1// This file is part of Substrate.
23// Copyright (C) Parity Technologies (UK) Ltd.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
56// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
1011// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
1516// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
1819use std::{
20 sync::{
21 atomic::{AtomicBool, Ordering},
22 Arc,
23 },
24 time::Duration,
25};
2627/// An error signifying that a task has been cancelled due to a timeout.
28#[derive(Debug)]
29pub struct Timeout;
3031impl std::error::Error for Timeout {}
32impl std::fmt::Display for Timeout {
33fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
34 fmt.write_str("task has been running too long")
35 }
36}
3738/// A handle which can be used to check whether the task has been cancelled due to a timeout.
39#[repr(transparent)]
40pub struct IsTimedOut(Arc<AtomicBool>);
4142impl IsTimedOut {
43#[must_use]
44pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> {
45if self.0.load(Ordering::Relaxed) {
46Err(Timeout)
47 } else {
48Ok(())
49 }
50 }
51}
5253/// An error for a task which either panicked, or has been cancelled due to a timeout.
54#[derive(Debug)]
55pub enum SpawnWithTimeoutError {
56 JoinError(tokio::task::JoinError),
57 Timeout,
58}
5960impl std::error::Error for SpawnWithTimeoutError {}
61impl std::fmt::Display for SpawnWithTimeoutError {
62fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
63match self {
64 SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt),
65 SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt),
66 }
67 }
68}
6970struct CancelOnDrop(Arc<AtomicBool>);
71impl Drop for CancelOnDrop {
72fn drop(&mut self) {
73self.0.store(true, Ordering::Relaxed)
74 }
75}
7677/// Spawns a new blocking task with a given `timeout`.
78///
79/// The `callback` should continuously call [`IsTimedOut::check_if_timed_out`],
80/// which will return an error once the task runs for longer than `timeout`.
81///
82/// If `timeout` is `None` then this works just as a regular `spawn_blocking`.
83pub async fn spawn_blocking_with_timeout<R>(
84 timeout: Option<Duration>,
85 callback: impl FnOnce(IsTimedOut) -> std::result::Result<R, Timeout> + Send + 'static,
86) -> Result<R, SpawnWithTimeoutError>
87where
88R: Send + 'static,
89{
90let is_timed_out_arc = Arc::new(AtomicBool::new(false));
91let is_timed_out = IsTimedOut(is_timed_out_arc.clone());
92let _cancel_on_drop = CancelOnDrop(is_timed_out_arc);
93let task = tokio::task::spawn_blocking(move || callback(is_timed_out));
9495let result = if let Some(timeout) = timeout {
96tokio::select! {
97// Shouldn't really matter, but make sure the task is polled before the timeout,
98 // in case the task finishes after the timeout and the timeout is really short.
99biased;
100101 task_result = task => task_result,
102_ = tokio::time::sleep(timeout) => Ok(Err(Timeout))
103 }
104 } else {
105 task.await
106};
107108match result {
109Ok(Ok(result)) => Ok(result),
110Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout),
111Err(error) => Err(SpawnWithTimeoutError::JoinError(error)),
112 }
113}
114115#[cfg(test)]
116mod tests {
117use super::*;
118119#[tokio::test]
120async fn spawn_blocking_with_timeout_works() {
121let task: Result<(), SpawnWithTimeoutError> =
122 spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
123 std::thread::sleep(Duration::from_millis(200));
124 is_timed_out.check_if_timed_out()?;
125unreachable!();
126 })
127 .await;
128129assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout));
130131let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
132 std::thread::sleep(Duration::from_millis(20));
133 is_timed_out.check_if_timed_out()?;
134Ok(())
135 })
136 .await;
137138assert_matches::assert_matches!(task, Ok(()));
139 }
140}