Re: [PATCH v4 03/13] rust: add derive macro for `Zeroable`

From: Gary Guo
Date: Wed Aug 16 2023 - 13:42:04 EST


On Mon, 14 Aug 2023 08:46:41 +0000
Benno Lossin <benno.lossin@xxxxxxxxx> wrote:

> Add a derive proc-macro for the `Zeroable` trait. The macro supports
> structs where every field implements the `Zeroable` trait. This way
> `unsafe` implementations can be avoided.
>
> The macro is split into two parts:
> - a proc-macro to parse generics into impl and ty generics,
> - a declarative macro that expands to the impl block.
>
> Suggested-by: Asahi Lina <lina@xxxxxxxxxxxxx>
> Signed-off-by: Benno Lossin <benno.lossin@xxxxxxxxx>

Reviewed-by: Gary Guo <gary@xxxxxxxxxxx>

> ---
> v3 -> v4:
> - add support for `+` in `quote!`.
>
> v2 -> v3:
> - change derive behavior, instead of adding `Zeroable` bounds for every
> field, add them only for generic type parameters,
> - still check that every field implements `Zeroable`,
> - removed Reviewed-by's due to changes.
>
> v1 -> v2:
> - fix Zeroable path,
> - add Reviewed-by from Gary and Björn.
>
> rust/kernel/init/macros.rs | 35 ++++++++++++++++++
> rust/kernel/prelude.rs | 2 +-
> rust/macros/lib.rs | 20 +++++++++++
> rust/macros/quote.rs | 12 +++++++
> rust/macros/zeroable.rs | 72 ++++++++++++++++++++++++++++++++++++++
> 5 files changed, 140 insertions(+), 1 deletion(-)
> create mode 100644 rust/macros/zeroable.rs
>
> diff --git a/rust/kernel/init/macros.rs b/rust/kernel/init/macros.rs
> index 9182fdf99e7e..78091756dec0 100644
> --- a/rust/kernel/init/macros.rs
> +++ b/rust/kernel/init/macros.rs
> @@ -1215,3 +1215,38 @@ macro_rules! __init_internal {
> );
> };
> }
> +
> +#[doc(hidden)]
> +#[macro_export]
> +macro_rules! __derive_zeroable {
> + (parse_input:
> + @sig(
> + $(#[$($struct_attr:tt)*])*
> + $vis:vis struct $name:ident
> + $(where $($whr:tt)*)?
> + ),
> + @impl_generics($($impl_generics:tt)*),
> + @ty_generics($($ty_generics:tt)*),
> + @body({
> + $(
> + $(#[$($field_attr:tt)*])*
> + $field:ident : $field_ty:ty
> + ),* $(,)?
> + }),
> + ) => {
> + // SAFETY: every field type implements `Zeroable` and padding bytes may be zero.
> + #[automatically_derived]
> + unsafe impl<$($impl_generics)*> $crate::init::Zeroable for $name<$($ty_generics)*>
> + where
> + $($($whr)*)?
> + {}
> + const _: () = {
> + fn assert_zeroable<T: ?::core::marker::Sized + $crate::init::Zeroable>() {}
> + fn ensure_zeroable<$($impl_generics)*>()
> + where $($($whr)*)?
> + {
> + $(assert_zeroable::<$field_ty>();)*
> + }
> + };
> + };
> +}
> diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs
> index c28587d68ebc..ae21600970b3 100644
> --- a/rust/kernel/prelude.rs
> +++ b/rust/kernel/prelude.rs
> @@ -18,7 +18,7 @@
> pub use alloc::{boxed::Box, vec::Vec};
>
> #[doc(no_inline)]
> -pub use macros::{module, pin_data, pinned_drop, vtable};
> +pub use macros::{module, pin_data, pinned_drop, vtable, Zeroable};
>
> pub use super::build_assert;
>
> diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs
> index b4bc44c27bd4..fd7a815e68a8 100644
> --- a/rust/macros/lib.rs
> +++ b/rust/macros/lib.rs
> @@ -11,6 +11,7 @@
> mod pin_data;
> mod pinned_drop;
> mod vtable;
> +mod zeroable;
>
> use proc_macro::TokenStream;
>
> @@ -343,3 +344,22 @@ pub fn paste(input: TokenStream) -> TokenStream {
> paste::expand(&mut tokens);
> tokens.into_iter().collect()
> }
> +
> +/// Derives the [`Zeroable`] trait for the given struct.
> +///
> +/// This can only be used for structs where every field implements the [`Zeroable`] trait.
> +///
> +/// # Examples
> +///
> +/// ```rust
> +/// #[derive(Zeroable)]
> +/// pub struct DriverData {
> +/// id: i64,
> +/// buf_ptr: *mut u8,
> +/// len: usize,
> +/// }
> +/// ```
> +#[proc_macro_derive(Zeroable)]
> +pub fn derive_zeroable(input: TokenStream) -> TokenStream {
> + zeroable::derive(input)
> +}
> diff --git a/rust/macros/quote.rs b/rust/macros/quote.rs
> index dddbb4e6f4cb..33a199e4f176 100644
> --- a/rust/macros/quote.rs
> +++ b/rust/macros/quote.rs
> @@ -124,6 +124,18 @@ macro_rules! quote_spanned {
> ));
> quote_spanned!(@proc $v $span $($tt)*);
> };
> + (@proc $v:ident $span:ident ; $($tt:tt)*) => {
> + $v.push(::proc_macro::TokenTree::Punct(
> + ::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone)
> + ));
> + quote_spanned!(@proc $v $span $($tt)*);
> + };
> + (@proc $v:ident $span:ident + $($tt:tt)*) => {
> + $v.push(::proc_macro::TokenTree::Punct(
> + ::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone)
> + ));
> + quote_spanned!(@proc $v $span $($tt)*);
> + };
> (@proc $v:ident $span:ident $id:ident $($tt:tt)*) => {
> $v.push(::proc_macro::TokenTree::Ident(::proc_macro::Ident::new(stringify!($id), $span)));
> quote_spanned!(@proc $v $span $($tt)*);
> diff --git a/rust/macros/zeroable.rs b/rust/macros/zeroable.rs
> new file mode 100644
> index 000000000000..0d605c46ab3b
> --- /dev/null
> +++ b/rust/macros/zeroable.rs
> @@ -0,0 +1,72 @@
> +// SPDX-License-Identifier: GPL-2.0
> +
> +use crate::helpers::{parse_generics, Generics};
> +use proc_macro::{TokenStream, TokenTree};
> +
> +pub(crate) fn derive(input: TokenStream) -> TokenStream {
> + let (
> + Generics {
> + impl_generics,
> + ty_generics,
> + },
> + mut rest,
> + ) = parse_generics(input);
> + // This should be the body of the struct `{...}`.
> + let last = rest.pop();
> + // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
> + let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
> + // Are we inside of a generic where we want to add `Zeroable`?
> + let mut in_generic = !impl_generics.is_empty();
> + // Have we already inserted `Zeroable`?
> + let mut inserted = false;
> + // Level of `<>` nestings.
> + let mut nested = 0;
> + for tt in impl_generics {
> + match &tt {
> + // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
> + TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
> + if in_generic && !inserted {
> + new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
> + }
> + in_generic = true;
> + inserted = false;
> + new_impl_generics.push(tt);
> + }
> + // If we find `'`, then we are entering a lifetime.
> + TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
> + in_generic = false;
> + new_impl_generics.push(tt);
> + }
> + TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
> + new_impl_generics.push(tt);
> + if in_generic {
> + new_impl_generics.extend(quote! { ::kernel::init::Zeroable + });
> + inserted = true;
> + }
> + }
> + TokenTree::Punct(p) if p.as_char() == '<' => {
> + nested += 1;
> + new_impl_generics.push(tt);
> + }
> + TokenTree::Punct(p) if p.as_char() == '>' => {
> + assert!(nested > 0);
> + nested -= 1;
> + new_impl_generics.push(tt);
> + }
> + _ => new_impl_generics.push(tt),
> + }
> + }
> + assert_eq!(nested, 0);
> + if in_generic && !inserted {
> + new_impl_generics.extend(quote! { : ::kernel::init::Zeroable });
> + }
> + quote! {
> + ::kernel::__derive_zeroable!(
> + parse_input:
> + @sig(#(#rest)*),
> + @impl_generics(#(#new_impl_generics)*),
> + @ty_generics(#(#ty_generics)*),
> + @body(#last),
> + );
> + }
> +}