referrerpolicy=no-referrer-when-downgrade

polkadot_node_core_pvf_common/worker/
mod.rs

1// Copyright (C) Parity Technologies (UK) Ltd.
2// This file is part of Polkadot.
3
4// Polkadot is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// Polkadot is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with Polkadot.  If not, see <http://www.gnu.org/licenses/>.
16
17//! Functionality common to both prepare and execute workers.
18
19pub mod security;
20
21use crate::{
22	framed_recv_blocking, framed_send_blocking, SecurityStatus, WorkerHandshake, LOG_TARGET,
23};
24use codec::{Decode, Encode};
25use cpu_time::ProcessTime;
26use futures::never::Never;
27use nix::{errno::Errno, sys::resource::Usage};
28use std::{
29	any::Any,
30	fmt::{self},
31	fs::File,
32	io::{self, Read, Write},
33	os::{
34		fd::{AsRawFd, FromRawFd, RawFd},
35		unix::net::UnixStream,
36	},
37	path::PathBuf,
38	sync::mpsc::{Receiver, RecvTimeoutError},
39	time::Duration,
40};
41
42/// Use this macro to declare a `fn main() {}` that will create an executable that can be used for
43/// spawning the desired worker.
44#[macro_export]
45macro_rules! decl_worker_main {
46	($expected_command:expr, $entrypoint:expr, $worker_version:expr, $worker_version_hash:expr $(,)*) => {
47		fn get_full_version() -> String {
48			format!("{}-{}", $worker_version, $worker_version_hash)
49		}
50
51		fn print_help(expected_command: &str) {
52			println!("{} {}", expected_command, $worker_version);
53			println!("commit: {}", $worker_version_hash);
54			println!();
55			println!("PVF worker that is called by polkadot.");
56		}
57
58		fn main() {
59			#[cfg(target_os = "linux")]
60			use $crate::worker::security;
61
62			$crate::sp_tracing::try_init_simple();
63
64			let args = std::env::args().collect::<Vec<_>>();
65			if args.len() == 1 {
66				print_help($expected_command);
67				return;
68			}
69
70			match args[1].as_ref() {
71				"--help" | "-h" => {
72					print_help($expected_command);
73					return;
74				},
75				"--version" | "-v" => {
76					println!("{}", $worker_version);
77					return;
78				},
79				// Useful for debugging. --version is used for version checks.
80				"--full-version" => {
81					println!("{}", get_full_version());
82					return;
83				},
84
85				"--check-can-enable-landlock" => {
86					#[cfg(target_os = "linux")]
87					let status = if let Err(err) = security::landlock::check_can_fully_enable() {
88						// Write the error to stderr, log it on the host-side.
89						eprintln!("{}", err);
90						-1
91					} else {
92						0
93					};
94					#[cfg(not(target_os = "linux"))]
95					let status = -1;
96					std::process::exit(status)
97				},
98				"--check-can-enable-seccomp" => {
99					#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
100					let status = if let Err(err) = security::seccomp::check_can_fully_enable() {
101						// Write the error to stderr, log it on the host-side.
102						eprintln!("{}", err);
103						-1
104					} else {
105						0
106					};
107					#[cfg(not(all(target_os = "linux", target_arch = "x86_64")))]
108					let status = -1;
109					std::process::exit(status)
110				},
111				"--check-can-unshare-user-namespace-and-change-root" => {
112					#[cfg(target_os = "linux")]
113					let cache_path_tempdir = std::path::Path::new(&args[2]);
114					#[cfg(target_os = "linux")]
115					let status = if let Err(err) =
116						security::change_root::check_can_fully_enable(&cache_path_tempdir)
117					{
118						// Write the error to stderr, log it on the host-side.
119						eprintln!("{}", err);
120						-1
121					} else {
122						0
123					};
124					#[cfg(not(target_os = "linux"))]
125					let status = -1;
126					std::process::exit(status)
127				},
128				"--check-can-do-secure-clone" => {
129					#[cfg(target_os = "linux")]
130					// SAFETY: new process is spawned within a single threaded process. This
131					// invariant is enforced by tests.
132					let status = if let Err(err) = unsafe { security::clone::check_can_fully_clone() } {
133						// Write the error to stderr, log it on the host-side.
134						eprintln!("{}", err);
135						-1
136					} else {
137						0
138					};
139					#[cfg(not(target_os = "linux"))]
140					let status = -1;
141					std::process::exit(status)
142				},
143
144				"test-sleep" => {
145					std::thread::sleep(std::time::Duration::from_secs(5));
146					return;
147				},
148
149				subcommand => {
150					// Must be passed for compatibility with the single-binary test workers.
151					if subcommand != $expected_command {
152						panic!(
153							"trying to run {} binary with the {} subcommand",
154							$expected_command, subcommand
155						)
156					}
157				},
158			}
159
160			let mut socket_path = None;
161			let mut worker_dir_path = None;
162			let mut node_version = None;
163
164			let mut i = 2;
165			while i < args.len() {
166				match args[i].as_ref() {
167					"--socket-path" => {
168						socket_path = Some(args[i + 1].as_str());
169						i += 1
170					},
171					"--worker-dir-path" => {
172						worker_dir_path = Some(args[i + 1].as_str());
173						i += 1
174					},
175					"--node-impl-version" => {
176						node_version = Some(args[i + 1].as_str());
177						i += 1
178					},
179					arg => panic!("Unexpected argument found: {}", arg),
180				}
181				i += 1;
182			}
183			let socket_path = socket_path.expect("the --socket-path argument is required");
184			let worker_dir_path =
185				worker_dir_path.expect("the --worker-dir-path argument is required");
186
187			let socket_path = std::path::Path::new(socket_path).to_owned();
188			let worker_dir_path = std::path::Path::new(worker_dir_path).to_owned();
189
190			$entrypoint(socket_path, worker_dir_path, node_version, Some($worker_version));
191		}
192	};
193}
194
195// taken from the os_pipe crate. Copied here to reduce one dependency and
196// because its type-safe abstractions do not play well with nix's clone
197#[cfg(not(target_os = "macos"))]
198pub fn pipe2_cloexec() -> io::Result<(libc::c_int, libc::c_int)> {
199	let mut fds: [libc::c_int; 2] = [0; 2];
200	let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_CLOEXEC) };
201	if res != 0 {
202		return Err(io::Error::last_os_error());
203	}
204	Ok((fds[0], fds[1]))
205}
206
207#[cfg(target_os = "macos")]
208pub fn pipe2_cloexec() -> io::Result<(libc::c_int, libc::c_int)> {
209	let mut fds: [libc::c_int; 2] = [0; 2];
210	let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
211	if res != 0 {
212		return Err(io::Error::last_os_error());
213	}
214	let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, libc::FD_CLOEXEC) };
215	if res != 0 {
216		return Err(io::Error::last_os_error());
217	}
218	let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, libc::FD_CLOEXEC) };
219	if res != 0 {
220		return Err(io::Error::last_os_error());
221	}
222	Ok((fds[0], fds[1]))
223}
224
225/// A wrapper around a file descriptor used to encapsulate and restrict
226/// functionality for pipe operations.
227pub struct PipeFd {
228	file: File,
229}
230
231impl AsRawFd for PipeFd {
232	/// Returns the raw file descriptor associated with this `PipeFd`
233	fn as_raw_fd(&self) -> RawFd {
234		self.file.as_raw_fd()
235	}
236}
237
238impl FromRawFd for PipeFd {
239	/// Creates a new `PipeFd` instance from a raw file descriptor.
240	///
241	/// # Safety
242	///
243	/// The fd passed in must be an owned file descriptor; in particular, it must be open.
244	unsafe fn from_raw_fd(fd: RawFd) -> Self {
245		PipeFd { file: File::from_raw_fd(fd) }
246	}
247}
248
249impl Read for PipeFd {
250	fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
251		self.file.read(buf)
252	}
253
254	fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
255		self.file.read_to_end(buf)
256	}
257}
258
259impl Write for PipeFd {
260	fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
261		self.file.write(buf)
262	}
263
264	fn flush(&mut self) -> io::Result<()> {
265		self.file.flush()
266	}
267
268	fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
269		self.file.write_all(buf)
270	}
271}
272
273/// Some allowed overhead that we account for in the "CPU time monitor" thread's sleeps, on the
274/// child process.
275pub const JOB_TIMEOUT_OVERHEAD: Duration = Duration::from_millis(50);
276
277#[derive(Debug, Clone, Copy)]
278pub enum WorkerKind {
279	Prepare,
280	Execute,
281	CheckPivotRoot,
282}
283
284impl fmt::Display for WorkerKind {
285	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286		match self {
287			Self::Prepare => write!(f, "prepare"),
288			Self::Execute => write!(f, "execute"),
289			Self::CheckPivotRoot => write!(f, "check pivot root"),
290		}
291	}
292}
293
294#[derive(Debug)]
295pub struct WorkerInfo {
296	pub pid: u32,
297	pub kind: WorkerKind,
298	pub version: Option<String>,
299	pub worker_dir_path: PathBuf,
300}
301
302// NOTE: The worker version must be passed in so that we accurately get the version of the worker,
303// and not the version that this crate was compiled with.
304//
305// NOTE: This must not spawn any threads due to safety requirements in `event_loop` and to avoid
306// errors in [`security::change_root::try_restrict`].
307//
308/// Initializes the worker process, then runs the given event loop, which spawns a new job process
309/// to securely handle each incoming request.
310pub fn run_worker<F>(
311	worker_kind: WorkerKind,
312	socket_path: PathBuf,
313	worker_dir_path: PathBuf,
314	node_version: Option<&str>,
315	worker_version: Option<&str>,
316	mut event_loop: F,
317) where
318	F: FnMut(UnixStream, &WorkerInfo, SecurityStatus) -> io::Result<Never>,
319{
320	#[cfg_attr(not(target_os = "linux"), allow(unused_mut))]
321	let mut worker_info = WorkerInfo {
322		pid: std::process::id(),
323		kind: worker_kind,
324		version: worker_version.map(|v| v.to_string()),
325		worker_dir_path,
326	};
327	gum::debug!(
328		target: LOG_TARGET,
329		?worker_info,
330		?socket_path,
331		"starting pvf worker ({})",
332		worker_info.kind
333	);
334
335	// Check for a mismatch between the node and worker versions.
336	if let (Some(node_version), Some(worker_version)) = (node_version, &worker_info.version) {
337		if node_version != worker_version {
338			gum::error!(
339				target: LOG_TARGET,
340				?worker_info,
341				%node_version,
342				"Node and worker version mismatch, node needs restarting, forcing shutdown",
343			);
344			kill_parent_node_in_emergency();
345			worker_shutdown(worker_info, "Version mismatch");
346		}
347	}
348
349	// Make sure that we can read the worker dir path, and log its contents.
350	let entries: io::Result<Vec<_>> = std::fs::read_dir(&worker_info.worker_dir_path)
351		.and_then(|d| d.map(|res| res.map(|e| e.file_name())).collect());
352	match entries {
353		Ok(entries) => {
354			gum::trace!(target: LOG_TARGET, ?worker_info, "content of worker dir: {:?}", entries)
355		},
356		Err(err) => {
357			let err = format!("Could not read worker dir: {}", err.to_string());
358			worker_shutdown_error(worker_info, &err);
359		},
360	}
361
362	// Connect to the socket.
363	let stream = || -> io::Result<UnixStream> {
364		let stream = UnixStream::connect(&socket_path)?;
365		let _ = std::fs::remove_file(&socket_path);
366		Ok(stream)
367	}();
368	let mut stream = match stream {
369		Ok(ok) => ok,
370		Err(err) => worker_shutdown_error(worker_info, &err.to_string()),
371	};
372
373	let WorkerHandshake { security_status } = match recv_worker_handshake(&mut stream) {
374		Ok(ok) => ok,
375		Err(err) => worker_shutdown_error(worker_info, &err.to_string()),
376	};
377
378	// Enable some security features.
379	{
380		gum::trace!(target: LOG_TARGET, ?security_status, "Enabling security features");
381
382		// First, make sure env vars were cleared, to match the environment we perform the checks
383		// within. (In theory, running checks with different env vars could result in different
384		// outcomes of the checks.)
385		if !security::check_env_vars_were_cleared(&worker_info) {
386			let err = "not all env vars were cleared when spawning the process";
387			gum::error!(
388				target: LOG_TARGET,
389				?worker_info,
390				"{}",
391				err
392			);
393			if security_status.secure_validator_mode {
394				worker_shutdown(worker_info, err);
395			}
396		}
397
398		// Call based on whether we can change root. Error out if it should work but fails.
399		//
400		// NOTE: This should not be called in a multi-threaded context (i.e. inside the tokio
401		// runtime). `unshare(2)`:
402		//
403		//       > CLONE_NEWUSER requires that the calling process is not threaded.
404		#[cfg(target_os = "linux")]
405		if security_status.can_unshare_user_namespace_and_change_root {
406			if let Err(err) = security::change_root::enable_for_worker(&worker_info) {
407				// The filesystem may be in an inconsistent state, always bail out.
408				let err = format!("Could not change root to be the worker cache path: {}", err);
409				worker_shutdown_error(worker_info, &err);
410			}
411			worker_info.worker_dir_path = std::path::Path::new("/").to_owned();
412		}
413
414		#[cfg(target_os = "linux")]
415		if security_status.can_enable_landlock {
416			if let Err(err) = security::landlock::enable_for_worker(&worker_info) {
417				// We previously were able to enable, so this should never happen. Shutdown if
418				// running in secure mode.
419				let err = format!("could not fully enable landlock: {:?}", err);
420				gum::error!(
421					target: LOG_TARGET,
422					?worker_info,
423					"{}. This should not happen, please report an issue",
424					err
425				);
426				if security_status.secure_validator_mode {
427					worker_shutdown(worker_info, &err);
428				}
429			}
430		}
431
432		// TODO: We can enable the seccomp networking blacklist on aarch64 as well, but we need a CI
433		//       job to catch regressions. See issue ci_cd/issues/609.
434		#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
435		if security_status.can_enable_seccomp {
436			if let Err(err) = security::seccomp::enable_for_worker(&worker_info) {
437				// We previously were able to enable, so this should never happen. Shutdown if
438				// running in secure mode.
439				let err = format!("could not fully enable seccomp: {:?}", err);
440				gum::error!(
441					target: LOG_TARGET,
442					?worker_info,
443					"{}. This should not happen, please report an issue",
444					err
445				);
446				if security_status.secure_validator_mode {
447					worker_shutdown(worker_info, &err);
448				}
449			}
450		}
451	}
452
453	// Run the main worker loop.
454	let err = event_loop(stream, &worker_info, security_status)
455		// It's never `Ok` because it's `Ok(Never)`.
456		.unwrap_err();
457
458	worker_shutdown(worker_info, &err.to_string());
459}
460
461/// Provide a consistent message on unexpected worker shutdown.
462fn worker_shutdown(worker_info: WorkerInfo, err: &str) -> ! {
463	gum::warn!(target: LOG_TARGET, ?worker_info, "quitting pvf worker ({}): {}", worker_info.kind, err);
464	std::process::exit(1);
465}
466
467/// Provide a consistent error on unexpected worker shutdown.
468fn worker_shutdown_error(worker_info: WorkerInfo, err: &str) -> ! {
469	gum::error!(target: LOG_TARGET, ?worker_info, "quitting pvf worker ({}): {}", worker_info.kind, err);
470	std::process::exit(1);
471}
472
473/// Loop that runs in the CPU time monitor thread on prepare and execute jobs. Continuously wakes up
474/// and then either blocks for the remaining CPU time, or returns if we exceed the CPU timeout.
475///
476/// Returning `Some` indicates that we should send a `TimedOut` error to the host. Will return
477/// `None` if the other thread finishes first, without us timing out.
478///
479/// NOTE: Sending a `TimedOut` error to the host will cause the worker, whether preparation or
480/// execution, to be killed by the host. We do not kill the process here because it would interfere
481/// with the proper handling of this error.
482pub fn cpu_time_monitor_loop(
483	cpu_time_start: ProcessTime,
484	timeout: Duration,
485	finished_rx: Receiver<()>,
486) -> Option<Duration> {
487	loop {
488		let cpu_time_elapsed = cpu_time_start.elapsed();
489
490		// Treat the timeout as CPU time, which is less subject to variance due to load.
491		if cpu_time_elapsed <= timeout {
492			// Sleep for the remaining CPU time, plus a bit to account for overhead. (And we don't
493			// want to wake up too often -- so, since we just want to halt the worker thread if it
494			// stalled, we can sleep longer than necessary.) Note that the sleep is wall clock time.
495			// The CPU clock may be slower than the wall clock.
496			let sleep_interval = timeout.saturating_sub(cpu_time_elapsed) + JOB_TIMEOUT_OVERHEAD;
497			match finished_rx.recv_timeout(sleep_interval) {
498				// Received finish signal.
499				Ok(()) => return None,
500				// Timed out, restart loop.
501				Err(RecvTimeoutError::Timeout) => continue,
502				Err(RecvTimeoutError::Disconnected) => return None,
503			}
504		}
505
506		return Some(cpu_time_elapsed);
507	}
508}
509
510/// Attempt to convert an opaque panic payload to a string.
511///
512/// This is a best effort, and is not guaranteed to provide the most accurate value.
513pub fn stringify_panic_payload(payload: Box<dyn Any + Send + 'static>) -> String {
514	match payload.downcast::<&'static str>() {
515		Ok(msg) => msg.to_string(),
516		Err(payload) => match payload.downcast::<String>() {
517			Ok(msg) => *msg,
518			// At least we tried...
519			Err(_) => "unknown panic payload".to_string(),
520		},
521	}
522}
523
524/// In case of node and worker version mismatch (as a result of in-place upgrade), send `SIGTERM`
525/// to the node to tear it down and prevent it from raising disputes on valid candidates. Node
526/// restart should be handled by the node owner. As node exits, Unix sockets opened to workers
527/// get closed by the OS and other workers receive error on socket read and also exit. Preparation
528/// jobs are written to the temporary files that are renamed to real artifacts on the node side, so
529/// no leftover artifacts are possible.
530fn kill_parent_node_in_emergency() {
531	unsafe {
532		// SAFETY: `getpid()` never fails but may return "no-parent" (0) or "parent-init" (1) in
533		// some corner cases, which is checked. `kill()` never fails.
534		let ppid = libc::getppid();
535		if ppid > 1 {
536			libc::kill(ppid, libc::SIGTERM);
537		}
538	}
539}
540
541/// Receives a handshake with information for the worker.
542fn recv_worker_handshake(stream: &mut UnixStream) -> io::Result<WorkerHandshake> {
543	let worker_handshake = framed_recv_blocking(stream)?;
544	let worker_handshake = WorkerHandshake::decode(&mut &worker_handshake[..]).map_err(|e| {
545		io::Error::new(
546			io::ErrorKind::Other,
547			format!("recv_worker_handshake: failed to decode WorkerHandshake: {}", e),
548		)
549	})?;
550	Ok(worker_handshake)
551}
552
553/// Calculate the total CPU time from the given `usage` structure, returned from
554/// [`nix::sys::resource::getrusage`], and calculates the total CPU time spent, including both user
555/// and system time.
556///
557/// # Arguments
558///
559/// - `rusage`: Contains resource usage information.
560///
561/// # Returns
562///
563/// Returns a `Duration` representing the total CPU time.
564pub fn get_total_cpu_usage(rusage: Usage) -> Duration {
565	let micros = (((rusage.user_time().tv_sec() + rusage.system_time().tv_sec()) * 1_000_000) +
566		(rusage.system_time().tv_usec() + rusage.user_time().tv_usec()) as i64) as u64;
567
568	return Duration::from_micros(micros);
569}
570
571/// Get a job response.
572pub fn recv_child_response<T>(
573	received_data: &mut io::BufReader<&[u8]>,
574	context: &'static str,
575) -> io::Result<T>
576where
577	T: Decode,
578{
579	let response_bytes = framed_recv_blocking(received_data)?;
580	T::decode(&mut response_bytes.as_slice()).map_err(|e| {
581		io::Error::new(
582			io::ErrorKind::Other,
583			format!("{} pvf recv_child_response: decode error: {}", context, e),
584		)
585	})
586}
587
588pub fn send_result<T, E>(
589	stream: &mut UnixStream,
590	result: Result<T, E>,
591	worker_info: &WorkerInfo,
592) -> io::Result<()>
593where
594	T: std::fmt::Debug,
595	E: std::fmt::Debug + std::fmt::Display,
596	Result<T, E>: Encode,
597{
598	if let Err(ref err) = result {
599		gum::warn!(
600			target: LOG_TARGET,
601			?worker_info,
602			"worker: error occurred: {}",
603			err
604		);
605	}
606	gum::trace!(
607		target: LOG_TARGET,
608		?worker_info,
609		"worker: sending result to host: {:?}",
610		result
611	);
612
613	framed_send_blocking(stream, &result.encode()).map_err(|err| {
614		gum::warn!(
615			target: LOG_TARGET,
616			?worker_info,
617			"worker: error occurred sending result to host: {}",
618			err
619		);
620		err
621	})
622}
623
624pub fn stringify_errno(context: &'static str, errno: Errno) -> String {
625	format!("{}: {}: {}", context, errno, io::Error::last_os_error())
626}
627
628/// Functionality related to threads spawned by the workers.
629///
630/// The motivation for this module is to coordinate worker threads without using async Rust.
631pub mod thread {
632	use std::{
633		io, panic,
634		sync::{Arc, Condvar, Mutex},
635		thread,
636		time::Duration,
637	};
638
639	/// Contains the outcome of waiting on threads, or `Pending` if none are ready.
640	#[derive(Debug, Clone, Copy)]
641	pub enum WaitOutcome {
642		Finished,
643		TimedOut,
644		Pending,
645	}
646
647	impl WaitOutcome {
648		pub fn is_pending(&self) -> bool {
649			matches!(self, Self::Pending)
650		}
651	}
652
653	/// Helper type.
654	pub type Cond = Arc<(Mutex<WaitOutcome>, Condvar)>;
655
656	/// Gets a condvar initialized to `Pending`.
657	pub fn get_condvar() -> Cond {
658		Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()))
659	}
660
661	/// Runs a worker thread. Will run the requested function, and afterwards notify the threads
662	/// waiting on the condvar. Catches panics during execution and resumes the panics after
663	/// triggering the condvar, so that the waiting thread is notified on panics.
664	///
665	/// # Returns
666	///
667	/// Returns the thread's join handle. Calling `.join()` on it returns the result of executing
668	/// `f()`, as well as whether we were able to enable sandboxing.
669	pub fn spawn_worker_thread<F, R>(
670		name: &str,
671		f: F,
672		cond: Cond,
673		outcome: WaitOutcome,
674	) -> io::Result<thread::JoinHandle<R>>
675	where
676		F: FnOnce() -> R,
677		F: Send + 'static + panic::UnwindSafe,
678		R: Send + 'static,
679	{
680		thread::Builder::new()
681			.name(name.into())
682			.spawn(move || cond_notify_on_done(f, cond, outcome))
683	}
684
685	/// Runs a worker thread with the given stack size. See [`spawn_worker_thread`].
686	pub fn spawn_worker_thread_with_stack_size<F, R>(
687		name: &str,
688		f: F,
689		cond: Cond,
690		outcome: WaitOutcome,
691		stack_size: usize,
692	) -> io::Result<thread::JoinHandle<R>>
693	where
694		F: FnOnce() -> R,
695		F: Send + 'static + panic::UnwindSafe,
696		R: Send + 'static,
697	{
698		thread::Builder::new()
699			.name(name.into())
700			.stack_size(stack_size)
701			.spawn(move || cond_notify_on_done(f, cond, outcome))
702	}
703
704	/// Runs a function, afterwards notifying the threads waiting on the condvar. Catches panics and
705	/// resumes them after triggering the condvar, so that the waiting thread is notified on panics.
706	fn cond_notify_on_done<F, R>(f: F, cond: Cond, outcome: WaitOutcome) -> R
707	where
708		F: FnOnce() -> R,
709		F: panic::UnwindSafe,
710	{
711		let result = panic::catch_unwind(|| f());
712		cond_notify_all(cond, outcome);
713		match result {
714			Ok(inner) => return inner,
715			Err(err) => panic::resume_unwind(err),
716		}
717	}
718
719	/// Helper function to notify all threads waiting on this condvar.
720	fn cond_notify_all(cond: Cond, outcome: WaitOutcome) {
721		let (lock, cvar) = &*cond;
722		let mut flag = lock.lock().unwrap();
723		if !flag.is_pending() {
724			// Someone else already triggered the condvar.
725			return;
726		}
727		*flag = outcome;
728		cvar.notify_all();
729	}
730
731	/// Block the thread while it waits on the condvar.
732	pub fn wait_for_threads(cond: Cond) -> WaitOutcome {
733		let (lock, cvar) = &*cond;
734		let guard = cvar.wait_while(lock.lock().unwrap(), |flag| flag.is_pending()).unwrap();
735		*guard
736	}
737
738	/// Block the thread while it waits on the condvar or on a timeout. If the timeout is hit,
739	/// returns `None`.
740	#[cfg_attr(not(any(target_os = "linux", feature = "jemalloc-allocator")), allow(dead_code))]
741	pub fn wait_for_threads_with_timeout(cond: &Cond, dur: Duration) -> Option<WaitOutcome> {
742		let (lock, cvar) = &**cond;
743		let result = cvar
744			.wait_timeout_while(lock.lock().unwrap(), dur, |flag| flag.is_pending())
745			.unwrap();
746		if result.1.timed_out() {
747			None
748		} else {
749			Some(*result.0)
750		}
751	}
752
753	#[cfg(test)]
754	mod tests {
755		use super::*;
756		use assert_matches::assert_matches;
757
758		#[test]
759		fn get_condvar_should_be_pending() {
760			let condvar = get_condvar();
761			let outcome = *condvar.0.lock().unwrap();
762			assert!(outcome.is_pending());
763		}
764
765		#[test]
766		fn wait_for_threads_with_timeout_return_none_on_time_out() {
767			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
768			let outcome = wait_for_threads_with_timeout(&condvar, Duration::from_millis(100));
769			assert!(outcome.is_none());
770		}
771
772		#[test]
773		fn wait_for_threads_with_timeout_returns_outcome() {
774			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
775			let condvar2 = condvar.clone();
776			cond_notify_all(condvar2, WaitOutcome::Finished);
777			let outcome = wait_for_threads_with_timeout(&condvar, Duration::from_secs(2));
778			assert_matches!(outcome.unwrap(), WaitOutcome::Finished);
779		}
780
781		#[test]
782		fn spawn_worker_thread_should_notify_on_done() {
783			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
784			let response =
785				spawn_worker_thread("thread", || 2, condvar.clone(), WaitOutcome::TimedOut);
786			let (lock, _) = &*condvar;
787			let r = response.unwrap().join().unwrap();
788			assert_eq!(r, 2);
789			assert_matches!(*lock.lock().unwrap(), WaitOutcome::TimedOut);
790		}
791
792		#[test]
793		fn spawn_worker_should_not_change_finished_outcome() {
794			let condvar = Arc::new((Mutex::new(WaitOutcome::Finished), Condvar::new()));
795			let response =
796				spawn_worker_thread("thread", move || 2, condvar.clone(), WaitOutcome::TimedOut);
797
798			let r = response.unwrap().join().unwrap();
799			assert_eq!(r, 2);
800			assert_matches!(*condvar.0.lock().unwrap(), WaitOutcome::Finished);
801		}
802
803		#[test]
804		fn cond_notify_on_done_should_update_wait_outcome_when_panic() {
805			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
806			let err = panic::catch_unwind(panic::AssertUnwindSafe(|| {
807				cond_notify_on_done(|| panic!("test"), condvar.clone(), WaitOutcome::Finished)
808			}));
809
810			assert_matches!(*condvar.0.lock().unwrap(), WaitOutcome::Finished);
811			assert!(err.is_err());
812		}
813	}
814}
815
816#[cfg(test)]
817mod tests {
818	use super::*;
819	use std::sync::mpsc::channel;
820
821	#[test]
822	fn cpu_time_monitor_loop_should_return_time_elapsed() {
823		let cpu_time_start = ProcessTime::now();
824		let timeout = Duration::from_secs(0);
825		let (_tx, rx) = channel();
826		let result = cpu_time_monitor_loop(cpu_time_start, timeout, rx);
827		assert_ne!(result, None);
828	}
829
830	#[test]
831	fn cpu_time_monitor_loop_should_return_none() {
832		let cpu_time_start = ProcessTime::now();
833		let timeout = Duration::from_secs(10);
834		let (tx, rx) = channel();
835		tx.send(()).unwrap();
836		let result = cpu_time_monitor_loop(cpu_time_start, timeout, rx);
837		assert_eq!(result, None);
838	}
839}