1use crate::{
18 configuration::{TestAuthorities, TestConfiguration},
19 mock::runtime_api::session_info_for_peers,
20 network::{HandleNetworkMessage, NetworkMessage},
21 NODE_UNDER_TEST,
22};
23use bitvec::vec::BitVec;
24use codec::{Decode, Encode};
25use futures::channel::oneshot;
26use itertools::Itertools;
27use polkadot_node_network_protocol::{
28 request_response::{
29 v2::{AttestedCandidateRequest, AttestedCandidateResponse},
30 Requests,
31 },
32 v3::{
33 BackedCandidateAcknowledgement, StatementDistributionMessage, StatementFilter,
34 ValidationProtocol,
35 },
36 ValidationProtocols,
37};
38use polkadot_node_primitives::{AvailableData, BlockData, PoV};
39use polkadot_node_subsystem_test_helpers::{
40 derive_erasure_chunks_with_proofs_and_root, mock::new_block_import_info,
41};
42use polkadot_overseer::BlockInfo;
43
44use polkadot_primitives::MutateDescriptorV2;
45
46use polkadot_primitives::{
47 BlockNumber, CandidateHash, CandidateReceiptV2 as CandidateReceipt,
48 CommittedCandidateReceiptV2 as CommittedCandidateReceipt, CompactStatement, CoreIndex, Hash,
49 Header, Id, PersistedValidationData, SessionInfo, SignedStatement, SigningContext,
50 UncheckedSigned, ValidatorIndex, ValidatorPair,
51};
52use polkadot_primitives_test_helpers::{
53 dummy_committed_candidate_receipt_v2, dummy_hash, dummy_head_data, dummy_pvd,
54};
55use sc_network::{config::IncomingRequest, ProtocolName};
56use sp_core::{Pair, H256};
57use std::{
58 collections::HashMap,
59 sync::{
60 atomic::{AtomicBool, Ordering},
61 Arc,
62 },
63};
64
65const SESSION_INDEX: u32 = 0;
66
67#[derive(Clone)]
68pub struct TestState {
69 pub config: TestConfiguration,
71 pub test_authorities: TestAuthorities,
73 pub block_infos: Vec<BlockInfo>,
75 pub candidate_receipts: HashMap<H256, Vec<CandidateReceipt>>,
77 pub commited_candidate_receipts: HashMap<H256, Vec<CommittedCandidateReceipt>>,
79 pub pvd: PersistedValidationData,
81 pub block_headers: HashMap<H256, Header>,
83 pub session_info: SessionInfo,
85 pub statements: HashMap<CandidateHash, Vec<UncheckedSigned<CompactStatement>>>,
87 pub own_backing_group: Vec<ValidatorIndex>,
89 pub statements_tracker: HashMap<CandidateHash, Vec<Arc<AtomicBool>>>,
91 pub manifests_tracker: HashMap<CandidateHash, Arc<AtomicBool>>,
93}
94
95impl TestState {
96 pub fn new(config: &TestConfiguration) -> Self {
97 let test_authorities = config.generate_authorities();
98 let session_info = session_info_for_peers(config, &test_authorities);
99 let own_backing_group = session_info
100 .validator_groups
101 .iter()
102 .find(|g| g.contains(&ValidatorIndex(NODE_UNDER_TEST)))
103 .unwrap()
104 .clone();
105 let mut state = Self {
106 config: config.clone(),
107 test_authorities,
108 block_infos: (1..=config.num_blocks).map(generate_block_info).collect(),
109 candidate_receipts: Default::default(),
110 commited_candidate_receipts: Default::default(),
111 pvd: dummy_pvd(dummy_head_data(), 0),
112 block_headers: Default::default(),
113 statements_tracker: Default::default(),
114 manifests_tracker: Default::default(),
115 session_info,
116 own_backing_group,
117 statements: Default::default(),
118 };
119
120 state.block_headers = state.block_infos.iter().map(generate_block_header).collect();
121
122 let pov_sizes = Vec::from(config.pov_sizes()); let pov_size_to_candidate = generate_pov_size_to_candidate(&pov_sizes);
125 let receipt_templates =
126 generate_receipt_templates(&pov_size_to_candidate, config.n_validators, &state.pvd);
127
128 for block_info in state.block_infos.iter() {
129 for core_idx in 0..config.n_cores {
130 let pov_size = pov_sizes.get(core_idx).expect("This is a cycle; qed");
131 let candidate_index =
132 *pov_size_to_candidate.get(pov_size).expect("pov_size always exists; qed");
133 let mut receipt = receipt_templates[candidate_index].clone();
134 receipt.descriptor.set_para_id(Id::new(core_idx as u32 + 1));
135 receipt.descriptor.set_relay_parent(block_info.hash);
136 receipt.descriptor.set_core_index(CoreIndex(core_idx as u32));
137 receipt.descriptor.set_session_index(SESSION_INDEX);
138
139 state.candidate_receipts.entry(block_info.hash).or_default().push(
140 CandidateReceipt {
141 descriptor: receipt.descriptor.clone(),
142 commitments_hash: receipt.commitments.hash(),
143 },
144 );
145 state.statements_tracker.entry(receipt.hash()).or_default().extend(
146 (0..config.n_validators)
147 .map(|_| Arc::new(AtomicBool::new(false)))
148 .collect_vec(),
149 );
150 state.manifests_tracker.insert(receipt.hash(), Arc::new(AtomicBool::new(false)));
151 state
152 .commited_candidate_receipts
153 .entry(block_info.hash)
154 .or_default()
155 .push(receipt);
156 }
157 }
158
159 let groups = state.session_info.validator_groups.clone();
160
161 for block_info in state.block_infos.iter() {
162 for (index, group) in groups.iter().enumerate() {
163 let candidate =
164 state.candidate_receipts.get(&block_info.hash).unwrap().get(index).unwrap();
165 let statements = group
166 .iter()
167 .map(|&v| {
168 sign_statement(
169 CompactStatement::Seconded(candidate.hash()),
170 block_info.hash,
171 v,
172 state.test_authorities.validator_pairs.get(v.0 as usize).unwrap(),
173 )
174 })
175 .collect_vec();
176 state.statements.insert(candidate.hash(), statements);
177 }
178 }
179
180 state
181 }
182
183 pub fn reset_trackers(&self) {
184 self.statements_tracker.values().for_each(|v| {
185 v.iter()
186 .enumerate()
187 .for_each(|(index, v)| v.as_ref().store(index <= 1, Ordering::SeqCst))
188 });
189 self.manifests_tracker
190 .values()
191 .for_each(|v| v.as_ref().store(false, Ordering::SeqCst));
192 }
193}
194
195fn sign_statement(
196 statement: CompactStatement,
197 relay_parent: H256,
198 validator_index: ValidatorIndex,
199 pair: &ValidatorPair,
200) -> UncheckedSigned<CompactStatement> {
201 let context = SigningContext { parent_hash: relay_parent, session_index: SESSION_INDEX };
202 let payload = statement.signing_payload(&context);
203
204 SignedStatement::new(
205 statement,
206 validator_index,
207 pair.sign(&payload[..]),
208 &context,
209 &pair.public(),
210 )
211 .unwrap()
212 .as_unchecked()
213 .to_owned()
214}
215
216fn generate_block_info(block_num: usize) -> BlockInfo {
217 new_block_import_info(Hash::repeat_byte(block_num as u8), block_num as BlockNumber)
218}
219
220fn generate_block_header(info: &BlockInfo) -> (H256, Header) {
221 (
222 info.hash,
223 Header {
224 digest: Default::default(),
225 number: info.number,
226 parent_hash: info.parent_hash,
227 extrinsics_root: Default::default(),
228 state_root: Default::default(),
229 },
230 )
231}
232
233fn generate_pov_size_to_candidate(pov_sizes: &[usize]) -> HashMap<usize, usize> {
234 pov_sizes
235 .iter()
236 .cloned()
237 .unique()
238 .enumerate()
239 .map(|(index, pov_size)| (pov_size, index))
240 .collect()
241}
242
243fn generate_receipt_templates(
244 pov_size_to_candidate: &HashMap<usize, usize>,
245 n_validators: usize,
246 pvd: &PersistedValidationData,
247) -> Vec<CommittedCandidateReceipt> {
248 pov_size_to_candidate
249 .iter()
250 .map(|(&pov_size, &index)| {
251 let mut receipt = dummy_committed_candidate_receipt_v2(dummy_hash());
252 let (_, erasure_root) = derive_erasure_chunks_with_proofs_and_root(
253 n_validators,
254 &AvailableData {
255 validation_data: pvd.clone(),
256 pov: Arc::new(PoV { block_data: BlockData(vec![index as u8; pov_size]) }),
257 },
258 |_, _| {},
259 );
260 receipt.descriptor.set_persisted_validation_data_hash(pvd.hash());
261 receipt.descriptor.set_erasure_root(erasure_root);
262 receipt
263 })
264 .collect()
265}
266
267#[async_trait::async_trait]
268impl HandleNetworkMessage for TestState {
269 async fn handle(
270 &self,
271 message: NetworkMessage,
272 node_sender: &mut futures::channel::mpsc::UnboundedSender<NetworkMessage>,
273 ) -> Option<NetworkMessage> {
274 match message {
275 NetworkMessage::RequestFromNode(_authority_id, requests) => {
276 let Requests::AttestedCandidateV2(req) = *requests else { return None };
277 let payload = req.payload;
278 let candidate_receipt = self
279 .commited_candidate_receipts
280 .values()
281 .flatten()
282 .find(|v| v.hash() == payload.candidate_hash)
283 .unwrap()
284 .clone();
285 let persisted_validation_data = self.pvd.clone();
286 let statements = self.statements.get(&payload.candidate_hash).unwrap().clone();
287 let res = AttestedCandidateResponse {
288 candidate_receipt,
289 persisted_validation_data,
290 statements,
291 };
292 let _ = req.pending_response.send(Ok((res.encode(), ProtocolName::from(""))));
293 None
294 },
295 NetworkMessage::MessageFromNode(
296 authority_id,
297 ValidationProtocols::V3(ValidationProtocol::StatementDistribution(
298 StatementDistributionMessage::Statement(relay_parent, statement),
299 )),
300 ) => {
301 let index = self
302 .test_authorities
303 .validator_authority_id
304 .iter()
305 .position(|v| v == &authority_id)
306 .unwrap();
307 let candidate_hash = *statement.unchecked_payload().candidate_hash();
308
309 let statements_sent_count = self
310 .statements_tracker
311 .get(&candidate_hash)
312 .unwrap()
313 .get(index)
314 .unwrap()
315 .as_ref();
316 if statements_sent_count.load(Ordering::SeqCst) {
317 return None
318 } else {
319 statements_sent_count.store(true, Ordering::SeqCst);
320 }
321
322 let group_statements = self.statements.get(&candidate_hash).unwrap();
323 if !group_statements.iter().any(|s| s.unchecked_validator_index().0 == index as u32)
324 {
325 return None
326 }
327
328 let statement = CompactStatement::Valid(candidate_hash);
329 let context =
330 SigningContext { parent_hash: relay_parent, session_index: SESSION_INDEX };
331 let payload = statement.signing_payload(&context);
332 let pair = self.test_authorities.validator_pairs.get(index).unwrap();
333 let signature = pair.sign(&payload[..]);
334 let statement = SignedStatement::new(
335 statement,
336 ValidatorIndex(index as u32),
337 signature,
338 &context,
339 &pair.public(),
340 )
341 .unwrap()
342 .as_unchecked()
343 .to_owned();
344
345 node_sender
346 .start_send(NetworkMessage::MessageFromPeer(
347 *self.test_authorities.peer_ids.get(index).unwrap(),
348 ValidationProtocols::V3(ValidationProtocol::StatementDistribution(
349 StatementDistributionMessage::Statement(relay_parent, statement),
350 )),
351 ))
352 .unwrap();
353 None
354 },
355 NetworkMessage::MessageFromNode(
356 authority_id,
357 ValidationProtocols::V3(ValidationProtocol::StatementDistribution(
358 StatementDistributionMessage::BackedCandidateManifest(manifest),
359 )),
360 ) => {
361 let index = self
362 .test_authorities
363 .validator_authority_id
364 .iter()
365 .position(|v| v == &authority_id)
366 .unwrap();
367 let backing_group =
368 self.session_info.validator_groups.get(manifest.group_index).unwrap();
369 let group_size = backing_group.len();
370 let is_own_backing_group = backing_group.contains(&ValidatorIndex(NODE_UNDER_TEST));
371 let mut seconded_in_group =
372 BitVec::from_iter((0..group_size).map(|_| !is_own_backing_group));
373 let mut validated_in_group = BitVec::from_iter((0..group_size).map(|_| false));
374
375 if is_own_backing_group {
376 let (pending_response, response_receiver) = oneshot::channel();
377 let peer_id = self.test_authorities.peer_ids.get(index).unwrap().to_owned();
378 node_sender
379 .start_send(NetworkMessage::RequestFromPeer(IncomingRequest {
380 peer: peer_id,
381 payload: AttestedCandidateRequest {
382 candidate_hash: manifest.candidate_hash,
383 mask: StatementFilter::blank(self.own_backing_group.len()),
384 }
385 .encode(),
386 pending_response,
387 }))
388 .unwrap();
389
390 let response = response_receiver.await.unwrap();
391 let response =
392 AttestedCandidateResponse::decode(&mut response.result.unwrap().as_ref())
393 .unwrap();
394
395 for statement in response.statements {
396 let validator_index = statement.unchecked_validator_index();
397 let position_in_group =
398 backing_group.iter().position(|v| *v == validator_index).unwrap();
399 match statement.unchecked_payload() {
400 CompactStatement::Seconded(_) =>
401 seconded_in_group.set(position_in_group, true),
402 CompactStatement::Valid(_) =>
403 validated_in_group.set(position_in_group, true),
404 }
405 }
406 }
407
408 let ack = BackedCandidateAcknowledgement {
409 candidate_hash: manifest.candidate_hash,
410 statement_knowledge: StatementFilter { seconded_in_group, validated_in_group },
411 };
412 node_sender
413 .start_send(NetworkMessage::MessageFromPeer(
414 *self.test_authorities.peer_ids.get(index).unwrap(),
415 ValidationProtocols::V3(ValidationProtocol::StatementDistribution(
416 StatementDistributionMessage::BackedCandidateKnown(ack),
417 )),
418 ))
419 .unwrap();
420
421 self.manifests_tracker
422 .get(&manifest.candidate_hash)
423 .unwrap()
424 .as_ref()
425 .store(true, Ordering::SeqCst);
426
427 None
428 },
429 NetworkMessage::MessageFromNode(
430 _authority_id,
431 ValidationProtocols::V3(ValidationProtocol::StatementDistribution(
432 StatementDistributionMessage::BackedCandidateKnown(ack),
433 )),
434 ) => {
435 self.manifests_tracker
436 .get(&ack.candidate_hash)
437 .unwrap()
438 .as_ref()
439 .store(true, Ordering::SeqCst);
440
441 None
442 },
443 _ => Some(message),
444 }
445 }
446}