parity_scale_codec_derive/
lib.rs

1// Copyright 2017-2021 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Derives serialization and deserialization codec for complex structs for simple marshalling.
16
17#![recursion_limit = "128"]
18extern crate proc_macro;
19
20#[macro_use]
21extern crate syn;
22
23#[macro_use]
24extern crate quote;
25
26use crate::utils::{codec_crate_path, is_lint_attribute};
27use syn::{spanned::Spanned, Data, DeriveInput, Error, Field, Fields};
28
29mod decode;
30mod encode;
31mod max_encoded_len;
32mod trait_bounds;
33mod utils;
34
35/// Wraps the impl block in a "dummy const"
36fn wrap_with_dummy_const(
37	input: DeriveInput,
38	impl_block: proc_macro2::TokenStream,
39) -> proc_macro::TokenStream {
40	let attrs = input.attrs.into_iter().filter(is_lint_attribute);
41	let generated = quote! {
42		#[allow(deprecated)]
43		const _: () = {
44			#(#attrs)*
45			#impl_block
46		};
47	};
48
49	generated.into()
50}
51
52/// Derive `parity_scale_codec::Encode` and `parity_scale_codec::EncodeLike` for struct and enum.
53///
54/// # Top level attributes
55///
56/// By default the macro will add [`Encode`] and [`Decode`] bounds to all types, but the bounds can
57/// be specified manually with the top level attributes:
58/// * `#[codec(encode_bound(T: Encode))]`: a custom bound added to the `where`-clause when deriving
59///   the `Encode` trait, overriding the default.
60/// * `#[codec(decode_bound(T: Decode))]`: a custom bound added to the `where`-clause when deriving
61///   the `Decode` trait, overriding the default.
62///
63/// # Struct
64///
65/// A struct is encoded by encoding each of its fields successively.
66///
67/// Fields can have some attributes:
68/// * `#[codec(skip)]`: the field is not encoded. It must derive `Default` if Decode is derived.
69/// * `#[codec(compact)]`: the field is encoded in its compact representation i.e. the field must
70///   implement `parity_scale_codec::HasCompact` and will be encoded as `HasCompact::Type`.
71/// * `#[codec(encoded_as = "$EncodeAs")]`: the field is encoded as an alternative type. $EncodedAs
72///   type must implement `parity_scale_codec::EncodeAsRef<'_, $FieldType>` with $FieldType the type
73///   of the field with the attribute. This is intended to be used for types implementing
74///   `HasCompact` as shown in the example.
75///
76/// ```
77/// # use parity_scale_codec_derive::Encode;
78/// # use parity_scale_codec::{Encode as _, HasCompact};
79/// #[derive(Encode)]
80/// struct StructType {
81///     #[codec(skip)]
82///     a: u32,
83///     #[codec(compact)]
84///     b: u32,
85///     #[codec(encoded_as = "<u32 as HasCompact>::Type")]
86///     c: u32,
87/// }
88/// ```
89///
90/// # Enum
91///
92/// The variable is encoded with one byte for the variant and then the variant struct encoding.
93/// The variant number is:
94/// * if variant has attribute: `#[codec(index = "$n")]` then n
95/// * else if variant has discriminant (like 3 in `enum T { A = 3 }`) then the discriminant.
96/// * else its position in the variant set, excluding skipped variants, but including variant with
97/// discriminant or attribute. Warning this position does collision with discriminant or attribute
98/// index.
99///
100/// variant attributes:
101/// * `#[codec(skip)]`: the variant is not encoded.
102/// * `#[codec(index = "$n")]`: override variant index.
103///
104/// field attributes: same as struct fields attributes.
105///
106/// ```
107/// # use parity_scale_codec_derive::Encode;
108/// # use parity_scale_codec::Encode as _;
109/// #[derive(Encode)]
110/// enum EnumType {
111///     #[codec(index = 15)]
112///     A,
113///     #[codec(skip)]
114///     B,
115///     C = 3,
116///     D,
117/// }
118///
119/// assert_eq!(EnumType::A.encode(), vec![15]);
120/// assert_eq!(EnumType::B.encode(), vec![]);
121/// assert_eq!(EnumType::C.encode(), vec![3]);
122/// assert_eq!(EnumType::D.encode(), vec![2]);
123/// ```
124#[proc_macro_derive(Encode, attributes(codec))]
125pub fn encode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
126	let mut input: DeriveInput = match syn::parse(input) {
127		Ok(input) => input,
128		Err(e) => return e.to_compile_error().into(),
129	};
130
131	if let Err(e) = utils::check_attributes(&input) {
132		return e.to_compile_error().into()
133	}
134
135	let crate_path = match codec_crate_path(&input.attrs) {
136		Ok(crate_path) => crate_path,
137		Err(error) => return error.into_compile_error().into(),
138	};
139
140	if let Err(e) = trait_bounds::add(
141		&input.ident,
142		&mut input.generics,
143		&input.data,
144		utils::custom_encode_trait_bound(&input.attrs),
145		parse_quote!(#crate_path::Encode),
146		None,
147		utils::has_dumb_trait_bound(&input.attrs),
148		&crate_path,
149	) {
150		return e.to_compile_error().into()
151	}
152
153	let name = &input.ident;
154	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
155
156	let encode_impl = encode::quote(&input.data, name, &crate_path);
157
158	let impl_block = quote! {
159		#[automatically_derived]
160		impl #impl_generics #crate_path::Encode for #name #ty_generics #where_clause {
161			#encode_impl
162		}
163
164		#[automatically_derived]
165		impl #impl_generics #crate_path::EncodeLike for #name #ty_generics #where_clause {}
166	};
167
168	wrap_with_dummy_const(input, impl_block)
169}
170
171/// Derive `parity_scale_codec::Decode` and for struct and enum.
172///
173/// see derive `Encode` documentation.
174#[proc_macro_derive(Decode, attributes(codec))]
175pub fn decode_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
176	let mut input: DeriveInput = match syn::parse(input) {
177		Ok(input) => input,
178		Err(e) => return e.to_compile_error().into(),
179	};
180
181	if let Err(e) = utils::check_attributes(&input) {
182		return e.to_compile_error().into()
183	}
184
185	let crate_path = match codec_crate_path(&input.attrs) {
186		Ok(crate_path) => crate_path,
187		Err(error) => return error.into_compile_error().into(),
188	};
189
190	if let Err(e) = trait_bounds::add(
191		&input.ident,
192		&mut input.generics,
193		&input.data,
194		utils::custom_decode_trait_bound(&input.attrs),
195		parse_quote!(#crate_path::Decode),
196		Some(parse_quote!(Default)),
197		utils::has_dumb_trait_bound(&input.attrs),
198		&crate_path,
199	) {
200		return e.to_compile_error().into()
201	}
202
203	let name = &input.ident;
204	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
205	let ty_gen_turbofish = ty_generics.as_turbofish();
206
207	let input_ = quote!(__codec_input_edqy);
208	let decoding =
209		decode::quote(&input.data, name, &quote!(#ty_gen_turbofish), &input_, &crate_path);
210
211	let decode_into_body = decode::quote_decode_into(
212		&input.data,
213		&crate_path,
214		&input_,
215		&input.attrs
216	);
217
218	let impl_decode_into = if let Some(body) = decode_into_body {
219		quote! {
220			fn decode_into<__CodecInputEdqy: #crate_path::Input>(
221				#input_: &mut __CodecInputEdqy,
222				dst_: &mut ::core::mem::MaybeUninit<Self>,
223			) -> ::core::result::Result<#crate_path::DecodeFinished, #crate_path::Error> {
224				#body
225			}
226		}
227	} else {
228		quote! {}
229	};
230
231	let impl_block = quote! {
232		#[automatically_derived]
233		impl #impl_generics #crate_path::Decode for #name #ty_generics #where_clause {
234			fn decode<__CodecInputEdqy: #crate_path::Input>(
235				#input_: &mut __CodecInputEdqy
236			) -> ::core::result::Result<Self, #crate_path::Error> {
237				#decoding
238			}
239
240			#impl_decode_into
241		}
242	};
243
244	wrap_with_dummy_const(input, impl_block)
245}
246
247/// Derive `parity_scale_codec::Compact` and `parity_scale_codec::CompactAs` for struct with single
248/// field.
249///
250/// Attribute skip can be used to skip other fields.
251///
252/// # Example
253///
254/// ```
255/// # use parity_scale_codec_derive::CompactAs;
256/// # use parity_scale_codec::{Encode, HasCompact};
257/// # use std::marker::PhantomData;
258/// #[derive(CompactAs)]
259/// struct MyWrapper<T>(u32, #[codec(skip)] PhantomData<T>);
260/// ```
261#[proc_macro_derive(CompactAs, attributes(codec))]
262pub fn compact_as_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
263	let mut input: DeriveInput = match syn::parse(input) {
264		Ok(input) => input,
265		Err(e) => return e.to_compile_error().into(),
266	};
267
268	if let Err(e) = utils::check_attributes(&input) {
269		return e.to_compile_error().into()
270	}
271
272	let crate_path = match codec_crate_path(&input.attrs) {
273		Ok(crate_path) => crate_path,
274		Err(error) => return error.into_compile_error().into(),
275	};
276
277	if let Err(e) = trait_bounds::add::<()>(
278		&input.ident,
279		&mut input.generics,
280		&input.data,
281		None,
282		parse_quote!(#crate_path::CompactAs),
283		None,
284		utils::has_dumb_trait_bound(&input.attrs),
285		&crate_path,
286	) {
287		return e.to_compile_error().into()
288	}
289
290	let name = &input.ident;
291	let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
292
293	fn val_or_default(field: &Field) -> proc_macro2::TokenStream {
294		let skip = utils::should_skip(&field.attrs);
295		if skip {
296			quote_spanned!(field.span()=> Default::default())
297		} else {
298			quote_spanned!(field.span()=> x)
299		}
300	}
301
302	let (inner_ty, inner_field, constructor) = match input.data {
303		Data::Struct(ref data) => match data.fields {
304			Fields::Named(ref fields) if utils::filter_skip_named(fields).count() == 1 => {
305				let recurse = fields.named.iter().map(|f| {
306					let name_ident = &f.ident;
307					let val_or_default = val_or_default(f);
308					quote_spanned!(f.span()=> #name_ident: #val_or_default)
309				});
310				let field = utils::filter_skip_named(fields).next().expect("Exactly one field");
311				let field_name = &field.ident;
312				let constructor = quote!( #name { #( #recurse, )* });
313				(&field.ty, quote!(&self.#field_name), constructor)
314			},
315			Fields::Unnamed(ref fields) if utils::filter_skip_unnamed(fields).count() == 1 => {
316				let recurse = fields.unnamed.iter().enumerate().map(|(_, f)| {
317					let val_or_default = val_or_default(f);
318					quote_spanned!(f.span()=> #val_or_default)
319				});
320				let (id, field) =
321					utils::filter_skip_unnamed(fields).next().expect("Exactly one field");
322				let id = syn::Index::from(id);
323				let constructor = quote!( #name(#( #recurse, )*));
324				(&field.ty, quote!(&self.#id), constructor)
325			},
326			_ =>
327				return Error::new(
328					data.fields.span(),
329					"Only structs with a single non-skipped field can derive CompactAs",
330				)
331				.to_compile_error()
332				.into(),
333		},
334		Data::Enum(syn::DataEnum { enum_token: syn::token::Enum { span }, .. }) |
335		Data::Union(syn::DataUnion { union_token: syn::token::Union { span }, .. }) =>
336			return Error::new(span, "Only structs can derive CompactAs").to_compile_error().into(),
337	};
338
339	let impl_block = quote! {
340		#[automatically_derived]
341		impl #impl_generics #crate_path::CompactAs for #name #ty_generics #where_clause {
342			type As = #inner_ty;
343			fn encode_as(&self) -> &#inner_ty {
344				#inner_field
345			}
346			fn decode_from(x: #inner_ty)
347				-> ::core::result::Result<#name #ty_generics, #crate_path::Error>
348			{
349				::core::result::Result::Ok(#constructor)
350			}
351		}
352
353		#[automatically_derived]
354		impl #impl_generics From<#crate_path::Compact<#name #ty_generics>>
355			for #name #ty_generics #where_clause
356		{
357			fn from(x: #crate_path::Compact<#name #ty_generics>) -> #name #ty_generics {
358				x.0
359			}
360		}
361	};
362
363	wrap_with_dummy_const(input, impl_block)
364}
365
366/// Derive `parity_scale_codec::MaxEncodedLen` for struct and enum.
367///
368/// # Top level attribute
369///
370/// By default the macro will try to bound the types needed to implement `MaxEncodedLen`, but the
371/// bounds can be specified manually with the top level attribute:
372/// ```
373/// # use parity_scale_codec_derive::Encode;
374/// # use parity_scale_codec::MaxEncodedLen;
375/// # #[derive(Encode, MaxEncodedLen)]
376/// #[codec(mel_bound(T: MaxEncodedLen))]
377/// # struct MyWrapper<T>(T);
378/// ```
379#[cfg(feature = "max-encoded-len")]
380#[proc_macro_derive(MaxEncodedLen, attributes(max_encoded_len_mod))]
381pub fn derive_max_encoded_len(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
382	max_encoded_len::derive_max_encoded_len(input)
383}