polkadot_availability_distribution/requester/
session_cache.rs1use std::collections::HashSet;
18
19use rand::{seq::SliceRandom, thread_rng};
20use schnellru::{ByLength, LruMap};
21
22use polkadot_node_subsystem::overseer;
23use polkadot_node_subsystem_util::{request_node_features, runtime::RuntimeInfo};
24use polkadot_primitives::{
25 AuthorityDiscoveryId, GroupIndex, Hash, NodeFeatures, SessionIndex, ValidatorIndex,
26};
27
28use crate::{
29 error::{Error, Result},
30 LOG_TARGET,
31};
32
33pub struct SessionCache {
37 session_info_cache: LruMap<SessionIndex, SessionInfo>,
43}
44
45#[derive(Clone)]
47pub struct SessionInfo {
48 pub session_index: SessionIndex,
50
51 pub validator_groups: Vec<Vec<AuthorityDiscoveryId>>,
58
59 pub our_index: ValidatorIndex,
61
62 pub our_group: Option<GroupIndex>,
67
68 pub node_features: NodeFeatures,
70}
71
72pub struct BadValidators {
77 pub session_index: SessionIndex,
79 pub group_index: GroupIndex,
81 pub bad_validators: Vec<AuthorityDiscoveryId>,
83}
84
85#[overseer::contextbounds(AvailabilityDistribution, prefix = self::overseer)]
86impl SessionCache {
87 pub fn new() -> Self {
89 SessionCache {
90 session_info_cache: LruMap::new(ByLength::new(2)),
92 }
93 }
94
95 pub async fn get_session_info<'a, Context>(
98 &'a mut self,
99 ctx: &mut Context,
100 runtime: &mut RuntimeInfo,
101 parent: Hash,
102 session_index: SessionIndex,
103 ) -> Result<Option<&'a SessionInfo>> {
104 gum::trace!(target: LOG_TARGET, session_index, "Calling `get_session_info`");
105
106 if self.session_info_cache.get(&session_index).is_none() {
107 if let Some(info) =
108 Self::query_info_from_runtime(ctx, runtime, parent, session_index).await?
109 {
110 gum::trace!(target: LOG_TARGET, session_index, "Storing session info in lru!");
111 self.session_info_cache.insert(session_index, info);
112 } else {
113 return Ok(None)
114 }
115 }
116
117 Ok(self.session_info_cache.get(&session_index).map(|i| &*i))
118 }
119
120 pub fn report_bad_log(&mut self, report: BadValidators) {
125 if let Err(err) = self.report_bad(report) {
126 gum::warn!(
127 target: LOG_TARGET,
128 err = ?err,
129 "Reporting bad validators failed with error"
130 );
131 }
132 }
133
134 pub fn report_bad(&mut self, report: BadValidators) -> Result<()> {
139 let available_sessions = self.session_info_cache.iter().map(|(k, _)| *k).collect();
140 let session = self.session_info_cache.get(&report.session_index).ok_or(
141 Error::NoSuchCachedSession {
142 available_sessions,
143 missing_session: report.session_index,
144 },
145 )?;
146 let group = session.validator_groups.get_mut(report.group_index.0 as usize).expect(
147 "A bad validator report must contain a valid group for the reported session. qed.",
148 );
149 let bad_set = report.bad_validators.iter().collect::<HashSet<_>>();
150
151 group.retain(|v| !bad_set.contains(v));
153
154 let mut new_group = report.bad_validators;
156 new_group.append(group);
157 *group = new_group;
158 Ok(())
159 }
160
161 async fn query_info_from_runtime<Context>(
169 ctx: &mut Context,
170 runtime: &mut RuntimeInfo,
171 relay_parent: Hash,
172 session_index: SessionIndex,
173 ) -> Result<Option<SessionInfo>> {
174 let info = runtime
175 .get_session_info_by_index(ctx.sender(), relay_parent, session_index)
176 .await?;
177
178 let node_features = request_node_features(relay_parent, session_index, ctx.sender())
179 .await
180 .await?
181 .map_err(Error::FailedNodeFeatures)?;
182
183 let discovery_keys = info.session_info.discovery_keys.clone();
184 let mut validator_groups = info.session_info.validator_groups.clone();
185
186 if let Some(our_index) = info.validator_info.our_index {
187 let our_group = info.validator_info.our_group;
189
190 let mut rng = thread_rng();
192 for g in validator_groups.iter_mut() {
193 g.shuffle(&mut rng)
194 }
195 let validator_groups: Vec<Vec<_>> = validator_groups
197 .into_iter()
198 .map(|group| {
199 group
200 .into_iter()
201 .map(|index| {
202 discovery_keys.get(index.0 as usize)
203 .expect("There should be a discovery key for each validator of each validator group. qed.")
204 .clone()
205 })
206 .collect()
207 })
208 .collect();
209
210 let info = SessionInfo {
211 validator_groups,
212 our_index,
213 session_index,
214 our_group,
215 node_features,
216 };
217 return Ok(Some(info))
218 }
219 return Ok(None)
220 }
221}