referrerpolicy=no-referrer-when-downgrade

sc_rpc/state/
utils.rs

1// This file is part of Substrate.
2
3// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19use std::{
20	sync::{
21		atomic::{AtomicBool, Ordering},
22		Arc,
23	},
24	time::Duration,
25};
26
27/// An error signifying that a task has been cancelled due to a timeout.
28#[derive(Debug)]
29pub struct Timeout;
30
31impl std::error::Error for Timeout {}
32impl std::fmt::Display for Timeout {
33	fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
34		fmt.write_str("task has been running too long")
35	}
36}
37
38/// A handle which can be used to check whether the task has been cancelled due to a timeout.
39#[repr(transparent)]
40pub struct IsTimedOut(Arc<AtomicBool>);
41
42impl IsTimedOut {
43	#[must_use]
44	pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> {
45		if self.0.load(Ordering::Relaxed) {
46			Err(Timeout)
47		} else {
48			Ok(())
49		}
50	}
51}
52
53/// An error for a task which either panicked, or has been cancelled due to a timeout.
54#[derive(Debug)]
55pub enum SpawnWithTimeoutError {
56	JoinError(tokio::task::JoinError),
57	Timeout,
58}
59
60impl std::error::Error for SpawnWithTimeoutError {}
61impl std::fmt::Display for SpawnWithTimeoutError {
62	fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
63		match self {
64			SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt),
65			SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt),
66		}
67	}
68}
69
70struct CancelOnDrop(Arc<AtomicBool>);
71impl Drop for CancelOnDrop {
72	fn drop(&mut self) {
73		self.0.store(true, Ordering::Relaxed)
74	}
75}
76
77/// Spawns a new blocking task with a given `timeout`.
78///
79/// The `callback` should continuously call [`IsTimedOut::check_if_timed_out`],
80/// which will return an error once the task runs for longer than `timeout`.
81///
82/// If `timeout` is `None` then this works just as a regular `spawn_blocking`.
83pub async fn spawn_blocking_with_timeout<R>(
84	timeout: Option<Duration>,
85	callback: impl FnOnce(IsTimedOut) -> std::result::Result<R, Timeout> + Send + 'static,
86) -> Result<R, SpawnWithTimeoutError>
87where
88	R: Send + 'static,
89{
90	let is_timed_out_arc = Arc::new(AtomicBool::new(false));
91	let is_timed_out = IsTimedOut(is_timed_out_arc.clone());
92	let _cancel_on_drop = CancelOnDrop(is_timed_out_arc);
93	let task = tokio::task::spawn_blocking(move || callback(is_timed_out));
94
95	let result = if let Some(timeout) = timeout {
96		tokio::select! {
97			// Shouldn't really matter, but make sure the task is polled before the timeout,
98			// in case the task finishes after the timeout and the timeout is really short.
99			biased;
100
101			task_result = task => task_result,
102			_ = tokio::time::sleep(timeout) => Ok(Err(Timeout))
103		}
104	} else {
105		task.await
106	};
107
108	match result {
109		Ok(Ok(result)) => Ok(result),
110		Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout),
111		Err(error) => Err(SpawnWithTimeoutError::JoinError(error)),
112	}
113}
114
115#[cfg(test)]
116mod tests {
117	use super::*;
118
119	#[tokio::test]
120	async fn spawn_blocking_with_timeout_works() {
121		let task: Result<(), SpawnWithTimeoutError> =
122			spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
123				std::thread::sleep(Duration::from_millis(200));
124				is_timed_out.check_if_timed_out()?;
125				unreachable!();
126			})
127			.await;
128
129		assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout));
130
131		let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
132			std::thread::sleep(Duration::from_millis(20));
133			is_timed_out.check_if_timed_out()?;
134			Ok(())
135		})
136		.await;
137
138		assert_matches::assert_matches!(task, Ok(()));
139	}
140}