1use std::{
20 sync::{
21 atomic::{AtomicBool, Ordering},
22 Arc,
23 },
24 time::Duration,
25};
26
27#[derive(Debug)]
29pub struct Timeout;
30
31impl std::error::Error for Timeout {}
32impl std::fmt::Display for Timeout {
33 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
34 fmt.write_str("task has been running too long")
35 }
36}
37
38#[repr(transparent)]
40pub struct IsTimedOut(Arc<AtomicBool>);
41
42impl IsTimedOut {
43 #[must_use]
44 pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> {
45 if self.0.load(Ordering::Relaxed) {
46 Err(Timeout)
47 } else {
48 Ok(())
49 }
50 }
51}
52
53#[derive(Debug)]
55pub enum SpawnWithTimeoutError {
56 JoinError(tokio::task::JoinError),
57 Timeout,
58}
59
60impl std::error::Error for SpawnWithTimeoutError {}
61impl std::fmt::Display for SpawnWithTimeoutError {
62 fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
63 match self {
64 SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt),
65 SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt),
66 }
67 }
68}
69
70struct CancelOnDrop(Arc<AtomicBool>);
71impl Drop for CancelOnDrop {
72 fn drop(&mut self) {
73 self.0.store(true, Ordering::Relaxed)
74 }
75}
76
77pub async fn spawn_blocking_with_timeout<R>(
84 timeout: Option<Duration>,
85 callback: impl FnOnce(IsTimedOut) -> std::result::Result<R, Timeout> + Send + 'static,
86) -> Result<R, SpawnWithTimeoutError>
87where
88 R: Send + 'static,
89{
90 let is_timed_out_arc = Arc::new(AtomicBool::new(false));
91 let is_timed_out = IsTimedOut(is_timed_out_arc.clone());
92 let _cancel_on_drop = CancelOnDrop(is_timed_out_arc);
93 let task = tokio::task::spawn_blocking(move || callback(is_timed_out));
94
95 let result = if let Some(timeout) = timeout {
96 tokio::select! {
97 biased;
100
101 task_result = task => task_result,
102 _ = tokio::time::sleep(timeout) => Ok(Err(Timeout))
103 }
104 } else {
105 task.await
106 };
107
108 match result {
109 Ok(Ok(result)) => Ok(result),
110 Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout),
111 Err(error) => Err(SpawnWithTimeoutError::JoinError(error)),
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[tokio::test]
120 async fn spawn_blocking_with_timeout_works() {
121 let task: Result<(), SpawnWithTimeoutError> =
122 spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
123 std::thread::sleep(Duration::from_millis(200));
124 is_timed_out.check_if_timed_out()?;
125 unreachable!();
126 })
127 .await;
128
129 assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout));
130
131 let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
132 std::thread::sleep(Duration::from_millis(20));
133 is_timed_out.check_if_timed_out()?;
134 Ok(())
135 })
136 .await;
137
138 assert_matches::assert_matches!(task, Ok(()));
139 }
140}