parity_scale_codec_derive/
decode.rs

1// Copyright 2017, 2018 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use proc_macro2::{Span, TokenStream, Ident};
16use syn::{
17	spanned::Spanned,
18	Data, Fields, Field, Error,
19};
20
21use crate::utils;
22
23/// Generate function block for function `Decode::decode`.
24///
25/// * data: data info of the type,
26/// * type_name: name of the type,
27/// * type_generics: the generics of the type in turbofish format, without bounds, e.g. `::<T, I>`
28/// * input: the variable name for the argument of function `decode`.
29pub fn quote(
30	data: &Data,
31	type_name: &Ident,
32	type_generics: &TokenStream,
33	input: &TokenStream,
34	crate_path: &syn::Path,
35) -> TokenStream {
36	match *data {
37		Data::Struct(ref data) => match data.fields {
38			Fields::Named(_) | Fields::Unnamed(_) => create_instance(
39				quote! { #type_name #type_generics },
40				&type_name.to_string(),
41				input,
42				&data.fields,
43				crate_path,
44			),
45			Fields::Unit => {
46				quote_spanned! { data.fields.span() =>
47					::core::result::Result::Ok(#type_name)
48				}
49			},
50		},
51		Data::Enum(ref data) => {
52			let data_variants = || data.variants.iter().filter(|variant| !utils::should_skip(&variant.attrs));
53
54			if data_variants().count() > 256 {
55				return Error::new(
56					data.variants.span(),
57					"Currently only enums with at most 256 variants are encodable."
58				).to_compile_error();
59			}
60
61			let recurse = data_variants().enumerate().map(|(i, v)| {
62				let name = &v.ident;
63				let index = utils::variant_index(v, i);
64
65				let create = create_instance(
66					quote! { #type_name #type_generics :: #name },
67					&format!("{}::{}", type_name, name),
68					input,
69					&v.fields,
70					crate_path,
71				);
72
73				quote_spanned! { v.span() =>
74					#[allow(clippy::unnecessary_cast)]
75					__codec_x_edqy if __codec_x_edqy == #index as ::core::primitive::u8 => {
76						// NOTE: This lambda is necessary to work around an upstream bug
77						// where each extra branch results in excessive stack usage:
78						//   https://github.com/rust-lang/rust/issues/34283
79						#[allow(clippy::redundant_closure_call)]
80						return (move || {
81							#create
82						})();
83					},
84				}
85			});
86
87			let read_byte_err_msg = format!(
88				"Could not decode `{}`, failed to read variant byte",
89				type_name,
90			);
91			let invalid_variant_err_msg = format!(
92				"Could not decode `{}`, variant doesn't exist",
93				type_name,
94			);
95			quote! {
96				match #input.read_byte()
97					.map_err(|e| e.chain(#read_byte_err_msg))?
98				{
99					#( #recurse )*
100					_ => {
101						#[allow(clippy::redundant_closure_call)]
102						return (move || {
103							::core::result::Result::Err(
104								<_ as ::core::convert::Into<_>>::into(#invalid_variant_err_msg)
105							)
106						})();
107					},
108				}
109			}
110
111		},
112		Data::Union(_) => Error::new(Span::call_site(), "Union types are not supported.").to_compile_error(),
113	}
114}
115
116pub fn quote_decode_into(
117	data: &Data,
118	crate_path: &syn::Path,
119	input: &TokenStream,
120	attrs: &[syn::Attribute]
121) -> Option<TokenStream> {
122	// Make sure the type is `#[repr(transparent)]`, as this guarantees that
123	// there can be only one field that is not zero-sized.
124	if !crate::utils::is_transparent(attrs) {
125		return None;
126	}
127
128	let fields = match data {
129		Data::Struct(
130			syn::DataStruct {
131				fields: Fields::Named(syn::FieldsNamed { named: fields, .. }) |
132				        Fields::Unnamed(syn::FieldsUnnamed { unnamed: fields, .. }),
133				..
134			}
135		) => {
136			fields
137		},
138		_ => return None
139	};
140
141	if fields.is_empty() {
142		return None;
143	}
144
145	// Bail if there are any extra attributes which could influence how the type is decoded.
146	if fields.iter().any(|field|
147		utils::get_encoded_as_type(field).is_some() ||
148		utils::is_compact(field) ||
149		utils::should_skip(&field.attrs)
150	) {
151		return None;
152	}
153
154	// Go through each field and call `decode_into` on it.
155	//
156	// Normally if there's more than one field in the struct this would be incorrect,
157	// however since the struct's marked as `#[repr(transparent)]` we're guaranteed that
158	// there's at most one non zero-sized field, so only one of these `decode_into` calls
159	// should actually do something, and the rest should just be dummy calls that do nothing.
160	let mut decode_fields = Vec::new();
161	let mut sizes = Vec::new();
162	let mut non_zst_field_count = Vec::new();
163	for field in fields {
164		let field_type = &field.ty;
165		decode_fields.push(quote! {{
166			let dst_: &mut ::core::mem::MaybeUninit<Self> = dst_; // To make sure the type is what we expect.
167
168			// Here we cast `&mut MaybeUninit<Self>` into a `&mut MaybeUninit<#field_type>`.
169			//
170			// SAFETY: The struct is marked as `#[repr(transparent)]` so the address of every field will
171			//         be the same as the address of the struct itself.
172			let dst_: &mut ::core::mem::MaybeUninit<#field_type> = unsafe {
173				&mut *dst_.as_mut_ptr().cast::<::core::mem::MaybeUninit<#field_type>>()
174			};
175			<#field_type as #crate_path::Decode>::decode_into(#input, dst_)?;
176		}});
177
178		if !sizes.is_empty() {
179			sizes.push(quote! { + });
180		}
181		sizes.push(quote! { ::core::mem::size_of::<#field_type>() });
182
183		if !non_zst_field_count.is_empty() {
184			non_zst_field_count.push(quote! { + });
185		}
186		non_zst_field_count.push(quote! { if ::core::mem::size_of::<#field_type>() > 0 { 1 } else { 0 } });
187	}
188
189	Some(quote!{
190		// Just a sanity check. These should always be true and will be optimized-out.
191		::core::assert_eq!(#(#sizes)*, ::core::mem::size_of::<Self>());
192		::core::assert!(#(#non_zst_field_count)* <= 1);
193
194		#(#decode_fields)*
195
196		// SAFETY: We've successfully called `decode_into` for all of the fields.
197		unsafe { ::core::result::Result::Ok(#crate_path::DecodeFinished::assert_decoding_finished()) }
198	})
199}
200
201fn create_decode_expr(field: &Field, name: &str, input: &TokenStream, crate_path: &syn::Path) -> TokenStream {
202	let encoded_as = utils::get_encoded_as_type(field);
203	let compact = utils::is_compact(field);
204	let skip = utils::should_skip(&field.attrs);
205
206	let res = quote!(__codec_res_edqy);
207
208	if encoded_as.is_some() as u8 + compact as u8 + skip as u8 > 1 {
209		return Error::new(
210			field.span(),
211			"`encoded_as`, `compact` and `skip` can only be used one at a time!"
212		).to_compile_error();
213	}
214
215	let err_msg = format!("Could not decode `{}`", name);
216
217	if compact {
218		let field_type = &field.ty;
219		quote_spanned! { field.span() =>
220			{
221				let #res = <
222					<#field_type as #crate_path::HasCompact>::Type as #crate_path::Decode
223				>::decode(#input);
224				match #res {
225					::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)),
226					::core::result::Result::Ok(#res) => #res.into(),
227				}
228			}
229		}
230	} else if let Some(encoded_as) = encoded_as {
231		quote_spanned! { field.span() =>
232			{
233				let #res = <#encoded_as as #crate_path::Decode>::decode(#input);
234				match #res {
235					::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)),
236					::core::result::Result::Ok(#res) => #res.into(),
237				}
238			}
239		}
240	} else if skip {
241		quote_spanned! { field.span() => ::core::default::Default::default() }
242	} else {
243		let field_type = &field.ty;
244		quote_spanned! { field.span() =>
245			{
246				let #res = <#field_type as #crate_path::Decode>::decode(#input);
247				match #res {
248					::core::result::Result::Err(e) => return ::core::result::Result::Err(e.chain(#err_msg)),
249					::core::result::Result::Ok(#res) => #res,
250				}
251			}
252		}
253	}
254}
255
256fn create_instance(
257	name: TokenStream,
258	name_str: &str,
259	input: &TokenStream,
260	fields: &Fields,
261	crate_path: &syn::Path,
262) -> TokenStream {
263	match *fields {
264		Fields::Named(ref fields) => {
265			let recurse = fields.named.iter().map(|f| {
266				let name_ident = &f.ident;
267				let field_name = match name_ident {
268					Some(a) => format!("{}::{}", name_str, a),
269					None => name_str.to_string(), // Should never happen, fields are named.
270				};
271				let decode = create_decode_expr(f, &field_name, input, crate_path);
272
273				quote_spanned! { f.span() =>
274					#name_ident: #decode
275				}
276			});
277
278			quote_spanned! { fields.span() =>
279				::core::result::Result::Ok(#name {
280					#( #recurse, )*
281				})
282			}
283		},
284		Fields::Unnamed(ref fields) => {
285			let recurse = fields.unnamed.iter().enumerate().map(|(i, f) | {
286				let field_name = format!("{}.{}", name_str, i);
287
288				create_decode_expr(f, &field_name, input, crate_path)
289			});
290
291			quote_spanned! { fields.span() =>
292				::core::result::Result::Ok(#name (
293					#( #recurse, )*
294				))
295			}
296		},
297		Fields::Unit => {
298			quote_spanned! { fields.span() =>
299				::core::result::Result::Ok(#name)
300			}
301		},
302	}
303}