1use std::{
27 marker::PhantomData,
28 sync::{Arc, Weak},
29};
30
31use futures::StreamExt;
32use sc_client_api::{Backend, UnpinWorkerMessage};
33
34use sc_utils::mpsc::TracingUnboundedReceiver;
35use schnellru::Limiter;
36use sp_runtime::traits::Block as BlockT;
37
38const LOG_TARGET: &str = "db::notification_pinning";
39const NOTIFICATION_PINNING_LIMIT: usize = 1024;
40
41#[derive(Clone, Debug)]
43struct UnpinningByLengthLimiter<Block: BlockT, B: Backend<Block>> {
44 max_length: usize,
45 backend: Weak<B>,
46 _phantom: PhantomData<Block>,
47}
48
49impl<Block: BlockT, B: Backend<Block>> UnpinningByLengthLimiter<Block, B> {
50 pub fn new(max_length: usize, backend: Weak<B>) -> UnpinningByLengthLimiter<Block, B> {
52 UnpinningByLengthLimiter { max_length, backend, _phantom: PhantomData::<Block>::default() }
53 }
54}
55
56impl<Block: BlockT, B: Backend<Block>> Limiter<Block::Hash, u32>
57 for UnpinningByLengthLimiter<Block, B>
58{
59 type KeyToInsert<'a> = Block::Hash;
60 type LinkType = usize;
61
62 fn is_over_the_limit(&self, length: usize) -> bool {
63 length > self.max_length
64 }
65
66 fn on_insert(
67 &mut self,
68 _length: usize,
69 key: Self::KeyToInsert<'_>,
70 value: u32,
71 ) -> Option<(Block::Hash, u32)> {
72 log::debug!(target: LOG_TARGET, "Pinning block based on notification. hash = {key}");
73 if self.max_length > 0 {
74 Some((key, value))
75 } else {
76 None
77 }
78 }
79
80 fn on_replace(
81 &mut self,
82 _length: usize,
83 _old_key: &mut Block::Hash,
84 _new_key: Block::Hash,
85 _old_value: &mut u32,
86 _new_value: &mut u32,
87 ) -> bool {
88 true
89 }
90
91 fn on_removed(&mut self, key: &mut Block::Hash, references: &mut u32) {
92 if *references > 0 {
97 log::warn!(
98 target: LOG_TARGET,
99 "Notification block pinning limit reached. Unpinning block with hash = {key:?}"
100 );
101 if let Some(backend) = self.backend.upgrade() {
102 (0..*references).for_each(|_| backend.unpin_block(*key));
103 }
104 } else {
105 log::trace!(
106 target: LOG_TARGET,
107 "Unpinned block. hash = {key:?}",
108 )
109 }
110 }
111
112 fn on_cleared(&mut self) {}
113
114 fn on_grow(&mut self, _new_memory_usage: usize) -> bool {
115 true
116 }
117}
118
119pub struct NotificationPinningWorker<Block: BlockT, Back: Backend<Block>> {
125 unpin_message_rx: TracingUnboundedReceiver<UnpinWorkerMessage<Block>>,
126 task_backend: Weak<Back>,
127 pinned_blocks: schnellru::LruMap<Block::Hash, u32, UnpinningByLengthLimiter<Block, Back>>,
128}
129
130impl<Block: BlockT, Back: Backend<Block>> NotificationPinningWorker<Block, Back> {
131 pub fn new(
133 unpin_message_rx: TracingUnboundedReceiver<UnpinWorkerMessage<Block>>,
134 task_backend: Arc<Back>,
135 ) -> Self {
136 let pinned_blocks =
137 schnellru::LruMap::<Block::Hash, u32, UnpinningByLengthLimiter<Block, Back>>::new(
138 UnpinningByLengthLimiter::new(
139 NOTIFICATION_PINNING_LIMIT,
140 Arc::downgrade(&task_backend),
141 ),
142 );
143 Self { unpin_message_rx, task_backend: Arc::downgrade(&task_backend), pinned_blocks }
144 }
145
146 fn handle_announce_message(&mut self, hash: Block::Hash) {
147 if let Some(entry) = self.pinned_blocks.get_or_insert(hash, Default::default) {
148 *entry = *entry + 1;
149 }
150 }
151
152 fn handle_unpin_message(&mut self, hash: Block::Hash) -> Result<(), ()> {
153 if let Some(refcount) = self.pinned_blocks.peek_mut(&hash) {
154 *refcount = *refcount - 1;
155 if *refcount == 0 {
156 self.pinned_blocks.remove(&hash);
157 }
158 if let Some(backend) = self.task_backend.upgrade() {
159 log::debug!(target: LOG_TARGET, "Reducing pinning refcount for block hash = {hash:?}");
160 backend.unpin_block(hash);
161 } else {
162 log::debug!(target: LOG_TARGET, "Terminating unpin-worker, backend reference was dropped.");
163 return Err(())
164 }
165 } else {
166 log::debug!(target: LOG_TARGET, "Received unpin message for already unpinned block. hash = {hash:?}");
167 }
168 Ok(())
169 }
170
171 pub async fn run(mut self) {
176 while let Some(message) = self.unpin_message_rx.next().await {
177 match message {
178 UnpinWorkerMessage::AnnouncePin(hash) => self.handle_announce_message(hash),
179 UnpinWorkerMessage::Unpin(hash) =>
180 if self.handle_unpin_message(hash).is_err() {
181 return
182 },
183 }
184 }
185 log::debug!(target: LOG_TARGET, "Terminating unpin-worker, stream terminated.")
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use std::sync::Arc;
192
193 use sc_client_api::{Backend, UnpinWorkerMessage};
194 use sc_utils::mpsc::{tracing_unbounded, TracingUnboundedReceiver};
195 use sp_core::H256;
196 use sp_runtime::traits::Block as BlockT;
197
198 type Block = substrate_test_runtime_client::runtime::Block;
199
200 use super::{NotificationPinningWorker, UnpinningByLengthLimiter};
201
202 impl<Block: BlockT, Back: Backend<Block>> NotificationPinningWorker<Block, Back> {
203 fn new_with_limit(
204 unpin_message_rx: TracingUnboundedReceiver<UnpinWorkerMessage<Block>>,
205 task_backend: Arc<Back>,
206 limit: usize,
207 ) -> Self {
208 let pinned_blocks =
209 schnellru::LruMap::<Block::Hash, u32, UnpinningByLengthLimiter<Block, Back>>::new(
210 UnpinningByLengthLimiter::new(limit, Arc::downgrade(&task_backend)),
211 );
212 Self { unpin_message_rx, task_backend: Arc::downgrade(&task_backend), pinned_blocks }
213 }
214
215 fn lru(
216 &self,
217 ) -> &schnellru::LruMap<Block::Hash, u32, UnpinningByLengthLimiter<Block, Back>> {
218 &self.pinned_blocks
219 }
220 }
221
222 #[test]
223 fn pinning_worker_handles_base_case() {
224 let (_tx, rx) = tracing_unbounded("testing", 1000);
225
226 let backend = Arc::new(sc_client_api::in_mem::Backend::<Block>::new());
227
228 let hash = H256::random();
229
230 let mut worker = NotificationPinningWorker::new(rx, backend.clone());
231
232 let _ = backend.pin_block(hash);
234 assert_eq!(backend.pin_refs(&hash), Some(1));
235
236 worker.handle_announce_message(hash);
237 assert_eq!(worker.lru().len(), 1);
238
239 let _ = worker.handle_unpin_message(hash);
240
241 assert_eq!(backend.pin_refs(&hash), Some(0));
242 assert!(worker.lru().is_empty());
243 }
244
245 #[test]
246 fn pinning_worker_handles_multiple_pins() {
247 let (_tx, rx) = tracing_unbounded("testing", 1000);
248
249 let backend = Arc::new(sc_client_api::in_mem::Backend::<Block>::new());
250
251 let hash = H256::random();
252
253 let mut worker = NotificationPinningWorker::new(rx, backend.clone());
254 let _ = backend.pin_block(hash);
256 let _ = backend.pin_block(hash);
257 let _ = backend.pin_block(hash);
258 assert_eq!(backend.pin_refs(&hash), Some(3));
259
260 worker.handle_announce_message(hash);
261 worker.handle_announce_message(hash);
262 worker.handle_announce_message(hash);
263 assert_eq!(worker.lru().len(), 1);
264
265 let _ = worker.handle_unpin_message(hash);
266 assert_eq!(backend.pin_refs(&hash), Some(2));
267 let _ = worker.handle_unpin_message(hash);
268 assert_eq!(backend.pin_refs(&hash), Some(1));
269 let _ = worker.handle_unpin_message(hash);
270 assert_eq!(backend.pin_refs(&hash), Some(0));
271 assert!(worker.lru().is_empty());
272
273 let _ = worker.handle_unpin_message(hash);
274 assert_eq!(backend.pin_refs(&hash), Some(0));
275 }
276
277 #[test]
278 fn pinning_worker_handles_too_many_unpins() {
279 let (_tx, rx) = tracing_unbounded("testing", 1000);
280
281 let backend = Arc::new(sc_client_api::in_mem::Backend::<Block>::new());
282
283 let hash = H256::random();
284 let hash2 = H256::random();
285
286 let mut worker = NotificationPinningWorker::new(rx, backend.clone());
287 let _ = backend.pin_block(hash);
290 let _ = backend.pin_block(hash);
291 let _ = backend.pin_block(hash);
292 assert_eq!(backend.pin_refs(&hash), Some(3));
293
294 worker.handle_announce_message(hash);
295 assert_eq!(worker.lru().len(), 1);
296
297 let _ = worker.handle_unpin_message(hash);
298 assert_eq!(backend.pin_refs(&hash), Some(2));
299 let _ = worker.handle_unpin_message(hash);
300 assert_eq!(backend.pin_refs(&hash), Some(2));
301 assert!(worker.lru().is_empty());
302
303 let _ = worker.handle_unpin_message(hash2);
304 assert!(worker.lru().is_empty());
305 assert_eq!(backend.pin_refs(&hash2), None);
306 }
307
308 #[test]
309 fn pinning_worker_should_evict_when_limit_reached() {
310 let (_tx, rx) = tracing_unbounded("testing", 1000);
311
312 let backend = Arc::new(sc_client_api::in_mem::Backend::<Block>::new());
313
314 let hash1 = H256::random();
315 let hash2 = H256::random();
316 let hash3 = H256::random();
317 let hash4 = H256::random();
318
319 let mut worker = NotificationPinningWorker::new_with_limit(rx, backend.clone(), 2);
321
322 let _ = backend.pin_block(hash1);
325 let _ = backend.pin_block(hash2);
326 let _ = backend.pin_block(hash3);
327 assert_eq!(backend.pin_refs(&hash1), Some(1));
328 assert_eq!(backend.pin_refs(&hash2), Some(1));
329 assert_eq!(backend.pin_refs(&hash3), Some(1));
330
331 worker.handle_announce_message(hash1);
332 assert!(worker.lru().peek(&hash1).is_some());
333 worker.handle_announce_message(hash2);
334 assert!(worker.lru().peek(&hash2).is_some());
335 worker.handle_announce_message(hash3);
336 assert!(worker.lru().peek(&hash3).is_some());
337 assert!(worker.lru().peek(&hash2).is_some());
338 assert_eq!(worker.lru().len(), 2);
339
340 assert_eq!(backend.pin_refs(&hash1), Some(0));
342 assert_eq!(backend.pin_refs(&hash2), Some(1));
343 assert_eq!(backend.pin_refs(&hash3), Some(1));
344
345 worker.handle_announce_message(hash2);
347 assert_eq!(worker.lru().peek(&hash2), Some(&2));
348
349 worker.handle_announce_message(hash4);
351 assert_eq!(worker.lru().peek(&hash3), None);
352 }
353}