1use crate::SubscriptionTaskExecutor;
22use futures::{
23 future::{self, Either, Fuse, FusedFuture},
24 Future, FutureExt, Stream, StreamExt,
25};
26use jsonrpsee::{
27 types::SubscriptionId, DisconnectError, PendingSubscriptionSink, SubscriptionMessage,
28 SubscriptionSink,
29};
30use sp_runtime::Serialize;
31use std::collections::VecDeque;
32
33const DEFAULT_BUF_SIZE: usize = 16;
34
35pub trait Buffer {
38 type Item;
40
41 fn push(&mut self, item: Self::Item) -> Result<(), ()>;
45 fn pop(&mut self) -> Option<Self::Item>;
47}
48
49pub struct BoundedVecDeque<T> {
51 inner: VecDeque<T>,
52 max_cap: usize,
53}
54
55impl<T> Default for BoundedVecDeque<T> {
56 fn default() -> Self {
57 Self { inner: VecDeque::with_capacity(DEFAULT_BUF_SIZE), max_cap: DEFAULT_BUF_SIZE }
58 }
59}
60
61impl<T> BoundedVecDeque<T> {
62 pub fn new(cap: usize) -> Self {
64 Self { inner: VecDeque::with_capacity(cap), max_cap: cap }
65 }
66}
67
68impl<T> Buffer for BoundedVecDeque<T> {
69 type Item = T;
70
71 fn push(&mut self, item: Self::Item) -> Result<(), ()> {
72 if self.inner.len() >= self.max_cap {
73 Err(())
74 } else {
75 self.inner.push_back(item);
76 Ok(())
77 }
78 }
79
80 fn pop(&mut self) -> Option<T> {
81 self.inner.pop_front()
82 }
83}
84
85#[derive(Debug)]
87pub struct RingBuffer<T> {
88 inner: VecDeque<T>,
89 cap: usize,
90}
91
92impl<T> RingBuffer<T> {
93 pub fn new(cap: usize) -> Self {
95 Self { inner: VecDeque::with_capacity(cap), cap }
96 }
97}
98
99impl<T> Buffer for RingBuffer<T> {
100 type Item = T;
101
102 fn push(&mut self, item: T) -> Result<(), ()> {
103 if self.inner.len() >= self.cap {
104 self.inner.pop_front();
105 }
106
107 self.inner.push_back(item);
108
109 Ok(())
110 }
111
112 fn pop(&mut self) -> Option<T> {
113 self.inner.pop_front()
114 }
115}
116
117pub struct PendingSubscription(PendingSubscriptionSink);
119
120impl From<PendingSubscriptionSink> for PendingSubscription {
121 fn from(p: PendingSubscriptionSink) -> Self {
122 Self(p)
123 }
124}
125
126impl PendingSubscription {
127 pub async fn pipe_from_stream<S, T, B>(self, mut stream: S, mut buf: B)
130 where
131 S: Stream<Item = T> + Unpin + Send + 'static,
132 T: Serialize + Send + 'static,
133 B: Buffer<Item = T>,
134 {
135 let method = self.0.method_name().to_string();
136 let conn_id = self.0.connection_id().0;
137 let accept_fut = self.0.accept();
138
139 futures::pin_mut!(accept_fut);
140
141 let sink = loop {
145 match future::select(accept_fut, stream.next()).await {
146 Either::Left((Ok(sink), _)) => break sink,
147 Either::Right((Some(msg), f)) => {
148 if buf.push(msg).is_err() {
149 log::debug!(target: "rpc", "Subscription::accept buffer full for subscription={method} conn_id={conn_id}; dropping subscription");
150 return
151 }
152 accept_fut = f;
153 },
154 _ => return,
156 }
157 };
158
159 Subscription(sink).pipe_from_stream(stream, buf).await
160 }
161}
162
163#[derive(Clone, Debug)]
165pub struct Subscription(SubscriptionSink);
166
167impl From<SubscriptionSink> for Subscription {
168 fn from(sink: SubscriptionSink) -> Self {
169 Self(sink)
170 }
171}
172
173impl Subscription {
174 pub async fn pipe_from_stream<S, T, B>(self, mut stream: S, mut buf: B)
177 where
178 S: Stream<Item = T> + Unpin + Send + 'static,
179 T: Serialize + Send + 'static,
180 B: Buffer<Item = T>,
181 {
182 let mut next_fut = Box::pin(Fuse::terminated());
183 let mut next_item = stream.next();
184 let closed = self.0.closed();
185
186 futures::pin_mut!(closed);
187
188 loop {
189 if next_fut.is_terminated() {
190 if let Some(v) = buf.pop() {
191 let val = self.to_sub_message(&v);
192 next_fut.set(async { self.0.send(val).await }.fuse());
193 }
194 }
195
196 match future::select(closed, future::select(next_fut, next_item)).await {
197 Either::Right((Either::Left((_, n)), c)) => {
199 next_item = n;
200 closed = c;
201 next_fut = Box::pin(Fuse::terminated());
202 },
203 Either::Right((Either::Right((Some(v), n)), c)) => {
205 if buf.push(v).is_err() {
206 log::debug!(
207 target: "rpc",
208 "Subscription buffer full for subscription={} conn_id={}; dropping subscription",
209 self.0.method_name(),
210 self.0.connection_id().0
211 );
212 return
213 }
214
215 next_fut = n;
216 closed = c;
217 next_item = stream.next();
218 },
219 Either::Right((Either::Right((None, pending_fut)), _)) => {
223 if !pending_fut.is_terminated() && pending_fut.await.is_err() {
224 return;
225 }
226
227 while let Some(v) = buf.pop() {
228 if self.send(&v).await.is_err() {
229 return;
230 }
231 }
232
233 return;
234 },
235 Either::Left(_) => return,
237 }
238 }
239 }
240
241 pub async fn send(&self, result: &impl Serialize) -> Result<(), DisconnectError> {
243 self.0.send(self.to_sub_message(result)).await
244 }
245
246 pub fn subscription_id(&self) -> SubscriptionId {
248 self.0.subscription_id()
249 }
250
251 pub async fn closed(&self) {
253 self.0.closed().await
254 }
255
256 fn to_sub_message(&self, result: &impl Serialize) -> SubscriptionMessage {
258 SubscriptionMessage::new(self.0.method_name(), self.0.subscription_id(), result)
259 .expect("Serialize infallible; qed")
260 }
261}
262
263pub fn spawn_subscription_task(
265 executor: &SubscriptionTaskExecutor,
266 fut: impl Future<Output = ()> + Send + 'static,
267) {
268 executor.spawn("substrate-rpc-subscription", Some("rpc"), fut.boxed());
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use futures::StreamExt;
275 use jsonrpsee::{core::EmptyServerParams, RpcModule, Subscription};
276
277 async fn subscribe() -> Subscription {
278 let mut module = RpcModule::new(());
279 module
280 .register_subscription("sub", "my_sub", "unsub", |_, pending, _, _| async move {
281 let stream = futures::stream::iter([0; 16]);
282 PendingSubscription::from(pending)
283 .pipe_from_stream(stream, BoundedVecDeque::new(16))
284 .await;
285 Ok(())
286 })
287 .unwrap();
288
289 module.subscribe("sub", EmptyServerParams::new(), 1).await.unwrap()
290 }
291
292 #[tokio::test]
293 async fn pipe_from_stream_works() {
294 let mut sub = subscribe().await;
295 let mut rx = 0;
296
297 while let Some(Ok(_)) = sub.next::<usize>().await {
298 rx += 1;
299 }
300
301 assert_eq!(rx, 16);
302 }
303
304 #[tokio::test]
305 async fn pipe_from_stream_with_bounded_vec() {
306 let (tx, mut rx) = futures::channel::mpsc::unbounded::<()>();
307
308 let mut module = RpcModule::new(tx);
309 module
310 .register_subscription("sub", "my_sub", "unsub", |_, pending, ctx, _| async move {
311 let stream = futures::stream::iter([0; 32]);
312 PendingSubscription::from(pending)
313 .pipe_from_stream(stream, BoundedVecDeque::new(16))
314 .await;
315 _ = ctx.unbounded_send(());
316 Ok(())
317 })
318 .unwrap();
319
320 let mut sub = module.subscribe("sub", EmptyServerParams::new(), 1).await.unwrap();
321
322 _ = rx.next().await.unwrap();
324 assert!(sub.next::<usize>().await.is_none());
325 }
326
327 #[tokio::test]
328 async fn subscription_is_dropped_when_stream_is_empty() {
329 let notify_rx = std::sync::Arc::new(tokio::sync::Notify::new());
330 let notify_tx = notify_rx.clone();
331
332 let mut module = RpcModule::new(notify_tx);
333 module
334 .register_subscription(
335 "sub",
336 "my_sub",
337 "unsub",
338 |_, pending, notify_tx, _| async move {
339 let stream = futures::stream::empty::<()>();
342 PendingSubscription::from(pending)
344 .pipe_from_stream(stream, BoundedVecDeque::default())
345 .await;
346 notify_tx.notify_one();
348 Ok(())
349 },
350 )
351 .unwrap();
352 module.subscribe("sub", EmptyServerParams::new(), 1).await.unwrap();
353
354 notify_rx.notified().await;
356 }
357
358 #[tokio::test]
359 async fn subscription_replace_old_messages() {
360 let mut module = RpcModule::new(());
361 module
362 .register_subscription("sub", "my_sub", "unsub", |_, pending, _, _| async move {
363 let stream = futures::stream::iter(0..20);
365 PendingSubscription::from(pending)
366 .pipe_from_stream(stream, RingBuffer::new(3))
367 .await;
368 Ok(())
369 })
370 .unwrap();
371
372 let mut sub = module.subscribe("sub", EmptyServerParams::new(), 1).await.unwrap();
373
374 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
377
378 let mut res = Vec::new();
379
380 while let Some(Ok((v, _))) = sub.next::<usize>().await {
381 res.push(v);
382 }
383
384 assert_eq!(res, vec![0, 17, 18, 19]);
387 }
388}