futures_executor/
thread_pool.rs

1use crate::enter;
2use crate::unpark_mutex::UnparkMutex;
3use futures_core::future::Future;
4use futures_core::task::{Context, Poll};
5use futures_task::{waker_ref, ArcWake};
6use futures_task::{FutureObj, Spawn, SpawnError};
7use futures_util::future::FutureExt;
8use std::cmp;
9use std::fmt;
10use std::io;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16/// A general-purpose thread pool for scheduling tasks that poll futures to
17/// completion.
18///
19/// The thread pool multiplexes any number of tasks onto a fixed number of
20/// worker threads.
21///
22/// This type is a clonable handle to the threadpool itself.
23/// Cloning it will only create a new reference, not a new threadpool.
24///
25/// This type is only available when the `thread-pool` feature of this
26/// library is activated.
27#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
28pub struct ThreadPool {
29    state: Arc<PoolState>,
30}
31
32/// Thread pool configuration object.
33///
34/// This type is only available when the `thread-pool` feature of this
35/// library is activated.
36#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
37pub struct ThreadPoolBuilder {
38    pool_size: usize,
39    stack_size: usize,
40    name_prefix: Option<String>,
41    after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
42    before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
43}
44
45trait AssertSendSync: Send + Sync {}
46impl AssertSendSync for ThreadPool {}
47
48struct PoolState {
49    tx: Mutex<mpsc::Sender<Message>>,
50    rx: Mutex<mpsc::Receiver<Message>>,
51    cnt: AtomicUsize,
52    size: usize,
53}
54
55impl fmt::Debug for ThreadPool {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
58    }
59}
60
61impl fmt::Debug for ThreadPoolBuilder {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("ThreadPoolBuilder")
64            .field("pool_size", &self.pool_size)
65            .field("name_prefix", &self.name_prefix)
66            .finish()
67    }
68}
69
70enum Message {
71    Run(Task),
72    Close,
73}
74
75impl ThreadPool {
76    /// Creates a new thread pool with the default configuration.
77    ///
78    /// See documentation for the methods in
79    /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
80    /// configuration.
81    pub fn new() -> Result<Self, io::Error> {
82        ThreadPoolBuilder::new().create()
83    }
84
85    /// Create a default thread pool configuration, which can then be customized.
86    ///
87    /// See documentation for the methods in
88    /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
89    /// configuration.
90    pub fn builder() -> ThreadPoolBuilder {
91        ThreadPoolBuilder::new()
92    }
93
94    /// Spawns a future that will be run to completion.
95    ///
96    /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
97    /// >           it is guaranteed to always succeed.
98    pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
99        let task = Task {
100            future,
101            wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
102            exec: self.clone(),
103        };
104        self.state.send(Message::Run(task));
105    }
106
107    /// Spawns a task that polls the given future with output `()` to
108    /// completion.
109    ///
110    /// ```
111    /// # {
112    /// use futures::executor::ThreadPool;
113    ///
114    /// let pool = ThreadPool::new().unwrap();
115    ///
116    /// let future = async { /* ... */ };
117    /// pool.spawn_ok(future);
118    /// # }
119    /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
120    /// ```
121    ///
122    /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
123    /// >           it is guaranteed to always succeed.
124    pub fn spawn_ok<Fut>(&self, future: Fut)
125    where
126        Fut: Future<Output = ()> + Send + 'static,
127    {
128        self.spawn_obj_ok(FutureObj::new(Box::new(future)))
129    }
130}
131
132impl Spawn for ThreadPool {
133    fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
134        self.spawn_obj_ok(future);
135        Ok(())
136    }
137}
138
139impl PoolState {
140    fn send(&self, msg: Message) {
141        self.tx.lock().unwrap().send(msg).unwrap();
142    }
143
144    fn work(
145        &self,
146        idx: usize,
147        after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
148        before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
149    ) {
150        let _scope = enter().unwrap();
151        if let Some(after_start) = after_start {
152            after_start(idx);
153        }
154        loop {
155            let msg = self.rx.lock().unwrap().recv().unwrap();
156            match msg {
157                Message::Run(task) => task.run(),
158                Message::Close => break,
159            }
160        }
161        if let Some(before_stop) = before_stop {
162            before_stop(idx);
163        }
164    }
165}
166
167impl Clone for ThreadPool {
168    fn clone(&self) -> Self {
169        self.state.cnt.fetch_add(1, Ordering::Relaxed);
170        Self { state: self.state.clone() }
171    }
172}
173
174impl Drop for ThreadPool {
175    fn drop(&mut self) {
176        if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
177            for _ in 0..self.state.size {
178                self.state.send(Message::Close);
179            }
180        }
181    }
182}
183
184impl ThreadPoolBuilder {
185    /// Create a default thread pool configuration.
186    ///
187    /// See the other methods on this type for details on the defaults.
188    pub fn new() -> Self {
189        Self {
190            pool_size: cmp::max(1, num_cpus::get()),
191            stack_size: 0,
192            name_prefix: None,
193            after_start: None,
194            before_stop: None,
195        }
196    }
197
198    /// Set size of a future ThreadPool
199    ///
200    /// The size of a thread pool is the number of worker threads spawned. By
201    /// default, this is equal to the number of CPU cores.
202    ///
203    /// # Panics
204    ///
205    /// Panics if `pool_size == 0`.
206    pub fn pool_size(&mut self, size: usize) -> &mut Self {
207        assert!(size > 0);
208        self.pool_size = size;
209        self
210    }
211
212    /// Set stack size of threads in the pool, in bytes.
213    ///
214    /// By default, worker threads use Rust's standard stack size.
215    pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
216        self.stack_size = stack_size;
217        self
218    }
219
220    /// Set thread name prefix of a future ThreadPool.
221    ///
222    /// Thread name prefix is used for generating thread names. For example, if prefix is
223    /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
224    ///
225    /// By default, worker threads are assigned Rust's standard thread name.
226    pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
227        self.name_prefix = Some(name_prefix.into());
228        self
229    }
230
231    /// Execute the closure `f` immediately after each worker thread is started,
232    /// but before running any tasks on it.
233    ///
234    /// This hook is intended for bookkeeping and monitoring.
235    /// The closure `f` will be dropped after the `builder` is dropped
236    /// and all worker threads in the pool have executed it.
237    ///
238    /// The closure provided will receive an index corresponding to the worker
239    /// thread it's running on.
240    pub fn after_start<F>(&mut self, f: F) -> &mut Self
241    where
242        F: Fn(usize) + Send + Sync + 'static,
243    {
244        self.after_start = Some(Arc::new(f));
245        self
246    }
247
248    /// Execute closure `f` just prior to shutting down each worker thread.
249    ///
250    /// This hook is intended for bookkeeping and monitoring.
251    /// The closure `f` will be dropped after the `builder` is dropped
252    /// and all threads in the pool have executed it.
253    ///
254    /// The closure provided will receive an index corresponding to the worker
255    /// thread it's running on.
256    pub fn before_stop<F>(&mut self, f: F) -> &mut Self
257    where
258        F: Fn(usize) + Send + Sync + 'static,
259    {
260        self.before_stop = Some(Arc::new(f));
261        self
262    }
263
264    /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
265    pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
266        let (tx, rx) = mpsc::channel();
267        let pool = ThreadPool {
268            state: Arc::new(PoolState {
269                tx: Mutex::new(tx),
270                rx: Mutex::new(rx),
271                cnt: AtomicUsize::new(1),
272                size: self.pool_size,
273            }),
274        };
275
276        for counter in 0..self.pool_size {
277            let state = pool.state.clone();
278            let after_start = self.after_start.clone();
279            let before_stop = self.before_stop.clone();
280            let mut thread_builder = thread::Builder::new();
281            if let Some(ref name_prefix) = self.name_prefix {
282                thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
283            }
284            if self.stack_size > 0 {
285                thread_builder = thread_builder.stack_size(self.stack_size);
286            }
287            thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
288        }
289        Ok(pool)
290    }
291}
292
293impl Default for ThreadPoolBuilder {
294    fn default() -> Self {
295        Self::new()
296    }
297}
298
299/// A task responsible for polling a future to completion.
300struct Task {
301    future: FutureObj<'static, ()>,
302    exec: ThreadPool,
303    wake_handle: Arc<WakeHandle>,
304}
305
306struct WakeHandle {
307    mutex: UnparkMutex<Task>,
308    exec: ThreadPool,
309}
310
311impl Task {
312    /// Actually run the task (invoking `poll` on the future) on the current
313    /// thread.
314    fn run(self) {
315        let Self { mut future, wake_handle, mut exec } = self;
316        let waker = waker_ref(&wake_handle);
317        let mut cx = Context::from_waker(&waker);
318
319        // Safety: The ownership of this `Task` object is evidence that
320        // we are in the `POLLING`/`REPOLL` state for the mutex.
321        unsafe {
322            wake_handle.mutex.start_poll();
323
324            loop {
325                let res = future.poll_unpin(&mut cx);
326                match res {
327                    Poll::Pending => {}
328                    Poll::Ready(()) => return wake_handle.mutex.complete(),
329                }
330                let task = Self { future, wake_handle: wake_handle.clone(), exec };
331                match wake_handle.mutex.wait(task) {
332                    Ok(()) => return, // we've waited
333                    Err(task) => {
334                        // someone's notified us
335                        future = task.future;
336                        exec = task.exec;
337                    }
338                }
339            }
340        }
341    }
342}
343
344impl fmt::Debug for Task {
345    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346        f.debug_struct("Task").field("contents", &"...").finish()
347    }
348}
349
350impl ArcWake for WakeHandle {
351    fn wake_by_ref(arc_self: &Arc<Self>) {
352        if let Ok(task) = arc_self.mutex.notify() {
353            arc_self.exec.state.send(Message::Run(task))
354        }
355    }
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361    use std::sync::mpsc;
362
363    #[test]
364    fn test_drop_after_start() {
365        {
366            let (tx, rx) = mpsc::sync_channel(2);
367            let _cpu_pool = ThreadPoolBuilder::new()
368                .pool_size(2)
369                .after_start(move |_| tx.send(1).unwrap())
370                .create()
371                .unwrap();
372
373            // After ThreadPoolBuilder is deconstructed, the tx should be dropped
374            // so that we can use rx as an iterator.
375            let count = rx.into_iter().count();
376            assert_eq!(count, 2);
377        }
378        std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
379    }
380}