1use std::{
22 any::{Any, TypeId},
23 fmt::Debug,
24};
25
26use std::collections::BTreeMap;
27
28use serde::{de::DeserializeOwned, Deserialize, Serialize};
29
30pub trait Group: Clone + Sized {
34 type Fork: Fork<Base = Self>;
36
37 fn to_fork(self) -> Self::Fork;
39}
40
41pub trait Fork: Serialize + DeserializeOwned + Clone + Sized {
49 type Base: Group<Fork = Self>;
51
52 fn combine_with(&mut self, other: Self);
57
58 fn to_base(self) -> Option<Self::Base>;
60}
61
62macro_rules! impl_trivial {
63 () => {};
64 ($A : ty) => {
65 impl_trivial!($A ,);
66 };
67 ($A : ty , $( $B : ty ),*) => {
68 impl_trivial!($( $B ),*);
69
70 impl Group for $A {
71 type Fork = $A;
72
73 fn to_fork(self) -> Self::Fork {
74 self
75 }
76 }
77
78 impl Fork for $A {
79 type Base = $A;
80
81 fn combine_with(&mut self, other: Self) {
82 *self = other;
83 }
84
85 fn to_base(self) -> Option<Self::Base> {
86 Some(self)
87 }
88 }
89 }
90}
91
92impl_trivial!((), u8, u16, u32, u64, usize, String, Vec<u8>);
93
94impl<T: Group> Group for Option<T> {
95 type Fork = Option<T::Fork>;
96
97 fn to_fork(self) -> Self::Fork {
98 self.map(|a| a.to_fork())
99 }
100}
101
102impl<T: Fork> Fork for Option<T> {
103 type Base = Option<T::Base>;
104
105 fn combine_with(&mut self, other: Self) {
106 *self = match (self.take(), other) {
107 (Some(mut a), Some(b)) => {
108 a.combine_with(b);
109 Some(a)
110 },
111 (a, b) => a.or(b),
112 };
113 }
114
115 fn to_base(self) -> Option<Self::Base> {
116 self.map(|x| x.to_base())
117 }
118}
119
120pub trait Extension: Serialize + DeserializeOwned + Clone {
125 type Forks: IsForks;
126
127 fn get<T: 'static>(&self) -> Option<&T>;
129 fn get_any(&self, t: TypeId) -> &dyn Any;
131 fn get_any_mut(&mut self, t: TypeId) -> &mut dyn Any;
133
134 fn forks<BlockNumber, T>(&self) -> Option<Forks<BlockNumber, T>>
136 where
137 BlockNumber: Ord + Clone + 'static,
138 T: Group + 'static,
139 <Self::Forks as IsForks>::Extension: Extension,
140 <<Self::Forks as IsForks>::Extension as Group>::Fork: Extension,
141 {
142 self.get::<Forks<BlockNumber, <Self::Forks as IsForks>::Extension>>()?
143 .for_type()
144 }
145}
146
147impl Extension for crate::NoExtension {
148 type Forks = Self;
149
150 fn get<T: 'static>(&self) -> Option<&T> {
151 None
152 }
153 fn get_any(&self, _t: TypeId) -> &dyn Any {
154 self
155 }
156 fn get_any_mut(&mut self, _: TypeId) -> &mut dyn Any {
157 self
158 }
159}
160
161pub trait IsForks {
162 type BlockNumber: Ord + 'static;
163 type Extension: Group + 'static;
164}
165
166impl IsForks for Option<()> {
167 type BlockNumber = u64;
168 type Extension = Self;
169}
170
171#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
172#[serde(deny_unknown_fields)]
173pub struct Forks<BlockNumber: Ord, T: Group> {
174 forks: BTreeMap<BlockNumber, T::Fork>,
175 #[serde(flatten)]
176 base: T,
177}
178
179impl<B: Ord, T: Group + Default> Default for Forks<B, T> {
180 fn default() -> Self {
181 Self { base: Default::default(), forks: Default::default() }
182 }
183}
184
185impl<B: Ord, T: Group> Forks<B, T>
186where
187 T::Fork: Debug,
188{
189 pub fn new(base: T, forks: BTreeMap<B, T::Fork>) -> Self {
191 Self { base, forks }
192 }
193
194 pub fn at_block(&self, block: B) -> T {
196 let mut start = self.base.clone().to_fork();
197
198 for (_, fork) in self.forks.range(..=block) {
199 start.combine_with(fork.clone());
200 }
201
202 start
203 .to_base()
204 .expect("We start from the `base` object, so it's always fully initialized; qed")
205 }
206}
207
208impl<B, T> IsForks for Forks<B, T>
209where
210 B: Ord + 'static,
211 T: Group + 'static,
212{
213 type BlockNumber = B;
214 type Extension = T;
215}
216
217impl<B: Ord + Clone, T: Group + Extension> Forks<B, T>
218where
219 T::Fork: Extension,
220{
221 pub fn for_type<X>(&self) -> Option<Forks<B, X>>
226 where
227 X: Group + 'static,
228 {
229 let base = self.base.get::<X>()?.clone();
230 let forks = self
231 .forks
232 .iter()
233 .filter_map(|(k, v)| Some((k.clone(), v.get::<Option<X::Fork>>()?.clone()?)))
234 .collect();
235
236 Some(Forks { base, forks })
237 }
238}
239
240impl<B, E> Extension for Forks<B, E>
241where
242 B: Serialize + DeserializeOwned + Ord + Clone + 'static,
243 E: Extension + Group + 'static,
244{
245 type Forks = Self;
246
247 fn get<T: 'static>(&self) -> Option<&T> {
248 if TypeId::of::<T>() == TypeId::of::<E>() {
249 <dyn Any>::downcast_ref(&self.base)
250 } else {
251 self.base.get()
252 }
253 }
254
255 fn get_any(&self, t: TypeId) -> &dyn Any {
256 if t == TypeId::of::<E>() {
257 &self.base
258 } else {
259 self.base.get_any(t)
260 }
261 }
262
263 fn get_any_mut(&mut self, t: TypeId) -> &mut dyn Any {
264 if t == TypeId::of::<E>() {
265 &mut self.base
266 } else {
267 self.base.get_any_mut(t)
268 }
269 }
270
271 fn forks<BlockNumber, T>(&self) -> Option<Forks<BlockNumber, T>>
272 where
273 BlockNumber: Ord + Clone + 'static,
274 T: Group + 'static,
275 <Self::Forks as IsForks>::Extension: Extension,
276 <<Self::Forks as IsForks>::Extension as Group>::Fork: Extension,
277 {
278 if TypeId::of::<BlockNumber>() == TypeId::of::<B>() {
279 <dyn Any>::downcast_ref(&self.for_type::<T>()?).cloned()
280 } else {
281 self.get::<Forks<BlockNumber, <Self::Forks as IsForks>::Extension>>()?
282 .for_type()
283 }
284 }
285}
286
287pub trait GetExtension {
289 fn get_any(&self, t: TypeId) -> &dyn Any;
291
292 fn get_any_mut(&mut self, t: TypeId) -> &mut dyn Any;
294}
295
296impl<E: Extension> GetExtension for E {
297 fn get_any(&self, t: TypeId) -> &dyn Any {
298 Extension::get_any(self, t)
299 }
300
301 fn get_any_mut(&mut self, t: TypeId) -> &mut dyn Any {
302 Extension::get_any_mut(self, t)
303 }
304}
305
306pub fn get_extension<T: 'static>(e: &dyn GetExtension) -> Option<&T> {
308 <dyn Any>::downcast_ref(GetExtension::get_any(e, TypeId::of::<T>()))
309}
310
311pub fn get_extension_mut<T: 'static>(e: &mut dyn GetExtension) -> Option<&mut T> {
313 <dyn Any>::downcast_mut(GetExtension::get_any_mut(e, TypeId::of::<T>()))
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use sc_chain_spec_derive::{ChainSpecExtension, ChainSpecGroup};
320 use crate as sc_chain_spec;
322
323 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ChainSpecGroup)]
324 #[serde(deny_unknown_fields)]
325 pub struct Extension1 {
326 pub test: u64,
327 }
328
329 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ChainSpecGroup)]
330 #[serde(deny_unknown_fields)]
331 pub struct Extension2 {
332 pub test: u8,
333 }
334
335 #[derive(
336 Debug, Clone, PartialEq, Serialize, Deserialize, ChainSpecGroup, ChainSpecExtension,
337 )]
338 #[serde(deny_unknown_fields)]
339 pub struct Extensions {
340 pub ext1: Extension1,
341 pub ext2: Extension2,
342 }
343
344 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, ChainSpecExtension)]
345 #[serde(deny_unknown_fields)]
346 pub struct Ext2 {
347 #[serde(flatten)]
348 ext1: Extension1,
349 #[forks]
350 forkable: Forks<u64, Extensions>,
351 }
352
353 #[test]
354 fn forks_should_work_correctly() {
355 use super::Extension as _;
356
357 let ext_val: serde_json::Value = serde_json::from_str(
360 r#"
361{
362 "test": 11,
363 "forkable": {
364 "ext1": {
365 "test": 15
366 },
367 "ext2": {
368 "test": 123
369 },
370 "forks": {
371 "1": {
372 "ext1": { "test": 5 }
373 },
374 "2": {
375 "ext2": { "test": 5 }
376 },
377 "5": {
378 "ext2": { "test": 1 }
379 }
380 }
381 }
382}
383 "#,
384 )
385 .unwrap();
386
387 let ext: Ext2 = serde_json::from_value(ext_val).unwrap();
388
389 assert_eq!(ext.get::<Extension1>(), Some(&Extension1 { test: 11 }));
390
391 let forks = ext.get::<Forks<u64, Extensions>>().unwrap();
393 assert_eq!(
394 forks.at_block(0),
395 Extensions { ext1: Extension1 { test: 15 }, ext2: Extension2 { test: 123 } }
396 );
397 assert_eq!(
398 forks.at_block(1),
399 Extensions { ext1: Extension1 { test: 5 }, ext2: Extension2 { test: 123 } }
400 );
401 assert_eq!(
402 forks.at_block(2),
403 Extensions { ext1: Extension1 { test: 5 }, ext2: Extension2 { test: 5 } }
404 );
405 assert_eq!(
406 forks.at_block(4),
407 Extensions { ext1: Extension1 { test: 5 }, ext2: Extension2 { test: 5 } }
408 );
409 assert_eq!(
410 forks.at_block(5),
411 Extensions { ext1: Extension1 { test: 5 }, ext2: Extension2 { test: 1 } }
412 );
413 assert_eq!(
414 forks.at_block(10),
415 Extensions { ext1: Extension1 { test: 5 }, ext2: Extension2 { test: 1 } }
416 );
417 assert!(forks.at_block(10).get::<Extension2>().is_some());
418
419 let ext2 = forks.for_type::<Extension2>().unwrap();
421 assert_eq!(ext2.at_block(0), Extension2 { test: 123 });
422 assert_eq!(ext2.at_block(2), Extension2 { test: 5 });
423 assert_eq!(ext2.at_block(10), Extension2 { test: 1 });
424
425 let ext2_2 = forks.forks::<u64, Extension2>().unwrap();
427 assert_eq!(ext2, ext2_2);
428
429 let ext2_3 = ext.forks::<u64, Extension2>().unwrap();
431 assert_eq!(ext2_2, ext2_3);
432 }
433}