use std::{
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::Duration,
};
#[derive(Debug)]
pub struct Timeout;
impl std::error::Error for Timeout {}
impl std::fmt::Display for Timeout {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
fmt.write_str("task has been running too long")
}
}
#[repr(transparent)]
pub struct IsTimedOut(Arc<AtomicBool>);
impl IsTimedOut {
#[must_use]
pub fn check_if_timed_out(&self) -> std::result::Result<(), Timeout> {
if self.0.load(Ordering::Relaxed) {
Err(Timeout)
} else {
Ok(())
}
}
}
#[derive(Debug)]
pub enum SpawnWithTimeoutError {
JoinError(tokio::task::JoinError),
Timeout,
}
impl std::error::Error for SpawnWithTimeoutError {}
impl std::fmt::Display for SpawnWithTimeoutError {
fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
SpawnWithTimeoutError::JoinError(error) => error.fmt(fmt),
SpawnWithTimeoutError::Timeout => Timeout.fmt(fmt),
}
}
}
struct CancelOnDrop(Arc<AtomicBool>);
impl Drop for CancelOnDrop {
fn drop(&mut self) {
self.0.store(true, Ordering::Relaxed)
}
}
pub async fn spawn_blocking_with_timeout<R>(
timeout: Option<Duration>,
callback: impl FnOnce(IsTimedOut) -> std::result::Result<R, Timeout> + Send + 'static,
) -> Result<R, SpawnWithTimeoutError>
where
R: Send + 'static,
{
let is_timed_out_arc = Arc::new(AtomicBool::new(false));
let is_timed_out = IsTimedOut(is_timed_out_arc.clone());
let _cancel_on_drop = CancelOnDrop(is_timed_out_arc);
let task = tokio::task::spawn_blocking(move || callback(is_timed_out));
let result = if let Some(timeout) = timeout {
tokio::select! {
biased;
task_result = task => task_result,
_ = tokio::time::sleep(timeout) => Ok(Err(Timeout))
}
} else {
task.await
};
match result {
Ok(Ok(result)) => Ok(result),
Ok(Err(Timeout)) => Err(SpawnWithTimeoutError::Timeout),
Err(error) => Err(SpawnWithTimeoutError::JoinError(error)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn spawn_blocking_with_timeout_works() {
let task: Result<(), SpawnWithTimeoutError> =
spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
std::thread::sleep(Duration::from_millis(200));
is_timed_out.check_if_timed_out()?;
unreachable!();
})
.await;
assert_matches::assert_matches!(task, Err(SpawnWithTimeoutError::Timeout));
let task = spawn_blocking_with_timeout(Some(Duration::from_millis(100)), |is_timed_out| {
std::thread::sleep(Duration::from_millis(20));
is_timed_out.check_if_timed_out()?;
Ok(())
})
.await;
assert_matches::assert_matches!(task, Ok(()));
}
}