1use std::{future::Future, pin::Pin, sync::Arc};
26
27use dyn_clone::DynClone;
28
29use sc_client_api::blockchain::HeaderBackend;
30use sp_runtime::traits::{Block as BlockT, Header, NumberFor, One, Zero};
31
32pub type VotingRuleResult<Block> =
34 Pin<Box<dyn Future<Output = Option<(<Block as BlockT>::Hash, NumberFor<Block>)>> + Send>>;
35
36pub trait VotingRule<Block, B>: DynClone + Send + Sync
38where
39 Block: BlockT,
40 B: HeaderBackend<Block>,
41{
42 fn restrict_vote(
55 &self,
56 backend: Arc<B>,
57 base: &Block::Header,
58 best_target: &Block::Header,
59 current_target: &Block::Header,
60 ) -> VotingRuleResult<Block>;
61}
62
63impl<Block, B> VotingRule<Block, B> for ()
64where
65 Block: BlockT,
66 B: HeaderBackend<Block>,
67{
68 fn restrict_vote(
69 &self,
70 _backend: Arc<B>,
71 _base: &Block::Header,
72 _best_target: &Block::Header,
73 _current_target: &Block::Header,
74 ) -> VotingRuleResult<Block> {
75 Box::pin(async { None })
76 }
77}
78
79#[derive(Clone)]
89pub struct BeforeBestBlockBy<N>(pub N);
90impl<Block, B> VotingRule<Block, B> for BeforeBestBlockBy<NumberFor<Block>>
91where
92 Block: BlockT,
93 B: HeaderBackend<Block>,
94{
95 fn restrict_vote(
96 &self,
97 backend: Arc<B>,
98 base: &Block::Header,
99 best_target: &Block::Header,
100 current_target: &Block::Header,
101 ) -> VotingRuleResult<Block> {
102 use sp_arithmetic::traits::Saturating;
103
104 if current_target.number().is_zero() {
105 return Box::pin(async { None })
106 }
107
108 if *base.number() + self.0 > *best_target.number() {
111 return Box::pin(std::future::ready(Some((base.hash(), *base.number()))))
112 }
113
114 let target_number = best_target.number().saturating_sub(self.0);
116
117 if target_number >= *current_target.number() {
119 return Box::pin(async { None })
120 }
121
122 let current_target = current_target.clone();
123
124 Box::pin(std::future::ready(find_target(&*backend, target_number, ¤t_target)))
126 }
127}
128
129#[derive(Clone)]
133pub struct ThreeQuartersOfTheUnfinalizedChain;
134
135impl<Block, B> VotingRule<Block, B> for ThreeQuartersOfTheUnfinalizedChain
136where
137 Block: BlockT,
138 B: HeaderBackend<Block>,
139{
140 fn restrict_vote(
141 &self,
142 backend: Arc<B>,
143 base: &Block::Header,
144 best_target: &Block::Header,
145 current_target: &Block::Header,
146 ) -> VotingRuleResult<Block> {
147 let target_number = {
149 let two = NumberFor::<Block>::one() + One::one();
150 let three = two + One::one();
151 let four = three + One::one();
152
153 let diff = *best_target.number() - *base.number();
154 let diff = ((diff * three) + two) / four;
155
156 *base.number() + diff
157 };
158
159 if target_number >= *current_target.number() {
161 return Box::pin(async { None })
162 }
163
164 Box::pin(std::future::ready(find_target(&*backend, target_number, current_target)))
166 }
167}
168
169fn find_target<Block, B>(
171 backend: &B,
172 target_number: NumberFor<Block>,
173 current_header: &Block::Header,
174) -> Option<(Block::Hash, NumberFor<Block>)>
175where
176 Block: BlockT,
177 B: HeaderBackend<Block>,
178{
179 let mut target_hash = current_header.hash();
180 let mut target_header = current_header.clone();
181
182 loop {
183 if *target_header.number() < target_number {
184 unreachable!(
185 "we are traversing backwards from a known block; \
186 blocks are stored contiguously; \
187 qed"
188 );
189 }
190
191 if *target_header.number() == target_number {
192 return Some((target_hash, target_number))
193 }
194
195 target_hash = *target_header.parent_hash();
196 target_header = backend
197 .header(target_hash)
198 .ok()?
199 .expect("Header known to exist due to the existence of one of its descendants; qed");
200 }
201}
202
203struct VotingRules<Block, B> {
204 rules: Arc<Vec<Box<dyn VotingRule<Block, B>>>>,
205}
206
207impl<B, Block> Clone for VotingRules<B, Block> {
208 fn clone(&self) -> Self {
209 VotingRules { rules: self.rules.clone() }
210 }
211}
212
213impl<Block, B> VotingRule<Block, B> for VotingRules<Block, B>
214where
215 Block: BlockT,
216 B: HeaderBackend<Block> + 'static,
217{
218 fn restrict_vote(
219 &self,
220 backend: Arc<B>,
221 base: &Block::Header,
222 best_target: &Block::Header,
223 current_target: &Block::Header,
224 ) -> VotingRuleResult<Block> {
225 let rules = self.rules.clone();
226 let base = base.clone();
227 let best_target = best_target.clone();
228 let current_target = current_target.clone();
229
230 Box::pin(async move {
231 let mut restricted_target = current_target.clone();
232
233 for rule in rules.iter() {
234 if let Some(header) = rule
235 .restrict_vote(backend.clone(), &base, &best_target, &restricted_target)
236 .await
237 .filter(|(_, restricted_number)| {
238 restricted_number >= base.number() &&
240 restricted_number < restricted_target.number()
241 })
242 .and_then(|(hash, _)| backend.header(hash).ok())
243 .and_then(std::convert::identity)
244 {
245 restricted_target = header;
246 }
247 }
248
249 let restricted_hash = restricted_target.hash();
250
251 if restricted_hash != current_target.hash() {
252 Some((restricted_hash, *restricted_target.number()))
253 } else {
254 None
255 }
256 })
257 }
258}
259
260pub struct VotingRulesBuilder<Block, B> {
263 rules: Vec<Box<dyn VotingRule<Block, B>>>,
264}
265
266impl<Block, B> Default for VotingRulesBuilder<Block, B>
267where
268 Block: BlockT,
269 B: HeaderBackend<Block> + 'static,
270{
271 fn default() -> Self {
272 VotingRulesBuilder::new()
273 .add(BeforeBestBlockBy(2u32.into()))
274 .add(ThreeQuartersOfTheUnfinalizedChain)
275 }
276}
277
278impl<Block, B> VotingRulesBuilder<Block, B>
279where
280 Block: BlockT,
281 B: HeaderBackend<Block> + 'static,
282{
283 pub fn new() -> Self {
285 VotingRulesBuilder { rules: Vec::new() }
286 }
287
288 pub fn add<R>(mut self, rule: R) -> Self
290 where
291 R: VotingRule<Block, B> + 'static,
292 {
293 self.rules.push(Box::new(rule));
294 self
295 }
296
297 pub fn add_all<I>(mut self, rules: I) -> Self
299 where
300 I: IntoIterator<Item = Box<dyn VotingRule<Block, B>>>,
301 {
302 self.rules.extend(rules);
303 self
304 }
305
306 pub fn build(self) -> impl VotingRule<Block, B> + Clone {
309 VotingRules { rules: Arc::new(self.rules) }
310 }
311}
312
313impl<Block, B> VotingRule<Block, B> for Box<dyn VotingRule<Block, B>>
314where
315 Block: BlockT,
316 B: HeaderBackend<Block>,
317 Self: Clone,
318{
319 fn restrict_vote(
320 &self,
321 backend: Arc<B>,
322 base: &Block::Header,
323 best_target: &Block::Header,
324 current_target: &Block::Header,
325 ) -> VotingRuleResult<Block> {
326 (**self).restrict_vote(backend, base, best_target, current_target)
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use sc_block_builder::BlockBuilderBuilder;
334 use sp_consensus::BlockOrigin;
335 use sp_runtime::traits::Header as _;
336
337 use substrate_test_runtime_client::{
338 runtime::{Block, Header},
339 Backend, Client, ClientBlockImportExt, DefaultTestClientBuilderExt, TestClientBuilder,
340 TestClientBuilderExt,
341 };
342
343 #[derive(Clone)]
345 struct Subtract(u64);
346 impl VotingRule<Block, Client<Backend>> for Subtract {
347 fn restrict_vote(
348 &self,
349 backend: Arc<Client<Backend>>,
350 _base: &Header,
351 _best_target: &Header,
352 current_target: &Header,
353 ) -> VotingRuleResult<Block> {
354 let target_number = current_target.number() - self.0;
355 let res = backend
356 .hash(target_number)
357 .unwrap()
358 .map(|target_hash| (target_hash, target_number));
359
360 Box::pin(std::future::ready(res))
361 }
362 }
363
364 #[test]
365 fn multiple_voting_rules_cannot_restrict_past_base() {
366 let rule = VotingRulesBuilder::new().add(Subtract(50)).add(Subtract(50)).build();
369
370 let client = Arc::new(TestClientBuilder::new().build());
371 let mut hashes = Vec::with_capacity(200);
372
373 for _ in 0..200 {
374 let block = BlockBuilderBuilder::new(&*client)
375 .on_parent_block(client.chain_info().best_hash)
376 .with_parent_block_number(client.chain_info().best_number)
377 .build()
378 .unwrap()
379 .build()
380 .unwrap()
381 .block;
382 hashes.push(block.hash());
383
384 futures::executor::block_on(client.import(BlockOrigin::Own, block)).unwrap();
385 }
386
387 let genesis = client.header(client.info().genesis_hash).unwrap().unwrap();
388
389 let best = client.header(client.info().best_hash).unwrap().unwrap();
390
391 let (_, number) =
392 futures::executor::block_on(rule.restrict_vote(client.clone(), &genesis, &best, &best))
393 .unwrap();
394
395 assert_eq!(number, 100);
398
399 let block110 = client.header(hashes[109]).unwrap().unwrap();
400
401 let (_, number) = futures::executor::block_on(rule.restrict_vote(
402 client.clone(),
403 &block110,
404 &best,
405 &best,
406 ))
407 .unwrap();
408
409 assert_eq!(number, 150);
413 }
414
415 #[test]
416 fn before_best_by_has_cutoff_at_base() {
417 let rule = BeforeBestBlockBy(2);
418
419 let client = Arc::new(TestClientBuilder::new().build());
420
421 let n = 5;
422 let mut hashes = Vec::with_capacity(n);
423 for _ in 0..n {
424 let block = BlockBuilderBuilder::new(&*client)
425 .on_parent_block(client.chain_info().best_hash)
426 .with_parent_block_number(client.chain_info().best_number)
427 .build()
428 .unwrap()
429 .build()
430 .unwrap()
431 .block;
432 hashes.push(block.hash());
433
434 futures::executor::block_on(client.import(BlockOrigin::Own, block)).unwrap();
435 }
436
437 let best = client.header(client.info().best_hash).unwrap().unwrap();
438 let best_number = *best.number();
439
440 for i in 0..n {
441 let base = client.header(hashes[i]).unwrap().unwrap();
442 let (_, number) = futures::executor::block_on(rule.restrict_vote(
443 client.clone(),
444 &base,
445 &best,
446 &best,
447 ))
448 .unwrap();
449
450 let expected = std::cmp::max(best_number - 2, *base.number());
451 assert_eq!(number, expected, "best = {}, lag = 2, base = {}", best_number, i);
452 }
453 }
454}