1use crate::enter;
2use crate::unpark_mutex::UnparkMutex;
3use futures_core::future::Future;
4use futures_core::task::{Context, Poll};
5use futures_task::{waker_ref, ArcWake};
6use futures_task::{FutureObj, Spawn, SpawnError};
7use futures_util::future::FutureExt;
8use std::cmp;
9use std::fmt;
10use std::io;
11use std::sync::atomic::{AtomicUsize, Ordering};
12use std::sync::mpsc;
13use std::sync::{Arc, Mutex};
14use std::thread;
15
16#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
28pub struct ThreadPool {
29 state: Arc<PoolState>,
30}
31
32#[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
37pub struct ThreadPoolBuilder {
38 pool_size: usize,
39 stack_size: usize,
40 name_prefix: Option<String>,
41 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
42 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
43}
44
45trait AssertSendSync: Send + Sync {}
46impl AssertSendSync for ThreadPool {}
47
48struct PoolState {
49 tx: Mutex<mpsc::Sender<Message>>,
50 rx: Mutex<mpsc::Receiver<Message>>,
51 cnt: AtomicUsize,
52 size: usize,
53}
54
55impl fmt::Debug for ThreadPool {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
58 }
59}
60
61impl fmt::Debug for ThreadPoolBuilder {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("ThreadPoolBuilder")
64 .field("pool_size", &self.pool_size)
65 .field("name_prefix", &self.name_prefix)
66 .finish()
67 }
68}
69
70enum Message {
71 Run(Task),
72 Close,
73}
74
75impl ThreadPool {
76 pub fn new() -> Result<Self, io::Error> {
82 ThreadPoolBuilder::new().create()
83 }
84
85 pub fn builder() -> ThreadPoolBuilder {
91 ThreadPoolBuilder::new()
92 }
93
94 pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
99 let task = Task {
100 future,
101 wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
102 exec: self.clone(),
103 };
104 self.state.send(Message::Run(task));
105 }
106
107 pub fn spawn_ok<Fut>(&self, future: Fut)
125 where
126 Fut: Future<Output = ()> + Send + 'static,
127 {
128 self.spawn_obj_ok(FutureObj::new(Box::new(future)))
129 }
130}
131
132impl Spawn for ThreadPool {
133 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
134 self.spawn_obj_ok(future);
135 Ok(())
136 }
137}
138
139impl PoolState {
140 fn send(&self, msg: Message) {
141 self.tx.lock().unwrap().send(msg).unwrap();
142 }
143
144 fn work(
145 &self,
146 idx: usize,
147 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
148 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
149 ) {
150 let _scope = enter().unwrap();
151 if let Some(after_start) = after_start {
152 after_start(idx);
153 }
154 loop {
155 let msg = self.rx.lock().unwrap().recv().unwrap();
156 match msg {
157 Message::Run(task) => task.run(),
158 Message::Close => break,
159 }
160 }
161 if let Some(before_stop) = before_stop {
162 before_stop(idx);
163 }
164 }
165}
166
167impl Clone for ThreadPool {
168 fn clone(&self) -> Self {
169 self.state.cnt.fetch_add(1, Ordering::Relaxed);
170 Self { state: self.state.clone() }
171 }
172}
173
174impl Drop for ThreadPool {
175 fn drop(&mut self) {
176 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
177 for _ in 0..self.state.size {
178 self.state.send(Message::Close);
179 }
180 }
181 }
182}
183
184impl ThreadPoolBuilder {
185 pub fn new() -> Self {
189 Self {
190 pool_size: cmp::max(1, num_cpus::get()),
191 stack_size: 0,
192 name_prefix: None,
193 after_start: None,
194 before_stop: None,
195 }
196 }
197
198 pub fn pool_size(&mut self, size: usize) -> &mut Self {
207 assert!(size > 0);
208 self.pool_size = size;
209 self
210 }
211
212 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
216 self.stack_size = stack_size;
217 self
218 }
219
220 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
227 self.name_prefix = Some(name_prefix.into());
228 self
229 }
230
231 pub fn after_start<F>(&mut self, f: F) -> &mut Self
241 where
242 F: Fn(usize) + Send + Sync + 'static,
243 {
244 self.after_start = Some(Arc::new(f));
245 self
246 }
247
248 pub fn before_stop<F>(&mut self, f: F) -> &mut Self
257 where
258 F: Fn(usize) + Send + Sync + 'static,
259 {
260 self.before_stop = Some(Arc::new(f));
261 self
262 }
263
264 pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
266 let (tx, rx) = mpsc::channel();
267 let pool = ThreadPool {
268 state: Arc::new(PoolState {
269 tx: Mutex::new(tx),
270 rx: Mutex::new(rx),
271 cnt: AtomicUsize::new(1),
272 size: self.pool_size,
273 }),
274 };
275
276 for counter in 0..self.pool_size {
277 let state = pool.state.clone();
278 let after_start = self.after_start.clone();
279 let before_stop = self.before_stop.clone();
280 let mut thread_builder = thread::Builder::new();
281 if let Some(ref name_prefix) = self.name_prefix {
282 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
283 }
284 if self.stack_size > 0 {
285 thread_builder = thread_builder.stack_size(self.stack_size);
286 }
287 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
288 }
289 Ok(pool)
290 }
291}
292
293impl Default for ThreadPoolBuilder {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299struct Task {
301 future: FutureObj<'static, ()>,
302 exec: ThreadPool,
303 wake_handle: Arc<WakeHandle>,
304}
305
306struct WakeHandle {
307 mutex: UnparkMutex<Task>,
308 exec: ThreadPool,
309}
310
311impl Task {
312 fn run(self) {
315 let Self { mut future, wake_handle, mut exec } = self;
316 let waker = waker_ref(&wake_handle);
317 let mut cx = Context::from_waker(&waker);
318
319 unsafe {
322 wake_handle.mutex.start_poll();
323
324 loop {
325 let res = future.poll_unpin(&mut cx);
326 match res {
327 Poll::Pending => {}
328 Poll::Ready(()) => return wake_handle.mutex.complete(),
329 }
330 let task = Self { future, wake_handle: wake_handle.clone(), exec };
331 match wake_handle.mutex.wait(task) {
332 Ok(()) => return, Err(task) => {
334 future = task.future;
336 exec = task.exec;
337 }
338 }
339 }
340 }
341 }
342}
343
344impl fmt::Debug for Task {
345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346 f.debug_struct("Task").field("contents", &"...").finish()
347 }
348}
349
350impl ArcWake for WakeHandle {
351 fn wake_by_ref(arc_self: &Arc<Self>) {
352 if let Ok(task) = arc_self.mutex.notify() {
353 arc_self.exec.state.send(Message::Run(task))
354 }
355 }
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361 use std::sync::mpsc;
362
363 #[test]
364 fn test_drop_after_start() {
365 {
366 let (tx, rx) = mpsc::sync_channel(2);
367 let _cpu_pool = ThreadPoolBuilder::new()
368 .pool_size(2)
369 .after_start(move |_| tx.send(1).unwrap())
370 .create()
371 .unwrap();
372
373 let count = rx.into_iter().count();
376 assert_eq!(count, 2);
377 }
378 std::thread::sleep(std::time::Duration::from_millis(500)); }
380}