1use proc_macro::TokenStream;
24use proc_macro2::{Literal, Span, TokenStream as TokenStream2};
25use quote::{quote, ToTokens};
26use syn::{parse_quote, punctuated::Punctuated, spanned::Spanned, token::Comma, FnArg, Ident};
27
28#[proc_macro_attribute]
39pub fn define_env(attr: TokenStream, item: TokenStream) -> TokenStream {
40 if !attr.is_empty() {
41 let msg = r#"Invalid `define_env` attribute macro: expected no attributes:
42 - `#[define_env]`"#;
43 let span = TokenStream2::from(attr).span();
44 return syn::Error::new(span, msg).to_compile_error().into();
45 }
46
47 let item = syn::parse_macro_input!(item as syn::ItemMod);
48
49 match EnvDef::try_from(item) {
50 Ok(mut def) => expand_env(&mut def).into(),
51 Err(e) => e.to_compile_error().into(),
52 }
53}
54
55struct EnvDef {
57 host_funcs: Vec<HostFn>,
58}
59
60struct HostFn {
62 item: syn::ItemFn,
63 name: String,
64 returns: HostFnReturn,
65 cfg: Option<syn::Attribute>,
66}
67
68enum HostFnReturn {
69 Unit,
70 U32,
71 U64,
72 ReturnCode,
73}
74
75impl HostFnReturn {
76 fn map_output(&self) -> TokenStream2 {
77 match self {
78 Self::Unit => quote! { |_| None },
79 _ => quote! { |ret_val| Some(ret_val.into()) },
80 }
81 }
82
83 fn success_type(&self) -> syn::ReturnType {
84 match self {
85 Self::Unit => syn::ReturnType::Default,
86 Self::U32 => parse_quote! { -> u32 },
87 Self::U64 => parse_quote! { -> u64 },
88 Self::ReturnCode => parse_quote! { -> ReturnErrorCode },
89 }
90 }
91
92 fn trace_return_value(&self) -> TokenStream2 {
93 match self {
94 Self::Unit => quote! { None },
95 Self::U32 => quote! { result.as_ref().ok().map(|r| *r as u64) },
96 Self::ReturnCode => quote! { result.as_ref().ok().copied().map(u64::from) },
97 Self::U64 => quote! { result.as_ref().ok().copied() },
98 }
99 }
100}
101
102impl EnvDef {
103 pub fn try_from(item: syn::ItemMod) -> syn::Result<Self> {
104 let span = item.span();
105 let err = |msg| syn::Error::new(span, msg);
106 let items = &item
107 .content
108 .as_ref()
109 .ok_or(err("Invalid environment definition, expected `mod` to be inlined."))?
110 .1;
111
112 let extract_fn = |i: &syn::Item| match i {
113 syn::Item::Fn(i_fn) => Some(i_fn.clone()),
114 _ => None,
115 };
116
117 let host_funcs = items
118 .iter()
119 .filter_map(extract_fn)
120 .map(HostFn::try_from)
121 .collect::<Result<Vec<_>, _>>()?;
122
123 Ok(Self { host_funcs })
124 }
125}
126
127impl HostFn {
128 pub fn try_from(mut item: syn::ItemFn) -> syn::Result<Self> {
129 let err = |span, msg| {
130 let msg = format!("Invalid host function definition.\n{}", msg);
131 syn::Error::new(span, msg)
132 };
133
134 let msg = "Only #[cfg] and #[mutating] attributes are allowed.";
136 let span = item.span();
137 let mut attrs = item.attrs.clone();
138 attrs.retain(|a| !a.path().is_ident("doc"));
139 let mut mutating = false;
140 let mut cfg = None;
141 while let Some(attr) = attrs.pop() {
142 let ident = attr.path().get_ident().ok_or(err(span, msg))?.to_string();
143 match ident.as_str() {
144 "mutating" => {
145 if mutating {
146 return Err(err(span, "#[mutating] can only be specified once"));
147 }
148 mutating = true;
149 },
150 "cfg" => {
151 if cfg.is_some() {
152 return Err(err(span, "#[cfg] can only be specified once"));
153 }
154 cfg = Some(attr);
155 },
156 id => return Err(err(span, &format!("Unsupported attribute \"{id}\". {msg}"))),
157 }
158 }
159
160 if mutating {
161 let stmt = syn::parse_quote! {
162 if self.ext().is_read_only() {
163 return Err(Error::<E::T>::StateChangeDenied.into());
164 }
165 };
166 item.block.stmts.insert(0, stmt);
167 }
168
169 let name = item.sig.ident.to_string();
170
171 let msg = "Every function must start with these two parameters: &mut self, memory: &mut M";
172 let special_args = item
173 .sig
174 .inputs
175 .iter()
176 .take(2)
177 .enumerate()
178 .map(|(i, arg)| is_valid_special_arg(i, arg))
179 .fold(0u32, |acc, valid| if valid { acc + 1 } else { acc });
180
181 if special_args != 2 {
182 return Err(err(span, msg));
183 }
184
185 let msg = r#"Should return one of the following:
187 - Result<(), TrapReason>,
188 - Result<ReturnErrorCode, TrapReason>,
189 - Result<u32, TrapReason>,
190 - Result<u64, TrapReason>"#;
191 let ret_ty = match item.clone().sig.output {
192 syn::ReturnType::Type(_, ty) => Ok(ty.clone()),
193 _ => Err(err(span, &msg)),
194 }?;
195 match *ret_ty {
196 syn::Type::Path(tp) => {
197 let result = &tp.path.segments.last().ok_or(err(span, &msg))?;
198 let (id, span) = (result.ident.to_string(), result.ident.span());
199 id.eq(&"Result".to_string()).then_some(()).ok_or(err(span, &msg))?;
200
201 match &result.arguments {
202 syn::PathArguments::AngleBracketed(group) => {
203 if group.args.len() != 2 {
204 return Err(err(span, &msg));
205 };
206
207 let arg2 = group.args.last().ok_or(err(span, &msg))?;
208
209 let err_ty = match arg2 {
210 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
211 _ => Err(err(arg2.span(), &msg)),
212 }?;
213
214 match err_ty {
215 syn::Type::Path(tp) => Ok(tp
216 .path
217 .segments
218 .first()
219 .ok_or(err(arg2.span(), &msg))?
220 .ident
221 .to_string()),
222 _ => Err(err(tp.span(), &msg)),
223 }?
224 .eq("TrapReason")
225 .then_some(())
226 .ok_or(err(span, &msg))?;
227
228 let arg1 = group.args.first().ok_or(err(span, &msg))?;
229 let ok_ty = match arg1 {
230 syn::GenericArgument::Type(ty) => Ok(ty.clone()),
231 _ => Err(err(arg1.span(), &msg)),
232 }?;
233 let ok_ty_str = match ok_ty {
234 syn::Type::Path(tp) => Ok(tp
235 .path
236 .segments
237 .first()
238 .ok_or(err(arg1.span(), &msg))?
239 .ident
240 .to_string()),
241 syn::Type::Tuple(tt) => {
242 if !tt.elems.is_empty() {
243 return Err(err(arg1.span(), &msg));
244 };
245 Ok("()".to_string())
246 },
247 _ => Err(err(ok_ty.span(), &msg)),
248 }?;
249 let returns = match ok_ty_str.as_str() {
250 "()" => Ok(HostFnReturn::Unit),
251 "u32" => Ok(HostFnReturn::U32),
252 "u64" => Ok(HostFnReturn::U64),
253 "ReturnErrorCode" => Ok(HostFnReturn::ReturnCode),
254 _ => Err(err(arg1.span(), &msg)),
255 }?;
256
257 Ok(Self { item, name, returns, cfg })
258 },
259 _ => Err(err(span, &msg)),
260 }
261 },
262 _ => Err(err(span, &msg)),
263 }
264 }
265}
266
267fn is_valid_special_arg(idx: usize, arg: &FnArg) -> bool {
268 match (idx, arg) {
269 (0, FnArg::Receiver(rec)) => rec.reference.is_some() && rec.mutability.is_some(),
270 (1, FnArg::Typed(pat)) => {
271 let ident = if let syn::Pat::Ident(ref ident) = *pat.pat {
272 &ident.ident
273 } else {
274 return false;
275 };
276 if !(ident == "memory" || ident == "_memory") {
277 return false;
278 }
279 matches!(*pat.ty, syn::Type::Reference(_))
280 },
281 _ => false,
282 }
283}
284
285fn arg_decoder<'a, P, I>(param_names: P, param_types: I) -> TokenStream2
286where
287 P: Iterator<Item = &'a std::boxed::Box<syn::Pat>> + Clone,
288 I: Iterator<Item = &'a std::boxed::Box<syn::Type>> + Clone,
289{
290 const ALLOWED_REGISTERS: usize = 6;
291
292 if param_names.clone().count() > ALLOWED_REGISTERS {
294 panic!("Syscalls take a maximum of {ALLOWED_REGISTERS} arguments");
295 }
296
297 if !param_types.clone().all(|ty| {
300 let syn::Type::Path(path) = &**ty else {
301 panic!("Type needs to be path");
302 };
303 let Some(ident) = path.path.get_ident() else {
304 panic!("Type needs to be ident");
305 };
306 matches!(ident.to_string().as_ref(), "u8" | "u16" | "u32" | "u64")
307 }) {
308 panic!("Only primitive unsigned integers are allowed as arguments to syscalls");
309 }
310
311 let bindings = param_names.zip(param_types).enumerate().map(|(idx, (name, ty))| {
313 let reg = quote::format_ident!("__a{}__", idx);
314 quote! {
315 let #name = #reg as #ty;
316 }
317 });
318 quote! {
319 #( #bindings )*
320 }
321}
322
323fn expand_env(def: &EnvDef) -> TokenStream2 {
328 let impls = expand_functions(def);
329 let bench_impls = expand_bench_functions(def);
330 let docs = expand_func_doc(def);
331 let all_syscalls = expand_func_list(def);
332 let lookup_syscall = expand_func_lookup(def);
333 let all_trace_ops = expand_trace_op_list(def);
334 let lookup_trace_op = expand_trace_op_lookup(def);
335
336 quote! {
337 pub fn list_syscalls() -> &'static [&'static [u8]] {
339 #all_syscalls
340 }
341
342 pub fn lookup_syscall_index(name: &'static str) -> Option<u8> {
344 #lookup_syscall
345 }
346
347 pub fn list_trace_ops() -> &'static [&'static [u8]] {
349 #all_trace_ops
350 }
351
352 pub fn lookup_trace_op_index(name: &'static str) -> Option<u8> {
354 #lookup_trace_op
355 }
356
357 impl<'a, E: Ext, M: PolkaVmInstance<E::T>> Runtime<'a, E, M> {
358 fn handle_ecall(
359 &mut self,
360 memory: &mut M,
361 __syscall_symbol__: &[u8],
362 ) -> Result<Option<u64>, TrapReason>
363 {
364 #impls
365 }
366 }
367
368 #[cfg(feature = "runtime-benchmarks")]
369 impl<'a, E: Ext, M: ?Sized + Memory<E::T>> Runtime<'a, E, M> {
370 #bench_impls
371 }
372
373 #[cfg(doc)]
383 pub trait SyscallDoc {
384 #docs
385 }
386 }
387}
388
389fn expand_functions(def: &EnvDef) -> TokenStream2 {
390 let impls = def.host_funcs.iter().map(|f| {
391 let params = f.item.sig.inputs.iter().skip(2);
393 let param_names = params.clone().filter_map(|arg| {
394 let FnArg::Typed(arg) = arg else {
395 return None;
396 };
397 Some(&arg.pat)
398 });
399 let param_types = params.clone().filter_map(|arg| {
400 let FnArg::Typed(arg) = arg else {
401 return None;
402 };
403 Some(&arg.ty)
404 });
405 let arg_decoder = arg_decoder(param_names, param_types);
406 let cfg = &f.cfg;
407 let name = &f.name;
408 let syscall_symbol = Literal::byte_string(name.as_bytes());
409 let body = &f.item.block;
410 let map_output = f.returns.map_output();
411 let trace_return = f.returns.trace_return_value();
412 let output = &f.item.sig.output;
413
414 let wrapped_body_with_trace = {
416 let trace_fmt_args = params.clone().filter_map(|arg| match arg {
417 syn::FnArg::Receiver(_) => None,
418 syn::FnArg::Typed(p) => match *p.pat.clone() {
419 syn::Pat::Ident(ref pat_ident) => Some(pat_ident.ident.clone()),
420 _ => None,
421 },
422 });
423
424 let params_fmt_str = trace_fmt_args
425 .clone()
426 .map(|s| format!("{s}: {{:?}}"))
427 .collect::<Vec<_>>()
428 .join(", ");
429 let trace_fmt_str = format!("{}({}) = {{:?}} weight_consumed: {{:?}}", name, params_fmt_str);
430 let trace_args_for_tracer: Vec<_> = trace_fmt_args.clone().collect();
431
432 quote! {
433 crate::tracing::if_tracing(|tracer| {
434 tracer.enter_ecall(#name, &[#( #trace_args_for_tracer as u64 ),*], self)
435 });
436
437 let result = (|| #body)();
439 ::log::trace!(target: "runtime::revive::strace", #trace_fmt_str, #( #trace_fmt_args, )* result, self.ext.frame_meter().weight_consumed());
440
441 crate::tracing::if_tracing(|tracer| tracer.exit_step(self, #trace_return));
442 result
443 }
444 };
445
446 quote! {
447 #cfg
448 #syscall_symbol => {
449 (|| #output {
451 #arg_decoder
452 #wrapped_body_with_trace
453 })().map(#map_output)
454 },
455 }
456 });
457
458 quote! {
459 crate::tracing::if_tracing(|tracer| {
460 tracer.enter_ecall(crate::tracing::PVM_FUEL_NAME, &[], self)
461 });
462
463 let __sync_result__ = self.ext
464 .frame_meter_mut()
465 .sync_from_executor(memory.gas())
466 .map_err(TrapReason::from);
467
468 crate::tracing::if_tracing(|tracer| tracer.exit_step(self, None));
469
470 __sync_result__?;
471
472 self.charge_gas(crate::vm::RuntimeCosts::HostFn).map_err(TrapReason::from)?;
474
475 let (__a0__, __a1__, __a2__, __a3__, __a4__, __a5__) = memory.read_input_regs();
477
478 let result = (|| match __syscall_symbol__ {
480 #( #impls )*
481 _ => Err(TrapReason::SupervisorError(Error::<E::T>::InvalidSyscall.into()))
482 })();
483
484 let gas = self.ext.frame_meter_mut().sync_to_executor();
486 memory.set_gas(gas.into());
487 result
488 }
489}
490
491fn expand_bench_functions(def: &EnvDef) -> TokenStream2 {
492 let impls = def.host_funcs.iter().map(|f| {
493 let params = f.item.sig.inputs.iter().skip(2);
495 let cfg = &f.cfg;
496 let name = &f.name;
497 let body = &f.item.block;
498 let output = &f.item.sig.output;
499
500 let name = Ident::new(&format!("bench_{name}"), Span::call_site());
501 quote! {
502 #cfg
503 pub fn #name(&mut self, memory: &mut M, #(#params),*) #output {
504 #body
505 }
506 }
507 });
508
509 quote! {
510 #( #impls )*
511 }
512}
513
514fn expand_func_doc(def: &EnvDef) -> TokenStream2 {
515 let docs = def.host_funcs.iter().map(|func| {
516 let func_decl = {
518 let mut sig = func.item.sig.clone();
519 sig.inputs = sig
520 .inputs
521 .iter()
522 .skip(2)
523 .map(|p| p.clone())
524 .collect::<Punctuated<FnArg, Comma>>();
525 sig.output = func.returns.success_type();
526 sig.to_token_stream()
527 };
528 let func_doc = {
529 let func_docs = {
530 let docs = func.item.attrs.iter().filter(|a| a.path().is_ident("doc")).map(|d| {
531 let docs = d.to_token_stream();
532 quote! { #docs }
533 });
534 quote! { #( #docs )* }
535 };
536 quote! {
537 #func_docs
538 }
539 };
540 quote! {
541 #func_doc
542 #func_decl;
543 }
544 });
545
546 quote! {
547 #( #docs )*
548 }
549}
550
551fn expand_func_list(def: &EnvDef) -> TokenStream2 {
552 let docs = def.host_funcs.iter().map(|f| {
553 let name = Literal::byte_string(f.name.as_bytes());
554 quote! {
555 #name.as_slice()
556 }
557 });
558 let len = docs.clone().count();
559
560 quote! {
561 {
562 static FUNCS: [&[u8]; #len] = [#(#docs),*];
563 FUNCS.as_slice()
564 }
565 }
566}
567
568fn expand_func_lookup(def: &EnvDef) -> TokenStream2 {
569 let arms = def.host_funcs.iter().enumerate().map(|(idx, f)| {
570 let name_str = &f.name;
571 quote! {
572 #name_str => Some(#idx as u8)
573 }
574 });
575 quote! {
576 match name {
577 #( #arms, )*
578 _ => None,
579 }
580 }
581}
582
583fn expand_trace_op_list(def: &EnvDef) -> TokenStream2 {
584 let syscalls = def.host_funcs.iter().map(|f| {
585 let name = Literal::byte_string(f.name.as_bytes());
586 quote! {
587 #name.as_slice()
588 }
589 });
590 let len = syscalls.clone().count() + 1;
591
592 quote! {
593 {
594 static OPS: [&[u8]; #len] = [
595 #(#syscalls,)*
596 crate::tracing::PVM_FUEL_NAME.as_bytes(),
597 ];
598 OPS.as_slice()
599 }
600 }
601}
602
603fn expand_trace_op_lookup(def: &EnvDef) -> TokenStream2 {
604 let arms = def.host_funcs.iter().enumerate().map(|(idx, f)| {
605 let name_str = &f.name;
606 quote! {
607 #name_str => Some(#idx as u8)
608 }
609 });
610 let pvm_fuel_idx = def.host_funcs.len();
611
612 quote! {
613 match name {
614 #( #arms, )*
615 crate::tracing::PVM_FUEL_NAME => Some(#pvm_fuel_idx as u8),
616 _ => None,
617 }
618 }
619}