pallet_revive_proc_macro/
lib.rs1use proc_macro::TokenStream;
24use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
25use quote::{quote, ToTokens};
26use syn::{parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg, Ident};
27
28#[proc_macro_attribute]
29pub fn unstable_hostfn(_attr: TokenStream, item: TokenStream) -> TokenStream {
30 let input = syn::parse_macro_input!(item as syn::Item);
31 let expanded = quote! {
32 #[cfg(feature = "unstable-hostfn")]
33 #[cfg_attr(docsrs, doc(cfg(feature = "unstable-hostfn")))]
34 #input
35 };
36 expanded.into()
37}
38
39#[proc_macro_attribute]
50pub fn define_env(attr: TokenStream, item: TokenStream) -> TokenStream {
51 if !attr.is_empty() {
52 let msg = r#"Invalid `define_env` attribute macro: expected no attributes:
53 - `#[define_env]`"#;
54 let span = TokenStream2::from(attr).span();
55 return syn::Error::new(span, msg).to_compile_error().into()
56 }
57
58 let item = syn::parse_macro_input!(item as syn::ItemMod);
59
60 match EnvDef::try_from(item) {
61 Ok(mut def) => expand_env(&mut def).into(),
62 Err(e) => e.to_compile_error().into(),
63 }
64}
65
66struct EnvDef {
68 host_funcs: Vec<HostFn>,
69}
70
71struct HostFn {
73 item: syn::ItemFn,
74 is_stable: bool,
75 name: String,
76 returns: HostFnReturn,
77 cfg: Option<syn::Attribute>,
78}
79
80enum HostFnReturn {
81 Unit,
82 U32,
83 U64,
84 ReturnCode,
85}
86
87impl HostFnReturn {
88 fn map_output(&self) -> TokenStream2 {
89 match self {
90 Self::Unit => quote! { |_| None },
91 _ => quote! { |ret_val| Some(ret_val.into()) },
92 }
93 }
94
95 fn success_type(&self) -> syn::ReturnType {
96 match self {
97 Self::Unit => syn::ReturnType::Default,
98 Self::U32 => parse_quote! { -> u32 },
99 Self::U64 => parse_quote! { -> u64 },
100 Self::ReturnCode => parse_quote! { -> ReturnErrorCode },
101 }
102 }
103}
104
105impl EnvDef {
106 pub fn try_from(item: syn::ItemMod) -> syn::Result<Self> {
107 let span = item.span();
108 let err = |msg| syn::Error::new(span, msg);
109 let items = &item
110 .content
111 .as_ref()
112 .ok_or(err("Invalid environment definition, expected `mod` to be inlined."))?
113 .1;
114
115 let extract_fn = |i: &syn::Item| match i {
116 syn::Item::Fn(i_fn) => Some(i_fn.clone()),
117 _ => None,
118 };
119
120 let host_funcs = items
121 .iter()
122 .filter_map(extract_fn)
123 .map(HostFn::try_from)
124 .collect::<Result<Vec<_>, _>>()?;
125
126 Ok(Self { host_funcs })
127 }
128}
129
130impl HostFn {
131 pub fn try_from(mut item: syn::ItemFn) -> syn::Result<Self> {
132 let err = |span, msg| {
133 let msg = format!("Invalid host function definition.\n{}", msg);
134 syn::Error::new(span, msg)
135 };
136
137 let msg = "Only #[stable], #[cfg] and #[mutating] attributes are allowed.";
139 let span = item.span();
140 let mut attrs = item.attrs.clone();
141 attrs.retain(|a| !a.path().is_ident("doc"));
142 let mut is_stable = false;
143 let mut mutating = false;
144 let mut cfg = None;
145 while let Some(attr) = attrs.pop() {
146 let ident = attr.path().get_ident().ok_or(err(span, msg))?.to_string();
147 match ident.as_str() {
148 "stable" => {
149 if is_stable {
150 return Err(err(span, "#[stable] can only be specified once"))
151 }
152 is_stable = true;
153 },
154 "mutating" => {
155 if mutating {
156 return Err(err(span, "#[mutating] can only be specified once"))
157 }
158 mutating = true;
159 },
160 "cfg" => {
161 if cfg.is_some() {
162 return Err(err(span, "#[cfg] can only be specified once"))
163 }
164 cfg = Some(attr);
165 },
166 id => return Err(err(span, &format!("Unsupported attribute \"{id}\". {msg}"))),
167 }
168 }
169
170 if mutating {
171 let stmt = syn::parse_quote! {
172 if self.ext().is_read_only() {
173 return Err(Error::<E::T>::StateChangeDenied.into());
174 }
175 };
176 item.block.stmts.insert(0, stmt);
177 }
178
179 let name = item.sig.ident.to_string();
180
181 let msg = "Every function must start with these two parameters: &mut self, memory: &mut M";
182 let special_args = item
183 .sig
184 .inputs
185 .iter()
186 .take(2)
187 .enumerate()
188 .map(|(i, arg)| is_valid_special_arg(i, arg))
189 .fold(0u32, |acc, valid| if valid { acc + 1 } else { acc });
190
191 if special_args != 2 {
192 return Err(err(span, msg))
193 }
194
195 let msg = r#"Should return one of the following:
197 - Result<(), TrapReason>,
198 - Result<ReturnErrorCode, TrapReason>,
199 - Result<u32, TrapReason>,
200 - Result<u64, TrapReason>"#;
201 let ret_ty = match item.clone().sig.output {
202 syn::ReturnType::Type(_, ty) => Ok(ty.clone()),
203 _ => Err(err(span, &msg)),
204 }?;
205 match *ret_ty {
206 syn::Type::Path(tp) => {
207 let result = &tp.path.segments.last().ok_or(err(span, &msg))?;
208 let (id, span) = (result.ident.to_string(), result.ident.span());
209 id.eq(&"Result".to_string()).then_some(()).ok_or(err(span, &msg))?;
210
211 match &result.arguments {
212 syn::PathArguments::AngleBracketed(group) => {
213 if group.args.len() != 2 {
214 return Err(err(span, &msg))
215 };
216
217 let arg2 = group.args.last().ok_or(err(span, &msg))?;
218
219 let err_ty = match arg2 {
220 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
221 _ => Err(err(arg2.span(), &msg)),
222 }?;
223
224 match err_ty {
225 syn::Type::Path(tp) => Ok(tp
226 .path
227 .segments
228 .first()
229 .ok_or(err(arg2.span(), &msg))?
230 .ident
231 .to_string()),
232 _ => Err(err(tp.span(), &msg)),
233 }?
234 .eq("TrapReason")
235 .then_some(())
236 .ok_or(err(span, &msg))?;
237
238 let arg1 = group.args.first().ok_or(err(span, &msg))?;
239 let ok_ty = match arg1 {
240 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
241 _ => Err(err(arg1.span(), &msg)),
242 }?;
243 let ok_ty_str = match ok_ty {
244 syn::Type::Path(tp) => Ok(tp
245 .path
246 .segments
247 .first()
248 .ok_or(err(arg1.span(), &msg))?
249 .ident
250 .to_string()),
251 syn::Type::Tuple(tt) => {
252 if !tt.elems.is_empty() {
253 return Err(err(arg1.span(), &msg))
254 };
255 Ok("()".to_string())
256 },
257 _ => Err(err(ok_ty.span(), &msg)),
258 }?;
259 let returns = match ok_ty_str.as_str() {
260 "()" => Ok(HostFnReturn::Unit),
261 "u32" => Ok(HostFnReturn::U32),
262 "u64" => Ok(HostFnReturn::U64),
263 "ReturnErrorCode" => Ok(HostFnReturn::ReturnCode),
264 _ => Err(err(arg1.span(), &msg)),
265 }?;
266
267 Ok(Self { item, is_stable, name, returns, cfg })
268 },
269 _ => Err(err(span, &msg)),
270 }
271 },
272 _ => Err(err(span, &msg)),
273 }
274 }
275}
276
277fn is_valid_special_arg(idx: usize, arg: &FnArg) -> bool {
278 match (idx, arg) {
279 (0, FnArg::Receiver(rec)) => rec.reference.is_some() && rec.mutability.is_some(),
280 (1, FnArg::Typed(pat)) => {
281 let ident =
282 if let syn::Pat::Ident(ref ident) = *pat.pat { &ident.ident } else { return false };
283 if !(ident == "memory" || ident == "_memory") {
284 return false
285 }
286 matches!(*pat.ty, syn::Type::Reference(_))
287 },
288 _ => false,
289 }
290}
291
292fn arg_decoder<'a, P, I>(param_names: P, param_types: I) -> TokenStream2
293where
294 P: Iterator<Item = &'a std::boxed::Box<syn::Pat>> + Clone,
295 I: Iterator<Item = &'a std::boxed::Box<syn::Type>> + Clone,
296{
297 const ALLOWED_REGISTERS: usize = 6;
298
299 if param_names.clone().count() > ALLOWED_REGISTERS {
301 panic!("Syscalls take a maximum of {ALLOWED_REGISTERS} arguments");
302 }
303
304 if !param_types.clone().all(|ty| {
307 let syn::Type::Path(path) = &**ty else {
308 panic!("Type needs to be path");
309 };
310 let Some(ident) = path.path.get_ident() else {
311 panic!("Type needs to be ident");
312 };
313 matches!(ident.to_string().as_ref(), "u8" | "u16" | "u32" | "u64")
314 }) {
315 panic!("Only primitive unsigned integers are allowed as arguments to syscalls");
316 }
317
318 let bindings = param_names.zip(param_types).enumerate().map(|(idx, (name, ty))| {
320 let reg = quote::format_ident!("__a{}__", idx);
321 quote! {
322 let #name = #reg as #ty;
323 }
324 });
325 quote! {
326 #( #bindings )*
327 }
328}
329
330fn expand_env(def: &EnvDef) -> TokenStream2 {
335 let impls = expand_functions(def);
336 let bench_impls = expand_bench_functions(def);
337 let docs = expand_func_doc(def);
338 let stable_syscalls = expand_func_list(def, false);
339 let all_syscalls = expand_func_list(def, true);
340
341 quote! {
342 pub fn list_syscalls(include_unstable: bool) -> &'static [&'static [u8]] {
343 if include_unstable {
344 #all_syscalls
345 } else {
346 #stable_syscalls
347 }
348 }
349
350 impl<'a, E: Ext, M: PolkaVmInstance<E::T>> Runtime<'a, E, M> {
351 fn handle_ecall(
352 &mut self,
353 memory: &mut M,
354 __syscall_symbol__: &[u8],
355 ) -> Result<Option<u64>, TrapReason>
356 {
357 #impls
358 }
359 }
360
361 #[cfg(feature = "runtime-benchmarks")]
362 impl<'a, E: Ext, M: ?Sized + Memory<E::T>> Runtime<'a, E, M> {
363 #bench_impls
364 }
365
366 #[cfg(doc)]
376 pub trait SyscallDoc {
377 #docs
378 }
379 }
380}
381
382fn expand_functions(def: &EnvDef) -> TokenStream2 {
383 let impls = def.host_funcs.iter().map(|f| {
384 let params = f.item.sig.inputs.iter().skip(2);
386 let param_names = params.clone().filter_map(|arg| {
387 let FnArg::Typed(arg) = arg else {
388 return None;
389 };
390 Some(&arg.pat)
391 });
392 let param_types = params.clone().filter_map(|arg| {
393 let FnArg::Typed(arg) = arg else {
394 return None;
395 };
396 Some(&arg.ty)
397 });
398 let arg_decoder = arg_decoder(param_names, param_types);
399 let cfg = &f.cfg;
400 let name = &f.name;
401 let syscall_symbol = Literal::byte_string(name.as_bytes());
402 let body = &f.item.block;
403 let map_output = f.returns.map_output();
404 let output = &f.item.sig.output;
405
406 let wrapped_body_with_trace = {
409 let trace_fmt_args = params.clone().filter_map(|arg| match arg {
410 syn::FnArg::Receiver(_) => None,
411 syn::FnArg::Typed(p) => match *p.pat.clone() {
412 syn::Pat::Ident(ref pat_ident) => Some(pat_ident.ident.clone()),
413 _ => None,
414 },
415 });
416
417 let params_fmt_str = trace_fmt_args
418 .clone()
419 .map(|s| format!("{s}: {{:?}}"))
420 .collect::<Vec<_>>()
421 .join(", ");
422 let trace_fmt_str = format!("{}({}) = {{:?}} gas_consumed: {{:?}}", name, params_fmt_str);
423
424 quote! {
425 let result = (|| #body)();
427 ::log::trace!(target: "runtime::revive::strace", #trace_fmt_str, #( #trace_fmt_args, )* result, self.ext.gas_meter().gas_consumed());
428 result
429 }
430 };
431
432 quote! {
433 #cfg
434 #syscall_symbol => {
435 (|| #output {
437 #arg_decoder
438 #wrapped_body_with_trace
439 })().map(#map_output)
440 },
441 }
442 });
443
444 quote! {
445 let __gas_left_before__ = self
447 .ext
448 .gas_meter_mut()
449 .sync_from_executor(memory.gas())
450 .map_err(TrapReason::from)?;
451
452 self.charge_gas(crate::vm::RuntimeCosts::HostFn).map_err(TrapReason::from)?;
454
455 let (__a0__, __a1__, __a2__, __a3__, __a4__, __a5__) = memory.read_input_regs();
457
458 let result = (|| match __syscall_symbol__ {
460 #( #impls )*
461 _ => Err(TrapReason::SupervisorError(Error::<E::T>::InvalidSyscall.into()))
462 })();
463
464 let gas = self.ext.gas_meter_mut().sync_to_executor(__gas_left_before__).map_err(TrapReason::from)?;
466 memory.set_gas(gas.into());
467 result
468 }
469}
470
471fn expand_bench_functions(def: &EnvDef) -> TokenStream2 {
472 let impls = def.host_funcs.iter().map(|f| {
473 let params = f.item.sig.inputs.iter().skip(2);
475 let cfg = &f.cfg;
476 let name = &f.name;
477 let body = &f.item.block;
478 let output = &f.item.sig.output;
479
480 let name = Ident::new(&format!("bench_{name}"), Span::call_site());
481 quote! {
482 #cfg
483 pub fn #name(&mut self, memory: &mut M, #(#params),*) #output {
484 #body
485 }
486 }
487 });
488
489 quote! {
490 #( #impls )*
491 }
492}
493
494fn expand_func_doc(def: &EnvDef) -> TokenStream2 {
495 let docs = def.host_funcs.iter().map(|func| {
496 let func_decl = {
498 let mut sig = func.item.sig.clone();
499 sig.inputs = sig
500 .inputs
501 .iter()
502 .skip(2)
503 .map(|p| p.clone())
504 .collect::<Punctuated<FnArg, Comma>>();
505 sig.output = func.returns.success_type();
506 sig.to_token_stream()
507 };
508 let func_doc = {
509 let func_docs = {
510 let docs = func.item.attrs.iter().filter(|a| a.path().is_ident("doc")).map(|d| {
511 let docs = d.to_token_stream();
512 quote! { #docs }
513 });
514 quote! { #( #docs )* }
515 };
516 let availability = if func.is_stable {
517 let info = "\n# Stable API\nThis API is stable and will never change.";
518 quote! { #[doc = #info] }
519 } else {
520 let info =
521 "\n# Unstable API\nThis API is not standardized and only available for testing.";
522 quote! { #[doc = #info] }
523 };
524 quote! {
525 #func_docs
526 #availability
527 }
528 };
529 quote! {
530 #func_doc
531 #func_decl;
532 }
533 });
534
535 quote! {
536 #( #docs )*
537 }
538}
539
540fn expand_func_list(def: &EnvDef, include_unstable: bool) -> TokenStream2 {
541 let docs = def.host_funcs.iter().filter(|f| include_unstable || f.is_stable).map(|f| {
542 let name = Literal::byte_string(f.name.as_bytes());
543 quote! {
544 #name.as_slice()
545 }
546 });
547 let len = docs.clone().count();
548
549 quote! {
550 {
551 static FUNCS: [&[u8]; #len] = [#(#docs),*];
552 FUNCS.as_slice()
553 }
554 }
555}