Skip to content

Commit

Permalink
Use associated constants instead of functions where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
meithecatte committed Sep 27, 2020
1 parent 5445541 commit 0b33b75
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 51 deletions.
57 changes: 24 additions & 33 deletions enumflags_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn bitflags_internal(
) -> proc_macro::TokenStream {
let ast: DeriveInput = syn::parse(input).unwrap();

let impls = match ast.data {
let output = match ast.data {
Data::Enum(ref data) => {
gen_enumflags(&ast.ident, &ast, data)
}
Expand All @@ -38,12 +38,13 @@ pub fn bitflags_internal(
}
};

let impls = impls.unwrap_or_else(|err| err.to_compile_error());
let combined = quote! {
#ast
#impls
};
combined.into()
output.unwrap_or_else(|err| {
let error = err.to_compile_error();
quote! {
#ast
#error
}
}).into()
}

/// Try to evaluate the expression given.
Expand Down Expand Up @@ -178,10 +179,11 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
{
let span = Span::call_site();
// for quote! interpolation
let variant_names = data.variants.iter().map(|v| &v.ident);
let variant_count = data.variants.len();

let repeated_name = std::iter::repeat(&ident);
let variant_names =
data.variants.iter()
.map(|v| &v.ident)
.collect::<Vec<_>>();
let repeated_name = vec![&ident; data.variants.len()];

let variants = collect_flags(data.variants.iter())?;
let deferred = variants.iter()
Expand All @@ -192,16 +194,11 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
.ok_or_else(|| syn::Error::new_spanned(&ident,
"repr attribute missing. Add #[repr(u64)] or a similar attribute to specify the size of the bitfield."))?;
let std_path = quote_spanned!(span => ::enumflags2::_internal::core);
let all = if variant_count == 0 {
quote!(0)
} else {
let repeated_name = repeated_name.clone();
let variant_names = variant_names.clone();
quote!(#(#repeated_name::#variant_names as #ty)|*)
};

Ok(quote_spanned! {
span => #(#deferred)*
span =>
#item
#(#deferred)*
impl #std_path::ops::Not for #ident {
type Output = ::enumflags2::BitFlags<#ident>;
fn not(self) -> Self::Output {
Expand Down Expand Up @@ -236,23 +233,17 @@ fn gen_enumflags(ident: &Ident, item: &DeriveInput, data: &DataEnum)
impl ::enumflags2::_internal::RawBitFlags for #ident {
type Type = #ty;

fn all_bits() -> Self::Type {
// make sure it's evaluated at compile time
const VALUE: #ty = #all;
VALUE
}
const ALL_BITS: Self::Type =
0 #(| (#repeated_name::#variant_names as #ty))*;

fn bits(self) -> Self::Type {
self as #ty
}
const FLAG_LIST: &'static [Self] =
&[#(#repeated_name::#variant_names),*];

fn flag_list() -> &'static [Self] {
const VARIANTS: [#ident; #variant_count] = [#(#repeated_name :: #variant_names),*];
&VARIANTS
}
const BITFLAGS_TYPE_NAME : &'static str =
concat!("BitFlags<", stringify!(#ident), ">");

fn bitflags_type_name() -> &'static str {
concat!("BitFlags<", stringify!(#ident), ">")
fn bits(self) -> Self::Type {
self as #ty
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/formatting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ where
T: BitFlag + fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = T::bitflags_type_name();
let name = T::BITFLAGS_TYPE_NAME;
let bits = DebugBinaryFormatter(&self.val);
let iter = if !self.is_empty() {
let iter = T::flag_list().iter().filter(|&&flag| self.contains(flag));
let iter = T::FLAG_LIST.iter().filter(|&&flag| self.contains(flag));
Some(FlagFormatter(iter))
} else {
None
Expand Down
30 changes: 14 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,21 +167,19 @@ pub mod _internal {
/// The underlying integer type.
type Type: BitFlagNum;

/// Return a value with all flag bits set.
fn all_bits() -> Self::Type;
/// A value with all flag bits set.
const ALL_BITS: Self::Type;

/// Return the bits as a number type.
fn bits(self) -> Self::Type;
/// A slice that contains each variant exactly one.
const FLAG_LIST: &'static [Self];

/// Return a slice that contains each variant exactly one.
fn flag_list() -> &'static [Self];

/// Return the name of the type for debug formatting purposes.
/// The name of the type for debug formatting purposes.
///
/// This is typically `BitFlags<EnumName>`
fn bitflags_type_name() -> &'static str {
"BitFlags"
}
const BITFLAGS_TYPE_NAME: &'static str;

/// Return the bits as a number type.
fn bits(self) -> Self::Type;
}

use ::core::ops::{BitAnd, BitOr, BitXor, Not};
Expand Down Expand Up @@ -333,12 +331,12 @@ where
/// assert_eq!(empty.contains(MyFlag::Three), true);
/// ```
pub fn all() -> Self {
unsafe { BitFlags::new(T::all_bits()) }
unsafe { BitFlags::new(T::ALL_BITS) }
}

/// Returns true if all flags are set
pub fn is_all(self) -> bool {
self.val == T::all_bits()
self.val == T::ALL_BITS
}

/// Returns true if no flag is set
Expand Down Expand Up @@ -383,7 +381,7 @@ where

/// Truncates flags that are illegal
pub fn from_bits_truncate(bits: T::Type) -> Self {
unsafe { BitFlags::new(bits & T::all_bits()) }
unsafe { BitFlags::new(bits & T::ALL_BITS) }
}

/// Toggles the matching bits
Expand All @@ -403,7 +401,7 @@ where

/// Returns an iterator that yields each set flag
pub fn iter(self) -> impl Iterator<Item = T> {
T::flag_list().iter().cloned().filter(move |&flag| self.contains(flag))
T::FLAG_LIST.iter().cloned().filter(move |&flag| self.contains(flag))
}
}

Expand Down Expand Up @@ -485,7 +483,7 @@ where
{
type Output = BitFlags<T>;
fn not(self) -> BitFlags<T> {
unsafe { BitFlags::new(!self.bits() & T::all_bits()) }
unsafe { BitFlags::new(!self.bits() & T::ALL_BITS) }
}
}

Expand Down

0 comments on commit 0b33b75

Please sign in to comment.