prioritized_metered_channel/
bounded.rs

1// Copyright 2017-2021 Parity Technologies (UK) Ltd.
2// This file is part of Polkadot.
3
4// Polkadot is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// Polkadot is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with Polkadot.  If not, see <http://www.gnu.org/licenses/>.
16
17#[cfg(feature = "async_channel")]
18use async_channel::{
19	bounded as bounded_channel, Receiver, Sender, TryRecvError, TrySendError as ChannelTrySendError,
20};
21
22#[cfg(feature = "futures_channel")]
23use futures::{
24	channel::mpsc::channel as bounded_channel,
25	channel::mpsc::{Receiver, Sender, TryRecvError, TrySendError as FuturesTrySendError},
26	sink::SinkExt,
27};
28
29use futures::{
30	stream::Stream,
31	task::{Context, Poll},
32};
33use std::{pin::Pin, result};
34
35use super::{prepare_with_tof, MaybeTimeOfFlight, Meter};
36
37/// Create a pair of `MeteredSender` and `MeteredReceiver`. No priorities are provided
38pub fn channel<T>(capacity: usize) -> (MeteredSender<T>, MeteredReceiver<T>) {
39	let (tx, rx) = bounded_channel::<MaybeTimeOfFlight<T>>(capacity);
40
41	let shared_meter = Meter::default();
42	let tx =
43		MeteredSender { meter: shared_meter.clone(), bulk_channel: tx, priority_channel: None };
44	let rx = MeteredReceiver { meter: shared_meter, bulk_channel: rx, priority_channel: None };
45	(tx, rx)
46}
47
48/// Create a pair of `MeteredSender` and `MeteredReceiver`. Priority channel is provided
49pub fn channel_with_priority<T>(
50	capacity_bulk: usize,
51	capacity_priority: usize,
52) -> (MeteredSender<T>, MeteredReceiver<T>) {
53	let (tx, rx) = bounded_channel::<MaybeTimeOfFlight<T>>(capacity_bulk);
54	let (tx_pri, rx_pri) = bounded_channel::<MaybeTimeOfFlight<T>>(capacity_priority);
55
56	let shared_meter = Meter::default();
57	let tx = MeteredSender {
58		meter: shared_meter.clone(),
59		bulk_channel: tx,
60		priority_channel: Some(tx_pri),
61	};
62	let rx =
63		MeteredReceiver { meter: shared_meter, bulk_channel: rx, priority_channel: Some(rx_pri) };
64	(tx, rx)
65}
66
67/// A receiver tracking the messages consumed by itself.
68#[derive(Debug)]
69pub struct MeteredReceiver<T> {
70	// count currently contained messages
71	meter: Meter,
72	bulk_channel: Receiver<MaybeTimeOfFlight<T>>,
73	priority_channel: Option<Receiver<MaybeTimeOfFlight<T>>>,
74}
75
76/// A bounded channel error
77#[derive(thiserror::Error, Debug)]
78pub enum SendError<T> {
79	#[error("Bounded channel has been closed")]
80	Closed(T),
81	#[error("Bounded channel has been closed and the original message is lost")]
82	Terminated,
83}
84
85impl<T> SendError<T> {
86	/// Returns the inner value.
87	pub fn into_inner(self) -> Option<T> {
88		match self {
89			Self::Closed(t) => Some(t),
90			Self::Terminated => None,
91		}
92	}
93}
94
95/// A bounded channel error when trying to send a message (transparently wraps the inner error type)
96#[derive(thiserror::Error, Debug)]
97pub enum TrySendError<T> {
98	#[error("Bounded channel has been closed")]
99	Closed(T),
100	#[error("Bounded channel is full")]
101	Full(T),
102}
103
104#[cfg(feature = "async_channel")]
105impl<T> From<ChannelTrySendError<MaybeTimeOfFlight<T>>> for TrySendError<T> {
106	fn from(error: ChannelTrySendError<MaybeTimeOfFlight<T>>) -> Self {
107		match error {
108			ChannelTrySendError::Closed(val) => Self::Closed(val.into()),
109			ChannelTrySendError::Full(val) => Self::Full(val.into()),
110		}
111	}
112}
113
114#[cfg(feature = "async_channel")]
115impl<T> From<ChannelTrySendError<T>> for TrySendError<T> {
116	fn from(error: ChannelTrySendError<T>) -> Self {
117		match error {
118			ChannelTrySendError::Closed(val) => Self::Closed(val),
119			ChannelTrySendError::Full(val) => Self::Full(val),
120		}
121	}
122}
123
124#[cfg(feature = "futures_channel")]
125impl<T> From<FuturesTrySendError<MaybeTimeOfFlight<T>>> for TrySendError<T> {
126	fn from(error: FuturesTrySendError<MaybeTimeOfFlight<T>>) -> Self {
127		let disconnected = error.is_disconnected();
128		let val = error.into_inner();
129		let val = val.into();
130		if disconnected {
131			Self::Closed(val)
132		} else {
133			Self::Full(val)
134		}
135	}
136}
137
138#[cfg(feature = "futures_channel")]
139impl<T> From<FuturesTrySendError<T>> for TrySendError<T> {
140	fn from(error: FuturesTrySendError<T>) -> Self {
141		let disconnected = error.is_disconnected();
142		let val = error.into_inner();
143		if disconnected {
144			Self::Closed(val)
145		} else {
146			Self::Full(val)
147		}
148	}
149}
150
151impl<T> TrySendError<T> {
152	/// Returns the inner value.
153	pub fn into_inner(self) -> T {
154		match self {
155			Self::Closed(t) => t,
156			Self::Full(t) => t,
157		}
158	}
159
160	/// Returns `true` if we could not send to channel as it was full
161	pub fn is_full(&self) -> bool {
162		match self {
163			Self::Closed(_) => false,
164			Self::Full(_) => true,
165		}
166	}
167
168	/// Returns `true` if we could not send to channel as it was disconnected
169	pub fn is_disconnected(&self) -> bool {
170		match self {
171			Self::Closed(_) => true,
172			Self::Full(_) => false,
173		}
174	}
175
176	/// Transform the inner value.
177	pub fn transform_inner<U, F>(self, f: F) -> TrySendError<U>
178	where
179		F: FnOnce(T) -> U,
180	{
181		match self {
182			Self::Closed(t) => TrySendError::<U>::Closed(f(t)),
183			Self::Full(t) => TrySendError::<U>::Full(f(t)),
184		}
185	}
186
187	/// Transform the inner value, fail-able version.
188	pub fn try_transform_inner<U, F, E>(self, f: F) -> std::result::Result<TrySendError<U>, E>
189	where
190		F: FnOnce(T) -> std::result::Result<U, E>,
191		E: std::fmt::Debug + std::error::Error + Send + Sync + 'static,
192	{
193		Ok(match self {
194			Self::Closed(t) => TrySendError::<U>::Closed(f(t)?),
195			Self::Full(t) => TrySendError::<U>::Full(f(t)?),
196		})
197	}
198}
199
200/// Error when receiving from a closed bounded channel
201#[derive(thiserror::Error, PartialEq, Eq, Clone, Copy, Debug)]
202pub struct RecvError {}
203
204impl std::fmt::Display for RecvError {
205	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206		write!(f, "receiving from an empty and closed channel")
207	}
208}
209
210#[cfg(feature = "async_channel")]
211impl From<async_channel::RecvError> for RecvError {
212	fn from(_: async_channel::RecvError) -> Self {
213		RecvError {}
214	}
215}
216
217impl<T> std::ops::Deref for MeteredReceiver<T> {
218	type Target = Receiver<MaybeTimeOfFlight<T>>;
219	fn deref(&self) -> &Self::Target {
220		&self.bulk_channel
221	}
222}
223
224impl<T> std::ops::DerefMut for MeteredReceiver<T> {
225	fn deref_mut(&mut self) -> &mut Self::Target {
226		&mut self.bulk_channel
227	}
228}
229
230impl<T> Stream for MeteredReceiver<T> {
231	type Item = T;
232	fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
233		if let Some(priority_channel) = &mut self.priority_channel {
234			match Receiver::poll_next(Pin::new(priority_channel), cx) {
235				Poll::Ready(maybe_value) => return Poll::Ready(self.maybe_meter_tof(maybe_value)),
236				Poll::Pending => {},
237			}
238		}
239		match Receiver::poll_next(Pin::new(&mut self.bulk_channel), cx) {
240			Poll::Ready(maybe_value) => Poll::Ready(self.maybe_meter_tof(maybe_value)),
241			Poll::Pending => Poll::Pending,
242		}
243	}
244
245	/// Don't rely on the unreliable size hint.
246	fn size_hint(&self) -> (usize, Option<usize>) {
247		self.bulk_channel.size_hint()
248	}
249}
250
251impl<T> MeteredReceiver<T> {
252	fn maybe_meter_tof(&mut self, maybe_value: Option<MaybeTimeOfFlight<T>>) -> Option<T> {
253		self.meter.note_received();
254
255		maybe_value.map(|value| {
256			match value {
257				MaybeTimeOfFlight::<T>::WithTimeOfFlight(value, tof_start) => {
258					// do not use `.elapsed()` of `std::time`, it may panic
259					// `coarsetime` does a saturating sub for all `CoarseInstant` substractions
260					let duration = tof_start.elapsed();
261					self.meter.note_time_of_flight(duration);
262					value
263				},
264				MaybeTimeOfFlight::<T>::Bare(value) => value,
265			}
266			.into()
267		})
268	}
269
270	/// Get an updated accessor object for all metrics collected.
271	pub fn meter(&self) -> &Meter {
272		// For async_channel we can update channel length in the meter access
273		// to avoid more expensive updates on each RW operation
274		#[cfg(feature = "async_channel")]
275		self.meter.note_channel_len(self.len());
276
277		&self.meter
278	}
279
280	/// Attempt to receive the next item.
281	/// This function returns:
282	///
283	///    `Ok(Some(t))` when message is fetched
284	///    `Ok(None)` when channel is closed and no messages left in the queue
285	///    `Err(e)` when there are no messages available, but channel is not yet closed
286	#[cfg(feature = "futures_channel")]
287	pub fn try_next(&mut self) -> Result<Option<T>, TryRecvError> {
288		if let Some(priority_channel) = &mut self.priority_channel {
289			match priority_channel.try_next() {
290				Ok(Some(value)) => return Ok(self.maybe_meter_tof(Some(value))),
291				Ok(None) => return Ok(None), // Channel is closed, inform the caller
292				Err(_) => {},                // Channel is not closed but empty, ignore the error
293			}
294		}
295		match self.bulk_channel.try_next()? {
296			Some(value) => Ok(self.maybe_meter_tof(Some(value))),
297			None => Ok(None),
298		}
299	}
300
301	/// Attempt to receive the next item.
302	/// This function returns:
303	///
304	///    `Ok(Some(t))` when message is fetched
305	///    `Ok(None)` when channel is closed and no messages left in the queue
306	///    `Err(e)` when there are no messages available, but channel is not yet closed
307	#[cfg(feature = "async_channel")]
308	pub fn try_next(&mut self) -> Result<Option<T>, TryRecvError> {
309		if let Some(priority_channel) = &mut self.priority_channel {
310			match priority_channel.try_recv() {
311				Ok(value) => return Ok(self.maybe_meter_tof(Some(value))),
312				Err(TryRecvError::Empty) => {},               // Continue to bulk
313				Err(TryRecvError::Closed) => return Ok(None), // Mimic futures_channel behaviour
314			}
315		}
316		match self.bulk_channel.try_recv() {
317			Ok(value) => Ok(self.maybe_meter_tof(Some(value))),
318			Err(TryRecvError::Empty) => Err(TryRecvError::Empty),
319			Err(TryRecvError::Closed) => Ok(None), // Mimic futures_channel behaviour
320		}
321	}
322
323	/// Receive the next item.
324	#[cfg(feature = "async_channel")]
325	pub async fn recv(&mut self) -> Result<T, RecvError> {
326		if let Some(priority_channel) = &mut self.priority_channel {
327			match priority_channel.try_recv() {
328				Ok(value) =>
329					return Ok(self
330						.maybe_meter_tof(Some(value))
331						.expect("wrapped value is always Some, qed")),
332				Err(err) => match err {
333					TryRecvError::Closed => return Err(RecvError {}),
334					TryRecvError::Empty => {}, // We can still have data in the bulk channel
335				},
336			}
337		}
338		match self.bulk_channel.recv().await {
339			Ok(value) =>
340				Ok(self.maybe_meter_tof(Some(value)).expect("wrapped value is always Some, qed")),
341			Err(err) => Err(err.into()),
342		}
343	}
344
345	/// Attempt to receive the next item without blocking
346	#[cfg(feature = "async_channel")]
347	pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
348		if let Some(priority_channel) = &mut self.priority_channel {
349			match priority_channel.try_recv() {
350				Ok(value) =>
351					return Ok(self
352						.maybe_meter_tof(Some(value))
353						.expect("wrapped value is always Some, qed")),
354				Err(err) => match err {
355					TryRecvError::Closed => return Err(err.into()),
356					TryRecvError::Empty => {},
357				},
358			}
359		}
360		match self.bulk_channel.try_recv() {
361			Ok(value) =>
362				Ok(self.maybe_meter_tof(Some(value)).expect("wrapped value is always Some, qed")),
363			Err(err) => Err(err),
364		}
365	}
366
367	#[cfg(feature = "async_channel")]
368	/// Returns the current number of messages in the channel
369	pub fn len(&self) -> usize {
370		self.bulk_channel.len() + self.priority_channel.as_ref().map_or(0, |c| c.len())
371	}
372
373	#[cfg(feature = "futures_channel")]
374	/// Returns the current number of messages in the channel based on meter approximation
375	pub fn len(&self) -> usize {
376		self.meter.calculate_channel_len()
377	}
378}
379
380impl<T> futures::stream::FusedStream for MeteredReceiver<T> {
381	fn is_terminated(&self) -> bool {
382		self.bulk_channel.is_terminated() &&
383			self.priority_channel.as_ref().map_or(true, |c| c.is_terminated())
384	}
385}
386
387/// The sender component, tracking the number of items
388/// sent across it.
389#[derive(Debug)]
390pub struct MeteredSender<T> {
391	meter: Meter,
392	bulk_channel: Sender<MaybeTimeOfFlight<T>>,
393	priority_channel: Option<Sender<MaybeTimeOfFlight<T>>>,
394}
395
396impl<T> Clone for MeteredSender<T> {
397	fn clone(&self) -> Self {
398		Self {
399			meter: self.meter.clone(),
400			bulk_channel: self.bulk_channel.clone(),
401			priority_channel: self.priority_channel.clone(),
402		}
403	}
404}
405
406impl<T> std::ops::Deref for MeteredSender<T> {
407	type Target = Sender<MaybeTimeOfFlight<T>>;
408	fn deref(&self) -> &Self::Target {
409		&self.bulk_channel
410	}
411}
412
413impl<T> std::ops::DerefMut for MeteredSender<T> {
414	fn deref_mut(&mut self) -> &mut Self::Target {
415		&mut self.bulk_channel
416	}
417}
418
419impl<T> MeteredSender<T> {
420	/// Get an updated accessor object for all metrics collected.
421	pub fn meter(&self) -> &Meter {
422		// For async_channel we can update channel length in the meter access
423		// to avoid more expensive updates on each RW operation
424		#[cfg(feature = "async_channel")]
425		self.meter.note_channel_len(self.len());
426		&self.meter
427	}
428
429	/// Send message in bulk channel, wait until capacity is available.
430	pub async fn send(&mut self, msg: T) -> result::Result<(), SendError<T>>
431	where
432		Self: Unpin,
433	{
434		self.send_inner(msg, false).await
435	}
436
437	/// Send message in priority channel (if configured), wait until capacity is available.
438	pub async fn priority_send(&mut self, msg: T) -> result::Result<(), SendError<T>>
439	where
440		Self: Unpin,
441	{
442		self.send_inner(msg, true).await
443	}
444
445	async fn send_inner(
446		&mut self,
447		msg: T,
448		use_priority_channel: bool,
449	) -> result::Result<(), SendError<T>>
450	where
451		Self: Unpin,
452	{
453		let res =
454			if use_priority_channel { self.try_priority_send(msg) } else { self.try_send(msg) };
455
456		match res {
457			Err(send_err) => {
458				if !send_err.is_full() {
459					return Err(SendError::Closed(send_err.into_inner().into()))
460				}
461
462				self.meter.note_blocked();
463				self.meter.note_sent(); // we are going to do full blocking send, so we have to note it here
464				let msg = send_err.into_inner().into();
465				self.send_to_channel(msg, use_priority_channel).await
466			},
467			_ => Ok(()),
468		}
469	}
470
471	// A helper routine to send a message to the channel after `try_send` returned that a channel is full
472	#[cfg(feature = "async_channel")]
473	async fn send_to_channel(
474		&mut self,
475		msg: MaybeTimeOfFlight<T>,
476		use_priority_channel: bool,
477	) -> result::Result<(), SendError<T>> {
478		let channel = if use_priority_channel {
479			self.priority_channel.as_mut().unwrap_or(&mut self.bulk_channel)
480		} else {
481			&mut self.bulk_channel
482		};
483
484		let fut = channel.send(msg);
485		futures::pin_mut!(fut);
486		let result = fut.await.map_err(|err| {
487			self.meter.retract_sent();
488			SendError::Closed(err.0.into())
489		});
490
491		result
492	}
493
494	#[cfg(feature = "futures_channel")]
495	async fn send_to_channel(
496		&mut self,
497		msg: MaybeTimeOfFlight<T>,
498		use_priority_channel: bool,
499	) -> result::Result<(), SendError<T>> {
500		let channel = if use_priority_channel {
501			self.priority_channel.as_mut().unwrap_or(&mut self.bulk_channel)
502		} else {
503			&mut self.bulk_channel
504		};
505		let fut = channel.send(msg);
506		futures::pin_mut!(fut);
507		fut.await.map_err(|_| {
508			self.meter.retract_sent();
509			// Futures channel does not provide a way to save the original message,
510			// so to avoid `T: Clone` bound we just return a generic error
511			SendError::Terminated
512		})
513	}
514
515	/// Attempt to send message or fail immediately.
516	pub fn try_send(&mut self, msg: T) -> result::Result<(), TrySendError<T>> {
517		let msg = prepare_with_tof(&self.meter, msg); // note_sent is called in here
518		self.bulk_channel.try_send(msg).map_err(|e| {
519			self.meter.retract_sent(); // we didn't send it, so we need to undo the note_send
520			TrySendError::from(e)
521		})
522	}
523
524	/// Attempt to send message or fail immediately.
525	pub fn try_priority_send(&mut self, msg: T) -> result::Result<(), TrySendError<T>> {
526		match self.priority_channel.as_mut() {
527			Some(priority_channel) => {
528				let msg = prepare_with_tof(&self.meter, msg);
529				priority_channel.try_send(msg).map_err(|e| {
530					self.meter.retract_sent(); // we didn't send it, so we need to undo the note_send
531					TrySendError::from(e)
532				})
533			},
534			None => self.try_send(msg), // use bulk channel as fallback
535		}
536	}
537
538	#[cfg(feature = "async_channel")]
539	/// Returns the current number of messages in the channel
540	pub fn len(&self) -> usize {
541		self.bulk_channel.len() + self.priority_channel.as_ref().map_or(0, |c| c.len())
542	}
543
544	#[cfg(feature = "futures_channel")]
545	/// Returns the current number of messages in the channel based on meter approximation
546	pub fn len(&self) -> usize {
547		self.meter.calculate_channel_len()
548	}
549}