frame_support_procedural/construct_runtime/expand/
origin.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: Apache-2.0
5
6// Licensed under the Apache License, Version 2.0 (the "License");
7// you may not use this file except in compliance with the License.
8// You may obtain a copy of the License at
9//
10// 	http://www.apache.org/licenses/LICENSE-2.0
11//
12// Unless required by applicable law or agreed to in writing, software
13// distributed under the License is distributed on an "AS IS" BASIS,
14// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15// See the License for the specific language governing permissions and
16// limitations under the License
17
18use crate::construct_runtime::{Pallet, SYSTEM_PALLET_NAME};
19use proc_macro2::TokenStream;
20use quote::quote;
21use std::str::FromStr;
22use syn::{Generics, Ident};
23
24pub fn expand_outer_origin(
25	runtime: &Ident,
26	system_pallet: &Pallet,
27	pallets: &[Pallet],
28	scrate: &TokenStream,
29) -> syn::Result<TokenStream> {
30	let mut caller_variants = TokenStream::new();
31	let mut pallet_conversions = TokenStream::new();
32	let mut query_origin_part_macros = Vec::new();
33
34	for pallet_decl in pallets.iter().filter(|pallet| pallet.name != SYSTEM_PALLET_NAME) {
35		if let Some(pallet_entry) = pallet_decl.find_part("Origin") {
36			let instance = pallet_decl.instance.as_ref();
37			let index = pallet_decl.index;
38			let generics = &pallet_entry.generics;
39			let name = &pallet_decl.name;
40			let path = &pallet_decl.path;
41
42			if instance.is_some() && generics.params.is_empty() {
43				let msg = format!(
44					"Instantiable pallet with no generic `Origin` cannot \
45					 be constructed: pallet `{}` must have generic `Origin`",
46					name
47				);
48				return Err(syn::Error::new(name.span(), msg))
49			}
50
51			caller_variants.extend(expand_origin_caller_variant(
52				runtime,
53				pallet_decl,
54				index,
55				instance,
56				generics,
57			));
58			pallet_conversions.extend(expand_origin_pallet_conversions(
59				scrate,
60				runtime,
61				pallet_decl,
62				instance,
63				generics,
64			));
65			query_origin_part_macros.push(quote! {
66				#path::__substrate_origin_check::is_origin_part_defined!(#name);
67			});
68		}
69	}
70
71	let system_path = &system_pallet.path;
72
73	let system_index = system_pallet.index;
74
75	let system_path_name = system_path.module_name();
76
77	let doc_string = get_intra_doc_string(
78		"Origin is always created with the base filter configured in",
79		&system_path_name,
80	);
81
82	let doc_string_none_origin =
83		get_intra_doc_string("Create with system none origin and", &system_path_name);
84
85	let doc_string_root_origin =
86		get_intra_doc_string("Create with system root origin and", &system_path_name);
87
88	let doc_string_signed_origin =
89		get_intra_doc_string("Create with system signed origin and", &system_path_name);
90
91	let doc_string_runtime_origin =
92		get_intra_doc_string("Convert to runtime origin, using as filter:", &system_path_name);
93
94	let doc_string_runtime_origin_with_caller = get_intra_doc_string(
95		"Convert to runtime origin with caller being system signed or none and use filter",
96		&system_path_name,
97	);
98
99	Ok(quote! {
100		#( #query_origin_part_macros )*
101
102		/// The runtime origin type representing the origin of a call.
103		///
104		#[doc = #doc_string]
105		#[derive(Clone)]
106		pub struct RuntimeOrigin {
107			pub caller: OriginCaller,
108			filter: #scrate::__private::Rc<#scrate::__private::Box<dyn Fn(&<#runtime as #system_path::Config>::RuntimeCall) -> bool>>,
109		}
110
111		#[cfg(not(feature = "std"))]
112		impl core::fmt::Debug for RuntimeOrigin {
113			fn fmt(
114				&self,
115				fmt: &mut core::fmt::Formatter,
116			) -> core::result::Result<(), core::fmt::Error> {
117				fmt.write_str("<wasm:stripped>")
118			}
119		}
120
121		#[cfg(feature = "std")]
122		impl core::fmt::Debug for RuntimeOrigin {
123			fn fmt(
124				&self,
125				fmt: &mut core::fmt::Formatter,
126			) -> core::result::Result<(), core::fmt::Error> {
127				fmt.debug_struct("Origin")
128					.field("caller", &self.caller)
129					.field("filter", &"[function ptr]")
130					.finish()
131			}
132		}
133
134		impl #scrate::traits::OriginTrait for RuntimeOrigin {
135			type Call = <#runtime as #system_path::Config>::RuntimeCall;
136			type PalletsOrigin = OriginCaller;
137			type AccountId = <#runtime as #system_path::Config>::AccountId;
138
139			fn add_filter(&mut self, filter: impl Fn(&Self::Call) -> bool + 'static) {
140				let f = self.filter.clone();
141
142				self.filter = #scrate::__private::Rc::new(#scrate::__private::Box::new(move |call| {
143					f(call) && filter(call)
144				}));
145			}
146
147			fn reset_filter(&mut self) {
148				let filter = <
149					<#runtime as #system_path::Config>::BaseCallFilter
150					as #scrate::traits::Contains<<#runtime as #system_path::Config>::RuntimeCall>
151				>::contains;
152
153				self.filter = #scrate::__private::Rc::new(#scrate::__private::Box::new(filter));
154			}
155
156			fn set_caller_from(&mut self, other: impl Into<Self>) {
157				self.caller = other.into().caller;
158			}
159
160			fn filter_call(&self, call: &Self::Call) -> bool {
161				match self.caller {
162					// Root bypasses all filters
163					OriginCaller::system(#system_path::Origin::<#runtime>::Root) => true,
164					_ => (self.filter)(call),
165				}
166			}
167
168			fn caller(&self) -> &Self::PalletsOrigin {
169				&self.caller
170			}
171
172			fn into_caller(self) -> Self::PalletsOrigin {
173				self.caller
174			}
175
176			fn try_with_caller<R>(
177				mut self,
178				f: impl FnOnce(Self::PalletsOrigin) -> Result<R, Self::PalletsOrigin>,
179			) -> Result<R, Self> {
180				match f(self.caller) {
181					Ok(r) => Ok(r),
182					Err(caller) => { self.caller = caller; Err(self) }
183				}
184			}
185
186			fn none() -> Self {
187				#system_path::RawOrigin::None.into()
188			}
189
190			fn root() -> Self {
191				#system_path::RawOrigin::Root.into()
192			}
193
194			fn signed(by: Self::AccountId) -> Self {
195				#system_path::RawOrigin::Signed(by).into()
196			}
197		}
198
199		#[derive(
200			Clone, PartialEq, Eq, #scrate::__private::RuntimeDebug, #scrate::__private::codec::Encode,
201			#scrate::__private::codec::Decode, #scrate::__private::scale_info::TypeInfo, #scrate::__private::codec::MaxEncodedLen,
202		)]
203		#[allow(non_camel_case_types)]
204		pub enum OriginCaller {
205			#[codec(index = #system_index)]
206			system(#system_path::Origin<#runtime>),
207			#caller_variants
208			#[allow(dead_code)]
209			Void(#scrate::__private::Void)
210		}
211
212		// For backwards compatibility and ease of accessing these functions.
213		#[allow(dead_code)]
214		impl RuntimeOrigin {
215			#[doc = #doc_string_none_origin]
216			pub fn none() -> Self {
217				<RuntimeOrigin as #scrate::traits::OriginTrait>::none()
218			}
219
220			#[doc = #doc_string_root_origin]
221			pub fn root() -> Self {
222				<RuntimeOrigin as #scrate::traits::OriginTrait>::root()
223			}
224
225			#[doc = #doc_string_signed_origin]
226			pub fn signed(by: <#runtime as #system_path::Config>::AccountId) -> Self {
227				<RuntimeOrigin as #scrate::traits::OriginTrait>::signed(by)
228			}
229		}
230
231		impl From<#system_path::Origin<#runtime>> for OriginCaller {
232			fn from(x: #system_path::Origin<#runtime>) -> Self {
233				OriginCaller::system(x)
234			}
235		}
236
237		impl #scrate::traits::CallerTrait<<#runtime as #system_path::Config>::AccountId> for OriginCaller {
238			fn into_system(self) -> Option<#system_path::RawOrigin<<#runtime as #system_path::Config>::AccountId>> {
239				match self {
240					OriginCaller::system(x) => Some(x),
241					_ => None,
242				}
243			}
244			fn as_system_ref(&self) -> Option<&#system_path::RawOrigin<<#runtime as #system_path::Config>::AccountId>> {
245				match &self {
246					OriginCaller::system(o) => Some(o),
247					_ => None,
248				}
249			}
250		}
251
252		impl TryFrom<OriginCaller> for #system_path::Origin<#runtime> {
253			type Error = OriginCaller;
254			fn try_from(x: OriginCaller)
255				-> core::result::Result<#system_path::Origin<#runtime>, OriginCaller>
256			{
257				if let OriginCaller::system(l) = x {
258					Ok(l)
259				} else {
260					Err(x)
261				}
262			}
263		}
264
265		impl From<#system_path::Origin<#runtime>> for RuntimeOrigin {
266
267			#[doc = #doc_string_runtime_origin]
268			fn from(x: #system_path::Origin<#runtime>) -> Self {
269				let o: OriginCaller = x.into();
270				o.into()
271			}
272		}
273
274		impl From<OriginCaller> for RuntimeOrigin {
275			fn from(x: OriginCaller) -> Self {
276				let mut o = RuntimeOrigin {
277					caller: x,
278					filter: #scrate::__private::Rc::new(#scrate::__private::Box::new(|_| true)),
279				};
280
281				#scrate::traits::OriginTrait::reset_filter(&mut o);
282
283				o
284			}
285		}
286
287		impl From<RuntimeOrigin> for core::result::Result<#system_path::Origin<#runtime>, RuntimeOrigin> {
288			/// NOTE: converting to pallet origin loses the origin filter information.
289			fn from(val: RuntimeOrigin) -> Self {
290				if let OriginCaller::system(l) = val.caller {
291					Ok(l)
292				} else {
293					Err(val)
294				}
295			}
296		}
297		impl From<Option<<#runtime as #system_path::Config>::AccountId>> for RuntimeOrigin {
298			#[doc = #doc_string_runtime_origin_with_caller]
299			fn from(x: Option<<#runtime as #system_path::Config>::AccountId>) -> Self {
300				<#system_path::Origin<#runtime>>::from(x).into()
301			}
302		}
303
304		#pallet_conversions
305	})
306}
307
308fn expand_origin_caller_variant(
309	runtime: &Ident,
310	pallet: &Pallet,
311	index: u8,
312	instance: Option<&Ident>,
313	generics: &Generics,
314) -> TokenStream {
315	let part_is_generic = !generics.params.is_empty();
316	let variant_name = &pallet.name;
317	let path = &pallet.path;
318	let attr = pallet.cfg_pattern.iter().fold(TokenStream::new(), |acc, pattern| {
319		let attr = TokenStream::from_str(&format!("#[cfg({})]", pattern.original()))
320			.expect("was successfully parsed before; qed");
321		quote! {
322			#acc
323			#attr
324		}
325	});
326
327	match instance {
328		Some(inst) if part_is_generic => quote! {
329			#attr
330			#[codec(index = #index)]
331			#variant_name(#path::Origin<#runtime, #path::#inst>),
332		},
333		Some(inst) => quote! {
334			#attr
335			#[codec(index = #index)]
336			#variant_name(#path::Origin<#path::#inst>),
337		},
338		None if part_is_generic => quote! {
339			#attr
340			#[codec(index = #index)]
341			#variant_name(#path::Origin<#runtime>),
342		},
343		None => quote! {
344			#attr
345			#[codec(index = #index)]
346			#variant_name(#path::Origin),
347		},
348	}
349}
350
351fn expand_origin_pallet_conversions(
352	_scrate: &TokenStream,
353	runtime: &Ident,
354	pallet: &Pallet,
355	instance: Option<&Ident>,
356	generics: &Generics,
357) -> TokenStream {
358	let path = &pallet.path;
359	let variant_name = &pallet.name;
360
361	let part_is_generic = !generics.params.is_empty();
362	let pallet_origin = match instance {
363		Some(inst) if part_is_generic => quote!(#path::Origin<#runtime, #path::#inst>),
364		Some(inst) => quote!(#path::Origin<#path::#inst>),
365		None if part_is_generic => quote!(#path::Origin<#runtime>),
366		None => quote!(#path::Origin),
367	};
368
369	let doc_string = get_intra_doc_string(" Convert to runtime origin using", &path.module_name());
370	let attr = pallet.cfg_pattern.iter().fold(TokenStream::new(), |acc, pattern| {
371		let attr = TokenStream::from_str(&format!("#[cfg({})]", pattern.original()))
372			.expect("was successfully parsed before; qed");
373		quote! {
374			#acc
375			#attr
376		}
377	});
378
379	quote! {
380		#attr
381		impl From<#pallet_origin> for OriginCaller {
382			fn from(x: #pallet_origin) -> Self {
383				OriginCaller::#variant_name(x)
384			}
385		}
386
387		#attr
388		impl From<#pallet_origin> for RuntimeOrigin {
389			#[doc = #doc_string]
390			fn from(x: #pallet_origin) -> Self {
391				let x: OriginCaller = x.into();
392				x.into()
393			}
394		}
395
396		#attr
397		impl From<RuntimeOrigin> for core::result::Result<#pallet_origin, RuntimeOrigin> {
398			/// NOTE: converting to pallet origin loses the origin filter information.
399			fn from(val: RuntimeOrigin) -> Self {
400				if let OriginCaller::#variant_name(l) = val.caller {
401					Ok(l)
402				} else {
403					Err(val)
404				}
405			}
406		}
407
408		#attr
409		impl TryFrom<OriginCaller> for #pallet_origin {
410			type Error = OriginCaller;
411			fn try_from(
412				x: OriginCaller,
413			) -> core::result::Result<#pallet_origin, OriginCaller> {
414				if let OriginCaller::#variant_name(l) = x {
415					Ok(l)
416				} else {
417					Err(x)
418				}
419			}
420		}
421
422		#attr
423		impl<'a> TryFrom<&'a OriginCaller> for &'a #pallet_origin {
424			type Error = ();
425			fn try_from(
426				x: &'a OriginCaller,
427			) -> core::result::Result<&'a #pallet_origin, ()> {
428				if let OriginCaller::#variant_name(l) = x {
429					Ok(&l)
430				} else {
431					Err(())
432				}
433			}
434		}
435
436		#attr
437		impl<'a> TryFrom<&'a RuntimeOrigin> for &'a #pallet_origin {
438			type Error = ();
439			fn try_from(
440				x: &'a RuntimeOrigin,
441			) -> core::result::Result<&'a #pallet_origin, ()> {
442				if let OriginCaller::#variant_name(l) = &x.caller {
443					Ok(&l)
444				} else {
445					Err(())
446				}
447			}
448		}
449	}
450}
451
452// Get the actual documentation using the doc information and system path name
453fn get_intra_doc_string(doc_info: &str, system_path_name: &String) -> String {
454	format!(" {} [`{}::Config::BaseCallFilter`].", doc_info, system_path_name)
455}