use proc_macro2::{Span, TokenStream};
use syn::{
parse::Parse, parse_quote, spanned::Spanned, token, Error, FnArg, Ident, ItemTrait, LitInt,
Pat, PatType, Result, Signature, TraitItem, TraitItemFn, Type,
};
use proc_macro_crate::{crate_name, FoundCrate};
use std::{
collections::{btree_map::Entry, BTreeMap},
env,
};
use quote::quote;
use inflector::Inflector;
mod attributes {
syn::custom_keyword!(register_only);
}
pub struct RuntimeInterfaceFunction {
item: TraitItemFn,
should_trap_on_return: bool,
}
impl std::ops::Deref for RuntimeInterfaceFunction {
type Target = TraitItemFn;
fn deref(&self) -> &Self::Target {
&self.item
}
}
impl RuntimeInterfaceFunction {
fn new(item: &TraitItemFn) -> Result<Self> {
let mut item = item.clone();
let mut should_trap_on_return = false;
item.attrs.retain(|attr| {
if attr.path().is_ident("trap_on_return") {
should_trap_on_return = true;
false
} else {
true
}
});
if should_trap_on_return && !matches!(item.sig.output, syn::ReturnType::Default) {
return Err(Error::new(
item.sig.ident.span(),
"Methods marked as #[trap_on_return] cannot return anything",
))
}
Ok(Self { item, should_trap_on_return })
}
pub fn should_trap_on_return(&self) -> bool {
self.should_trap_on_return
}
}
struct RuntimeInterfaceFunctionSet {
latest_version_to_call: Option<u32>,
versions: BTreeMap<u32, RuntimeInterfaceFunction>,
}
impl RuntimeInterfaceFunctionSet {
fn new(version: VersionAttribute, trait_item: &TraitItemFn) -> Result<Self> {
Ok(Self {
latest_version_to_call: version.is_callable().then_some(version.version),
versions: BTreeMap::from([(
version.version,
RuntimeInterfaceFunction::new(trait_item)?,
)]),
})
}
pub fn latest_version_to_call(&self) -> Option<(u32, &RuntimeInterfaceFunction)> {
self.latest_version_to_call.map(|v| {
(
v,
self.versions.get(&v).expect(
"If latest_version_to_call has a value, the key with this value is in the versions; qed",
),
)
})
}
fn add_version(&mut self, version: VersionAttribute, trait_item: &TraitItemFn) -> Result<()> {
if let Some(existing_item) = self.versions.get(&version.version) {
let mut err = Error::new(trait_item.span(), "Duplicated version attribute");
err.combine(Error::new(
existing_item.span(),
"Previous version with the same number defined here",
));
return Err(err)
}
self.versions
.insert(version.version, RuntimeInterfaceFunction::new(trait_item)?);
if self.latest_version_to_call.map_or(true, |v| v < version.version) &&
version.is_callable()
{
self.latest_version_to_call = Some(version.version);
}
Ok(())
}
}
pub struct RuntimeInterface {
items: BTreeMap<syn::Ident, RuntimeInterfaceFunctionSet>,
}
impl RuntimeInterface {
pub fn latest_versions_to_call(
&self,
) -> impl Iterator<Item = (u32, &RuntimeInterfaceFunction)> {
self.items.iter().filter_map(|(_, item)| item.latest_version_to_call())
}
pub fn all_versions(&self) -> impl Iterator<Item = (u32, &RuntimeInterfaceFunction)> {
self.items
.iter()
.flat_map(|(_, item)| item.versions.iter())
.map(|(v, i)| (*v, i))
}
}
pub fn generate_runtime_interface_include() -> TokenStream {
match crate_name("sp-runtime-interface") {
Ok(FoundCrate::Itself) => quote!(),
Ok(FoundCrate::Name(crate_name)) => {
let crate_name = Ident::new(&crate_name, Span::call_site());
quote!(
#[doc(hidden)]
extern crate #crate_name as proc_macro_runtime_interface;
)
},
Err(e) => {
let err = Error::new(Span::call_site(), e).to_compile_error();
quote!( #err )
},
}
}
pub fn generate_crate_access() -> TokenStream {
if env::var("CARGO_PKG_NAME").unwrap() == "sp-runtime-interface" {
quote!(sp_runtime_interface)
} else {
quote!(proc_macro_runtime_interface)
}
}
pub fn create_exchangeable_host_function_ident(name: &Ident) -> Ident {
Ident::new(&format!("host_{}", name), Span::call_site())
}
pub fn create_host_function_ident(name: &Ident, version: u32, trait_name: &Ident) -> Ident {
Ident::new(
&format!("ext_{}_{}_version_{}", trait_name.to_string().to_snake_case(), name, version),
Span::call_site(),
)
}
pub fn create_function_ident_with_version(name: &Ident, version: u32) -> Ident {
Ident::new(&format!("{}_version_{}", name, version), Span::call_site())
}
pub fn get_function_arguments(sig: &Signature) -> impl Iterator<Item = PatType> + '_ {
sig.inputs
.iter()
.filter_map(|a| match a {
FnArg::Receiver(_) => None,
FnArg::Typed(pat_type) => Some(pat_type),
})
.enumerate()
.map(|(i, arg)| {
let mut res = arg.clone();
if let Pat::Wild(wild) = &*arg.pat {
let ident =
Ident::new(&format!("__runtime_interface_generated_{}_", i), wild.span());
res.pat = Box::new(parse_quote!( #ident ))
}
res
})
}
pub fn get_function_argument_names(sig: &Signature) -> impl Iterator<Item = Box<Pat>> + '_ {
get_function_arguments(sig).map(|pt| pt.pat)
}
pub fn get_function_argument_types(sig: &Signature) -> impl Iterator<Item = Box<Type>> + '_ {
get_function_arguments(sig).map(|pt| pt.ty)
}
pub fn get_function_argument_types_without_ref(
sig: &Signature,
) -> impl Iterator<Item = Box<Type>> + '_ {
get_function_arguments(sig).map(|pt| pt.ty).map(|ty| match *ty {
Type::Reference(type_ref) => type_ref.elem,
_ => ty,
})
}
pub fn get_function_argument_names_and_types_without_ref(
sig: &Signature,
) -> impl Iterator<Item = (Box<Pat>, Box<Type>)> + '_ {
get_function_arguments(sig).map(|pt| match *pt.ty {
Type::Reference(type_ref) => (pt.pat, type_ref.elem),
_ => (pt.pat, pt.ty),
})
}
pub fn get_function_argument_types_ref_and_mut(
sig: &Signature,
) -> impl Iterator<Item = Option<(token::And, Option<token::Mut>)>> + '_ {
get_function_arguments(sig).map(|pt| pt.ty).map(|ty| match *ty {
Type::Reference(type_ref) => Some((type_ref.and_token, type_ref.mutability)),
_ => None,
})
}
fn get_trait_methods(trait_def: &ItemTrait) -> impl Iterator<Item = &TraitItemFn> {
trait_def.items.iter().filter_map(|i| match i {
TraitItem::Fn(ref method) => Some(method),
_ => None,
})
}
struct VersionAttribute {
version: u32,
register_only: Option<attributes::register_only>,
}
impl VersionAttribute {
fn is_callable(&self) -> bool {
self.register_only.is_none()
}
}
impl Default for VersionAttribute {
fn default() -> Self {
Self { version: 1, register_only: None }
}
}
impl Parse for VersionAttribute {
fn parse(input: syn::parse::ParseStream) -> Result<Self> {
let version: LitInt = input.parse()?;
let register_only = if input.peek(token::Comma) {
let _ = input.parse::<token::Comma>();
Some(input.parse()?)
} else {
if !input.is_empty() {
return Err(Error::new(input.span(), "Unexpected token, expected `,`."))
}
None
};
Ok(Self { version: version.base10_parse()?, register_only })
}
}
fn get_item_version(item: &TraitItemFn) -> Result<Option<VersionAttribute>> {
item.attrs
.iter()
.find(|attr| attr.path().is_ident("version"))
.map(|attr| attr.parse_args())
.transpose()
}
pub fn get_runtime_interface(trait_def: &ItemTrait) -> Result<RuntimeInterface> {
let mut functions: BTreeMap<syn::Ident, RuntimeInterfaceFunctionSet> = BTreeMap::new();
for item in get_trait_methods(trait_def) {
let name = item.sig.ident.clone();
let version = get_item_version(item)?.unwrap_or_default();
if version.version < 1 {
return Err(Error::new(item.span(), "Version needs to be at least `1`."))
}
match functions.entry(name.clone()) {
Entry::Vacant(entry) => {
entry.insert(RuntimeInterfaceFunctionSet::new(version, item)?);
},
Entry::Occupied(mut entry) => {
entry.get_mut().add_version(version, item)?;
},
}
}
for function in functions.values() {
let mut next_expected = 1;
for (version, item) in function.versions.iter() {
if next_expected != *version {
return Err(Error::new(
item.span(),
format!(
"Unexpected version attribute: missing version '{}' for this function",
next_expected
),
))
}
next_expected += 1;
}
}
Ok(RuntimeInterface { items: functions })
}