referrerpolicy=no-referrer-when-downgrade

pallet_revive_proc_macro/
lib.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License.
17
18//! Procedural macros used in `pallet-revive`.
19//!
20//! Most likely you should use the [`#[define_env]`][`macro@define_env`] attribute macro which hides
21//! boilerplate of defining external environment for a polkavm module.
22
23use proc_macro::TokenStream;
24use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
25use quote::{quote, ToTokens};
26use syn::{parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg, Ident};
27
28/// Defines a host functions set that can be imported by contract polkavm code.
29///
30/// **CAUTION**: Be advised that all functions defined by this macro
31/// cause undefined behaviour inside the contract if the signature does not match.
32///
33/// WARNING: It is CRITICAL for contracts to make sure that the signatures match exactly.
34/// Failure to do so may result in undefined behavior, traps or security vulnerabilities inside the
35/// contract. The runtime itself is unharmed due to sandboxing.
36/// For example, if a function is called with an incorrect signature, it could lead to memory
37/// corruption or unexpected results within the contract.
38#[proc_macro_attribute]
39pub fn define_env(attr: TokenStream, item: TokenStream) -> TokenStream {
40	if !attr.is_empty() {
41		let msg = r#"Invalid `define_env` attribute macro: expected no attributes:
42					 - `#[define_env]`"#;
43		let span = TokenStream2::from(attr).span();
44		return syn::Error::new(span, msg).to_compile_error().into();
45	}
46
47	let item = syn::parse_macro_input!(item as syn::ItemMod);
48
49	match EnvDef::try_from(item) {
50		Ok(mut def) => expand_env(&mut def).into(),
51		Err(e) => e.to_compile_error().into(),
52	}
53}
54
55/// Parsed environment definition.
56struct EnvDef {
57	host_funcs: Vec<HostFn>,
58}
59
60/// Parsed host function definition.
61struct HostFn {
62	item: syn::ItemFn,
63	name: String,
64	returns: HostFnReturn,
65	cfg: Option<syn::Attribute>,
66}
67
68enum HostFnReturn {
69	Unit,
70	U32,
71	U64,
72	ReturnCode,
73}
74
75impl HostFnReturn {
76	fn map_output(&self) -> TokenStream2 {
77		match self {
78			Self::Unit => quote! { |_| None },
79			_ => quote! { |ret_val| Some(ret_val.into()) },
80		}
81	}
82
83	fn success_type(&self) -> syn::ReturnType {
84		match self {
85			Self::Unit => syn::ReturnType::Default,
86			Self::U32 => parse_quote! { -> u32 },
87			Self::U64 => parse_quote! { -> u64 },
88			Self::ReturnCode => parse_quote! { -> ReturnErrorCode },
89		}
90	}
91
92	fn trace_return_value(&self) -> TokenStream2 {
93		match self {
94			Self::Unit => quote! { None },
95			Self::U32 => quote! { result.as_ref().ok().map(|r| *r as u64) },
96			Self::ReturnCode => quote! { result.as_ref().ok().copied().map(u64::from) },
97			Self::U64 => quote! { result.as_ref().ok().copied() },
98		}
99	}
100}
101
102impl EnvDef {
103	pub fn try_from(item: syn::ItemMod) -> syn::Result<Self> {
104		let span = item.span();
105		let err = |msg| syn::Error::new(span, msg);
106		let items = &item
107			.content
108			.as_ref()
109			.ok_or(err("Invalid environment definition, expected `mod` to be inlined."))?
110			.1;
111
112		let extract_fn = |i: &syn::Item| match i {
113			syn::Item::Fn(i_fn) => Some(i_fn.clone()),
114			_ => None,
115		};
116
117		let host_funcs = items
118			.iter()
119			.filter_map(extract_fn)
120			.map(HostFn::try_from)
121			.collect::<Result<Vec<_>, _>>()?;
122
123		Ok(Self { host_funcs })
124	}
125}
126
127impl HostFn {
128	pub fn try_from(mut item: syn::ItemFn) -> syn::Result<Self> {
129		let err = |span, msg| {
130			let msg = format!("Invalid host function definition.\n{}", msg);
131			syn::Error::new(span, msg)
132		};
133
134		// process attributes
135		let msg = "Only #[cfg] and #[mutating] attributes are allowed.";
136		let span = item.span();
137		let mut attrs = item.attrs.clone();
138		attrs.retain(|a| !a.path().is_ident("doc"));
139		let mut mutating = false;
140		let mut cfg = None;
141		while let Some(attr) = attrs.pop() {
142			let ident = attr.path().get_ident().ok_or(err(span, msg))?.to_string();
143			match ident.as_str() {
144				"mutating" => {
145					if mutating {
146						return Err(err(span, "#[mutating] can only be specified once"));
147					}
148					mutating = true;
149				},
150				"cfg" => {
151					if cfg.is_some() {
152						return Err(err(span, "#[cfg] can only be specified once"));
153					}
154					cfg = Some(attr);
155				},
156				id => return Err(err(span, &format!("Unsupported attribute \"{id}\". {msg}"))),
157			}
158		}
159
160		if mutating {
161			let stmt = syn::parse_quote! {
162				if self.ext().is_read_only() {
163					return Err(Error::<E::T>::StateChangeDenied.into());
164				}
165			};
166			item.block.stmts.insert(0, stmt);
167		}
168
169		let name = item.sig.ident.to_string();
170
171		let msg = "Every function must start with these two parameters: &mut self, memory: &mut M";
172		let special_args = item
173			.sig
174			.inputs
175			.iter()
176			.take(2)
177			.enumerate()
178			.map(|(i, arg)| is_valid_special_arg(i, arg))
179			.fold(0u32, |acc, valid| if valid { acc + 1 } else { acc });
180
181		if special_args != 2 {
182			return Err(err(span, msg));
183		}
184
185		// process return type
186		let msg = r#"Should return one of the following:
187				- Result<(), TrapReason>,
188				- Result<ReturnErrorCode, TrapReason>,
189				- Result<u32, TrapReason>,
190				- Result<u64, TrapReason>"#;
191		let ret_ty = match item.clone().sig.output {
192			syn::ReturnType::Type(_, ty) => Ok(ty.clone()),
193			_ => Err(err(span, &msg)),
194		}?;
195		match *ret_ty {
196			syn::Type::Path(tp) => {
197				let result = &tp.path.segments.last().ok_or(err(span, &msg))?;
198				let (id, span) = (result.ident.to_string(), result.ident.span());
199				id.eq(&"Result".to_string()).then_some(()).ok_or(err(span, &msg))?;
200
201				match &result.arguments {
202					syn::PathArguments::AngleBracketed(group) => {
203						if group.args.len() != 2 {
204							return Err(err(span, &msg));
205						};
206
207						let arg2 = group.args.last().ok_or(err(span, &msg))?;
208
209						let err_ty = match arg2 {
210							syn::GenericArgument::Type(ty) => Ok(ty.clone()),
211							_ => Err(err(arg2.span(), &msg)),
212						}?;
213
214						match err_ty {
215							syn::Type::Path(tp) => Ok(tp
216								.path
217								.segments
218								.first()
219								.ok_or(err(arg2.span(), &msg))?
220								.ident
221								.to_string()),
222							_ => Err(err(tp.span(), &msg)),
223						}?
224						.eq("TrapReason")
225						.then_some(())
226						.ok_or(err(span, &msg))?;
227
228						let arg1 = group.args.first().ok_or(err(span, &msg))?;
229						let ok_ty = match arg1 {
230							syn::GenericArgument::Type(ty) => Ok(ty.clone()),
231							_ => Err(err(arg1.span(), &msg)),
232						}?;
233						let ok_ty_str = match ok_ty {
234							syn::Type::Path(tp) => Ok(tp
235								.path
236								.segments
237								.first()
238								.ok_or(err(arg1.span(), &msg))?
239								.ident
240								.to_string()),
241							syn::Type::Tuple(tt) => {
242								if !tt.elems.is_empty() {
243									return Err(err(arg1.span(), &msg));
244								};
245								Ok("()".to_string())
246							},
247							_ => Err(err(ok_ty.span(), &msg)),
248						}?;
249						let returns = match ok_ty_str.as_str() {
250							"()" => Ok(HostFnReturn::Unit),
251							"u32" => Ok(HostFnReturn::U32),
252							"u64" => Ok(HostFnReturn::U64),
253							"ReturnErrorCode" => Ok(HostFnReturn::ReturnCode),
254							_ => Err(err(arg1.span(), &msg)),
255						}?;
256
257						Ok(Self { item, name, returns, cfg })
258					},
259					_ => Err(err(span, &msg)),
260				}
261			},
262			_ => Err(err(span, &msg)),
263		}
264	}
265}
266
267fn is_valid_special_arg(idx: usize, arg: &FnArg) -> bool {
268	match (idx, arg) {
269		(0, FnArg::Receiver(rec)) => rec.reference.is_some() && rec.mutability.is_some(),
270		(1, FnArg::Typed(pat)) => {
271			let ident = if let syn::Pat::Ident(ref ident) = *pat.pat {
272				&ident.ident
273			} else {
274				return false;
275			};
276			if !(ident == "memory" || ident == "_memory") {
277				return false;
278			}
279			matches!(*pat.ty, syn::Type::Reference(_))
280		},
281		_ => false,
282	}
283}
284
285fn arg_decoder<'a, P, I>(param_names: P, param_types: I) -> TokenStream2
286where
287	P: Iterator<Item = &'a std::boxed::Box<syn::Pat>> + Clone,
288	I: Iterator<Item = &'a std::boxed::Box<syn::Type>> + Clone,
289{
290	const ALLOWED_REGISTERS: usize = 6;
291
292	// too many arguments
293	if param_names.clone().count() > ALLOWED_REGISTERS {
294		panic!("Syscalls take a maximum of {ALLOWED_REGISTERS} arguments");
295	}
296
297	// all of them take one register but we truncate them before passing into the function
298	// it is important to not allow any type which has illegal bit patterns like 'bool'
299	if !param_types.clone().all(|ty| {
300		let syn::Type::Path(path) = &**ty else {
301			panic!("Type needs to be path");
302		};
303		let Some(ident) = path.path.get_ident() else {
304			panic!("Type needs to be ident");
305		};
306		matches!(ident.to_string().as_ref(), "u8" | "u16" | "u32" | "u64")
307	}) {
308		panic!("Only primitive unsigned integers are allowed as arguments to syscalls");
309	}
310
311	// one argument per register
312	let bindings = param_names.zip(param_types).enumerate().map(|(idx, (name, ty))| {
313		let reg = quote::format_ident!("__a{}__", idx);
314		quote! {
315			let #name = #reg as #ty;
316		}
317	});
318	quote! {
319		#( #bindings )*
320	}
321}
322
323/// Expands environment definition.
324/// Should generate source code for:
325///  - implementations of the host functions to be added to the polkavm runtime environment (see
326///    `expand_impls()`).
327fn expand_env(def: &EnvDef) -> TokenStream2 {
328	let impls = expand_functions(def);
329	let bench_impls = expand_bench_functions(def);
330	let docs = expand_func_doc(def);
331	let all_syscalls = expand_func_list(def);
332	let lookup_syscall = expand_func_lookup(def);
333	let all_trace_ops = expand_trace_op_list(def);
334	let lookup_trace_op = expand_trace_op_lookup(def);
335
336	quote! {
337		/// Returns the list of all syscalls that contracts can import.
338		pub fn list_syscalls() -> &'static [&'static [u8]] {
339			#all_syscalls
340		}
341
342		/// Return the index of a syscall in the `list_syscalls()` list.
343		pub fn lookup_syscall_index(name: &'static str) -> Option<u8> {
344			#lookup_syscall
345		}
346
347		/// Returns the list of all trace operations (real syscalls + synthetic trace steps).
348		pub fn list_trace_ops() -> &'static [&'static [u8]] {
349			#all_trace_ops
350		}
351
352		/// Return the index of a trace operation in the `list_trace_ops()` list.
353		pub fn lookup_trace_op_index(name: &'static str) -> Option<u8> {
354			#lookup_trace_op
355		}
356
357		impl<'a, E: Ext, M: PolkaVmInstance<E::T>> Runtime<'a, E, M> {
358			fn handle_ecall(
359				&mut self,
360				memory: &mut M,
361				__syscall_symbol__: &[u8],
362			) -> Result<Option<u64>, TrapReason>
363			{
364				#impls
365			}
366		}
367
368		#[cfg(feature = "runtime-benchmarks")]
369		impl<'a, E: Ext, M: ?Sized + Memory<E::T>> Runtime<'a, E, M> {
370			#bench_impls
371		}
372
373		/// Documentation of the syscalls (host functions) available to contracts.
374		///
375		/// Each of the functions in this trait represent a function that is callable
376		/// by the contract. Guests use the function name as the import symbol.
377		///
378		/// # Note
379		///
380		/// This module is not meant to be used by any code. Rather, it is meant to be
381		/// consumed by humans through rustdoc.
382		#[cfg(doc)]
383		pub trait SyscallDoc {
384			#docs
385		}
386	}
387}
388
389fn expand_functions(def: &EnvDef) -> TokenStream2 {
390	let impls = def.host_funcs.iter().map(|f| {
391		// skip the self and memory argument
392		let params = f.item.sig.inputs.iter().skip(2);
393		let param_names = params.clone().filter_map(|arg| {
394			let FnArg::Typed(arg) = arg else {
395				return None;
396			};
397			Some(&arg.pat)
398		});
399		let param_types = params.clone().filter_map(|arg| {
400			let FnArg::Typed(arg) = arg else {
401				return None;
402			};
403			Some(&arg.ty)
404		});
405		let arg_decoder = arg_decoder(param_names, param_types);
406		let cfg = &f.cfg;
407		let name = &f.name;
408		let syscall_symbol = Literal::byte_string(name.as_bytes());
409		let body = &f.item.block;
410		let map_output = f.returns.map_output();
411		let trace_return = f.returns.trace_return_value();
412		let output = &f.item.sig.output;
413
414		// wrapped host function body call with host function traces
415		let wrapped_body_with_trace = {
416			let trace_fmt_args = params.clone().filter_map(|arg| match arg {
417				syn::FnArg::Receiver(_) => None,
418				syn::FnArg::Typed(p) => match *p.pat.clone() {
419					syn::Pat::Ident(ref pat_ident) => Some(pat_ident.ident.clone()),
420					_ => None,
421				},
422			});
423
424			let params_fmt_str = trace_fmt_args
425				.clone()
426				.map(|s| format!("{s}: {{:?}}"))
427				.collect::<Vec<_>>()
428				.join(", ");
429			let trace_fmt_str = format!("{}({}) = {{:?}} weight_consumed: {{:?}}", name, params_fmt_str);
430			let trace_args_for_tracer: Vec<_> = trace_fmt_args.clone().collect();
431
432			quote! {
433				crate::tracing::if_tracing(|tracer| {
434					tracer.enter_ecall(#name, &[#( #trace_args_for_tracer as u64 ),*], self)
435				});
436
437				// wrap body in closure to make sure the tracing is always executed
438				let result = (|| #body)();
439				::log::trace!(target: "runtime::revive::strace", #trace_fmt_str, #( #trace_fmt_args, )* result, self.ext.frame_meter().weight_consumed());
440
441				crate::tracing::if_tracing(|tracer| tracer.exit_step(self, #trace_return));
442				result
443			}
444		};
445
446		quote! {
447			#cfg
448			#syscall_symbol => {
449				// closure is needed so that "?" can infere the correct type
450				(|| #output {
451					#arg_decoder
452					#wrapped_body_with_trace
453				})().map(#map_output)
454			},
455		}
456	});
457
458	quote! {
459		crate::tracing::if_tracing(|tracer| {
460			tracer.enter_ecall(crate::tracing::PVM_FUEL_NAME, &[], self)
461		});
462
463		let __sync_result__ = self.ext
464			.frame_meter_mut()
465			.sync_from_executor(memory.gas())
466			.map_err(TrapReason::from);
467
468		crate::tracing::if_tracing(|tracer| tracer.exit_step(self, None));
469
470		__sync_result__?;
471
472		// This is the overhead to call an empty syscall that always needs to be charged.
473		self.charge_gas(crate::vm::RuntimeCosts::HostFn).map_err(TrapReason::from)?;
474
475		// They will be mapped to variable names by the syscall specific code.
476		let (__a0__, __a1__, __a2__, __a3__, __a4__, __a5__) = memory.read_input_regs();
477
478		// Execute the syscall specific logic in a closure so that the gas metering code is always executed.
479		let result = (|| match __syscall_symbol__ {
480			#( #impls )*
481			_ => Err(TrapReason::SupervisorError(Error::<E::T>::InvalidSyscall.into()))
482		})();
483
484		// Write gas from pallet-revive into polkavm after leaving the host function.
485		let gas = self.ext.frame_meter_mut().sync_to_executor();
486		memory.set_gas(gas.into());
487		result
488	}
489}
490
491fn expand_bench_functions(def: &EnvDef) -> TokenStream2 {
492	let impls = def.host_funcs.iter().map(|f| {
493		// skip the context and memory argument
494		let params = f.item.sig.inputs.iter().skip(2);
495		let cfg = &f.cfg;
496		let name = &f.name;
497		let body = &f.item.block;
498		let output = &f.item.sig.output;
499
500		let name = Ident::new(&format!("bench_{name}"), Span::call_site());
501		quote! {
502			#cfg
503			pub fn #name(&mut self, memory: &mut M, #(#params),*) #output {
504				#body
505			}
506		}
507	});
508
509	quote! {
510		#( #impls )*
511	}
512}
513
514fn expand_func_doc(def: &EnvDef) -> TokenStream2 {
515	let docs = def.host_funcs.iter().map(|func| {
516		// Remove auxiliary args: `ctx: _` and `memory: _`
517		let func_decl = {
518			let mut sig = func.item.sig.clone();
519			sig.inputs = sig
520				.inputs
521				.iter()
522				.skip(2)
523				.map(|p| p.clone())
524				.collect::<Punctuated<FnArg, Comma>>();
525			sig.output = func.returns.success_type();
526			sig.to_token_stream()
527		};
528		let func_doc = {
529			let func_docs = {
530				let docs = func.item.attrs.iter().filter(|a| a.path().is_ident("doc")).map(|d| {
531					let docs = d.to_token_stream();
532					quote! { #docs }
533				});
534				quote! { #( #docs )* }
535			};
536			quote! {
537				#func_docs
538			}
539		};
540		quote! {
541			#func_doc
542			#func_decl;
543		}
544	});
545
546	quote! {
547		#( #docs )*
548	}
549}
550
551fn expand_func_list(def: &EnvDef) -> TokenStream2 {
552	let docs = def.host_funcs.iter().map(|f| {
553		let name = Literal::byte_string(f.name.as_bytes());
554		quote! {
555			#name.as_slice()
556		}
557	});
558	let len = docs.clone().count();
559
560	quote! {
561		{
562			static FUNCS: [&[u8]; #len] = [#(#docs),*];
563			FUNCS.as_slice()
564		}
565	}
566}
567
568fn expand_func_lookup(def: &EnvDef) -> TokenStream2 {
569	let arms = def.host_funcs.iter().enumerate().map(|(idx, f)| {
570		let name_str = &f.name;
571		quote! {
572			#name_str => Some(#idx as u8)
573		}
574	});
575	quote! {
576		match name {
577			#( #arms, )*
578			_ => None,
579		}
580	}
581}
582
583fn expand_trace_op_list(def: &EnvDef) -> TokenStream2 {
584	let syscalls = def.host_funcs.iter().map(|f| {
585		let name = Literal::byte_string(f.name.as_bytes());
586		quote! {
587			#name.as_slice()
588		}
589	});
590	let len = syscalls.clone().count() + 1;
591
592	quote! {
593		{
594			static OPS: [&[u8]; #len] = [
595				#(#syscalls,)*
596				crate::tracing::PVM_FUEL_NAME.as_bytes(),
597			];
598			OPS.as_slice()
599		}
600	}
601}
602
603fn expand_trace_op_lookup(def: &EnvDef) -> TokenStream2 {
604	let arms = def.host_funcs.iter().enumerate().map(|(idx, f)| {
605		let name_str = &f.name;
606		quote! {
607			#name_str => Some(#idx as u8)
608		}
609	});
610	let pvm_fuel_idx = def.host_funcs.len();
611
612	quote! {
613		match name {
614			#( #arms, )*
615			crate::tracing::PVM_FUEL_NAME => Some(#pvm_fuel_idx as u8),
616			_ => None,
617		}
618	}
619}