Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive on structs with generics #79

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 5 additions & 31 deletions bilge-impl/src/bitsize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,40 +120,14 @@ fn analyze_enum(bitsize: BitSize, variants: Iter<Variant>) {
}
}

fn generate_struct(item: &ItemStruct, declared_bitsize: u8) -> TokenStream {
let ItemStruct { vis, ident, fields, .. } = item;
let declared_bitsize = declared_bitsize as usize;

let computed_bitsize = fields.iter().fold(quote!(0), |acc, next| {
let field_size = shared::generate_type_bitsize(&next.ty);
quote!(#acc + #field_size)
});

// we could remove this if the whole struct gets passed
let is_tuple_struct = fields.iter().any(|field| field.ident.is_none());
let fields_def = if is_tuple_struct {
let fields = fields.iter();
quote! {
( #(#fields,)* );
}
} else {
let fields = fields.iter();
quote! {
{ #(#fields,)* }
}
fn generate_struct(item: &ItemStruct, _declared_bitsize: u8) -> TokenStream {
let item = ItemStruct {
attrs: Vec::new(),
..item.clone()
};

quote! {
#vis struct #ident #fields_def

// constness: when we get const blocks evaluated at compile time, add a const computed_bitsize
const _: () = assert!(
(#computed_bitsize) == (#declared_bitsize),
concat!("struct size and declared bit size differ: ",
// stringify!(#computed_bitsize),
" != ",
stringify!(#declared_bitsize))
);
#item
}
}

Expand Down
137 changes: 105 additions & 32 deletions bilge-impl/src/bitsize_internal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{Attribute, Field, Item, ItemEnum, ItemStruct, Type};
use syn::{Attribute, Field, Generics, Item, ItemEnum, ItemStruct, Type};

use crate::shared::{self, unreachable};

Expand All @@ -12,36 +12,36 @@ struct ItemIr<'a> {
name: &'a Ident,
/// generated item (and setters, getters, constructor, impl Bitsized)
expanded: TokenStream,
generics: &'a Generics,
check_expr: TokenStream,
}

pub(super) fn bitsize_internal(args: TokenStream, item: TokenStream) -> TokenStream {
let (item, arb_int) = parse(item, args);
let (item, arb_int, declared_bitsize) = parse(item, args);
let ir = match item {
Item::Struct(ref item) => {
let expanded = generate_struct(item, &arb_int);
let attrs = &item.attrs;
let name = &item.ident;
ItemIr { attrs, name, expanded }
}
Item::Enum(ref item) => {
let expanded = generate_enum(item);
let attrs = &item.attrs;
let name = &item.ident;
ItemIr { attrs, name, expanded }
}
Item::Struct(ref item) => generate_struct(item, &arb_int, declared_bitsize),
Item::Enum(ref item) => generate_enum(item),
_ => unreachable(()),
};
generate_common(ir, &arb_int)
}

fn parse(item: TokenStream, args: TokenStream) -> (Item, TokenStream) {
fn parse(item: TokenStream, args: TokenStream) -> (Item, TokenStream, u8) {
let item = syn::parse2(item).unwrap_or_else(unreachable);
let (_declared_bitsize, arb_int) = shared::bitsize_and_arbitrary_int_from(args);
(item, arb_int)
let (declared_bitsize, arb_int) = shared::bitsize_and_arbitrary_int_from(args);
(item, arb_int, declared_bitsize)
}

fn generate_struct(struct_data: &ItemStruct, arb_int: &TokenStream) -> TokenStream {
let ItemStruct { vis, ident, fields, .. } = struct_data;
fn generate_struct<'a>(struct_data: &'a ItemStruct, arb_int: &TokenStream, declared_bitsize: u8) -> ItemIr<'a> {
let ItemStruct {
vis,
ident,
fields,
generics,
attrs,
..
} = struct_data;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let mut fieldless_next_int = 0;
let mut previous_field_sizes = vec![];
Expand All @@ -66,12 +66,15 @@ fn generate_struct(struct_data: &ItemStruct, arb_int: &TokenStream) -> TokenStre

let const_ = if cfg!(feature = "nightly") { quote!(const) } else { quote!() };

quote! {
#vis struct #ident {
let phantom_type = generate_phantom_type(generics);

let expanded = quote! {
#vis struct #ident #generics #where_clause {
/// WARNING: modifying this value directly can break invariants
value: #arb_int,
_phantom: ::core::marker::PhantomData<#phantom_type>
}
impl #ident {
impl #impl_generics #ident #ty_generics #where_clause {
// #[inline]
#[allow(clippy::too_many_arguments, clippy::type_complexity, unused_parens)]
pub #const_ fn new(#( #constructor_args )*) -> Self {
Expand All @@ -81,10 +84,48 @@ fn generate_struct(struct_data: &ItemStruct, arb_int: &TokenStream) -> TokenStre
let mut offset = 0;
let raw_value = #( #constructor_parts )|*;
let value = #arb_int::new(raw_value);
Self { value }
Self { value, _phantom: ::core::marker::PhantomData }
}
#( #accessors )*
}
};

let computed_bitsize = fields.iter().fold(quote!(0), |acc, next| {
let field_size = shared::generate_type_bitsize(&next.ty);
quote!(#acc + #field_size)
});

let declared_bitsize = declared_bitsize as usize;

let check_expr = quote!(assert!(
(#computed_bitsize) == (#declared_bitsize),
concat!("struct size and declared bit size differ: ",
// stringify!(#computed_bitsize),
" != ",
stringify!(#declared_bitsize))
));

ItemIr {
attrs,
name: ident,
expanded,
generics,
check_expr,
}
}

/// Returns a tuple with the following types, in order:
/// - The original struct's type parameters
/// - References to `()` bound by the original struct's lifetime parameters
/// If there are 0 generics, the type will simply be `()`.
/// If there is a single generic type or lifetime, the type will not wrapped in a tuple.
fn generate_phantom_type(generics: &Generics) -> TokenStream {
let phantom_ty = generics.type_params().map(|e| &e.ident).map(|ident| quote!(#ident));
let phantom_lt = generics.lifetimes().map(|l| &l.lifetime).map(|lifetime| quote!(& #lifetime ()));
// TODO: integrate user-provided PhantomData somehow? (so that the user can set the variance)
let phantom = phantom_ty.chain(phantom_lt);
quote! {
(#(#phantom),*)
}
}

Expand Down Expand Up @@ -209,27 +250,59 @@ fn generate_constructor_stuff(ty: &Type, name: &Ident) -> (TokenStream, TokenStr
(constructor_arg, constructor_part)
}

fn generate_enum(enum_data: &ItemEnum) -> TokenStream {
let ItemEnum { vis, ident, variants, .. } = enum_data;
quote! {
fn generate_enum(enum_data: &ItemEnum) -> ItemIr {
let ItemEnum {
vis,
ident,
variants,
generics,
attrs,
..
} = enum_data;
let expanded = quote! {
#vis enum #ident {
#variants
}
};
ItemIr {
attrs,
name: ident,
expanded,
generics,
check_expr: quote! { () },
}
}

/// We have _one_ `generate_common` function, which holds everything struct and enum have _in common_.
/// Everything else has its own `generate_` functions.
fn generate_common(ir: ItemIr, arb_int: &TokenStream) -> TokenStream {
let ItemIr { attrs, name, expanded } = ir;
let ItemIr {
attrs,
name,
expanded,
generics,
check_expr,
} = ir;

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
#(#attrs)*
#expanded
impl ::bilge::Bitsized for #name {
type ArbitraryInt = #arb_int;
const BITS: usize = <Self::ArbitraryInt as Bitsized>::BITS;
const MAX: Self::ArbitraryInt = <Self::ArbitraryInt as Bitsized>::MAX;
}
const _: () = {
trait Assertion {
const SIZE_CHECK: ();
}

impl #impl_generics Assertion for #name #ty_generics #where_clause {
const SIZE_CHECK: () = #check_expr;
}

impl #impl_generics ::bilge::Bitsized for #name #ty_generics #where_clause {
type ArbitraryInt = #arb_int;
const BITS: usize = (<Self::ArbitraryInt as Bitsized>::BITS, <Self as Assertion>::SIZE_CHECK).0;
const MAX: Self::ArbitraryInt = (<Self::ArbitraryInt as Bitsized>::MAX, <Self as Assertion>::SIZE_CHECK).0;
}
};
Comment on lines +301 to +306
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's clever! congratulations on finding this workaround!

}
}
10 changes: 7 additions & 3 deletions bilge-impl/src/debug_bits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub(super) fn debug_bits(item: TokenStream) -> TokenStream {
};

let fmt_impl = match struct_data.fields {
Fields::Named(fields) => {
Fields::Named(ref fields) => {
let calls = fields.named.iter().map(|f| {
// We can unwrap since this is a named field
let call = f.ident.as_ref().unwrap();
Expand All @@ -30,7 +30,7 @@ pub(super) fn debug_bits(item: TokenStream) -> TokenStream {
#(#calls)*.finish()
}
}
Fields::Unnamed(fields) => {
Fields::Unnamed(ref fields) => {
let calls = fields.unnamed.iter().map(|_| {
let call: Ident = syn::parse_str(&format!("val_{}", fieldless_next_int)).unwrap_or_else(unreachable);
fieldless_next_int += 1;
Expand All @@ -45,8 +45,12 @@ pub(super) fn debug_bits(item: TokenStream) -> TokenStream {
Fields::Unit => todo!("this is a unit struct, which is not supported right now"),
};

let (impl_generics, ty_generics, where_clause) = derive_input.generics.split_for_impl();

let where_clause = shared::generate_trait_where_clause(&derive_input.generics, where_clause, quote!(::core::fmt::Debug));

quote! {
impl ::core::fmt::Debug for #name {
impl #impl_generics ::core::fmt::Debug for #name #ty_generics #where_clause {
fn fmt(&self, f: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
#fmt_impl
}
Expand Down
20 changes: 12 additions & 8 deletions bilge-impl/src/default_bits.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,39 @@
use proc_macro2::{Ident, TokenStream};
use proc_macro_error::abort_call_site;
use quote::quote;
use syn::{Data, DeriveInput, Fields, Type};
use syn::{Data, DeriveInput, Fields, Generics, Type};

use crate::shared::{self, fallback::Fallback, unreachable, BitSize};

pub(crate) fn default_bits(item: TokenStream) -> TokenStream {
let derive_input = parse(item);
//TODO: does fallback need handling?
let (derive_data, _, name, ..) = analyze(&derive_input);
let (derive_data, _, name, generics, ..) = analyze(&derive_input);

match derive_data {
Data::Struct(data) => generate_struct_default_impl(name, &data.fields),
Data::Struct(data) => generate_struct_default_impl(name, &data.fields, generics),
Data::Enum(_) => abort_call_site!("use derive(Default) for enums"),
_ => unreachable(()),
}
}

fn generate_struct_default_impl(struct_name: &Ident, fields: &Fields) -> TokenStream {
fn generate_struct_default_impl(struct_name: &Ident, fields: &Fields, generics: &Generics) -> TokenStream {
let default_value = fields
.iter()
.map(|field| generate_default_inner(&field.ty))
.reduce(|acc, next| quote!(#acc | #next));

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let where_clause = shared::generate_trait_where_clause(generics, where_clause, quote!(::core::default::Default));

quote! {
impl ::core::default::Default for #struct_name {
impl #impl_generics ::core::default::Default for #struct_name #ty_generics #where_clause {
fn default() -> Self {
let mut offset = 0;
let value = #default_value;
let value = <#struct_name as Bitsized>::ArbitraryInt::new(value);
Self { value }
let value = <#struct_name #ty_generics as Bitsized>::ArbitraryInt::new(value);
Self { value, _phantom: ::core::marker::PhantomData }
}
}
}
Expand Down Expand Up @@ -87,6 +91,6 @@ fn parse(item: TokenStream) -> DeriveInput {
shared::parse_derive(item)
}

fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, &Generics, BitSize, Option<Fallback>) {
shared::analyze_derive(derive_input, false)
}
14 changes: 8 additions & 6 deletions bilge-impl/src/fmt_bits.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
use proc_macro2::{Ident, TokenStream};
use quote::quote;
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Variant};
use syn::{punctuated::Iter, Data, DeriveInput, Fields, Generics, Variant};

use crate::shared::{self, discriminant_assigner::DiscriminantAssigner, fallback::Fallback, unreachable, BitSize};

pub(crate) fn binary(item: TokenStream) -> TokenStream {
let derive_input = parse(item);
let (derive_data, arb_int, name, bitsize, fallback) = analyze(&derive_input);
let (derive_data, arb_int, name, generics, bitsize, fallback) = analyze(&derive_input);

match derive_data {
Data::Struct(data) => generate_struct_binary_impl(name, &data.fields),
Data::Struct(data) => generate_struct_binary_impl(name, &data.fields, generics),
Data::Enum(data) => generate_enum_binary_impl(name, data.variants.iter(), arb_int, bitsize, fallback),
_ => unreachable(()),
}
}

fn generate_struct_binary_impl(struct_name: &Ident, fields: &Fields) -> TokenStream {
fn generate_struct_binary_impl(struct_name: &Ident, fields: &Fields, generics: &Generics) -> TokenStream {
let write_underscore = quote! { write!(f, "_")?; };

// fields are printed from most significant to least significant, separated by an underscore
Expand All @@ -37,8 +37,10 @@ fn generate_struct_binary_impl(struct_name: &Ident, fields: &Fields) -> TokenStr
})
.reduce(|acc, next| quote!(#acc #write_underscore #next));

let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

quote! {
impl ::core::fmt::Binary for #struct_name {
impl #impl_generics ::core::fmt::Binary for #struct_name #ty_generics #where_clause {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let struct_size = <#struct_name as Bitsized>::BITS;
let mut last_bit_pos = struct_size;
Expand Down Expand Up @@ -107,6 +109,6 @@ fn parse(item: TokenStream) -> DeriveInput {
shared::parse_derive(item)
}

fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, BitSize, Option<Fallback>) {
fn analyze(derive_input: &DeriveInput) -> (&Data, TokenStream, &Ident, &Generics, BitSize, Option<Fallback>) {
shared::analyze_derive(derive_input, false)
}
Loading