referrerpolicy=no-referrer-when-downgrade

polkadot_node_core_pvf_common/worker/
mod.rs

1// Copyright (C) Parity Technologies (UK) Ltd.
2// This file is part of Polkadot.
3
4// Polkadot is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// Polkadot is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with Polkadot.  If not, see <http://www.gnu.org/licenses/>.
16
17//! Functionality common to both prepare and execute workers.
18
19pub mod security;
20
21use crate::{
22	framed_recv_blocking, framed_send_blocking, SecurityStatus, WorkerHandshake, LOG_TARGET,
23};
24use codec::{Decode, Encode};
25use cpu_time::ProcessTime;
26use futures::never::Never;
27use nix::{errno::Errno, sys::resource::Usage};
28use std::{
29	any::Any,
30	fmt::{self},
31	fs::File,
32	io::{self, Read, Write},
33	os::{
34		fd::{AsRawFd, FromRawFd, RawFd},
35		unix::net::UnixStream,
36	},
37	path::PathBuf,
38	sync::mpsc::{Receiver, RecvTimeoutError},
39	time::Duration,
40};
41
42/// Use this macro to declare a `fn main() {}` that will create an executable that can be used for
43/// spawning the desired worker.
44#[macro_export]
45macro_rules! decl_worker_main {
46	($expected_command:expr, $entrypoint:expr, $worker_version:expr, $worker_version_hash:expr $(,)*) => {
47		fn get_full_version() -> String {
48			format!("{}-{}", $worker_version, $worker_version_hash)
49		}
50
51		fn print_help(expected_command: &str) {
52			println!("{} {}", expected_command, $worker_version);
53			println!("commit: {}", $worker_version_hash);
54			println!();
55			println!("PVF worker that is called by polkadot.");
56		}
57
58		fn main() {
59			#[cfg(target_os = "linux")]
60			use $crate::worker::security;
61
62			$crate::sp_tracing::try_init_simple();
63
64			let args = std::env::args().collect::<Vec<_>>();
65			if args.len() == 1 {
66				print_help($expected_command);
67				return
68			}
69
70			match args[1].as_ref() {
71				"--help" | "-h" => {
72					print_help($expected_command);
73					return
74				},
75				"--version" | "-v" => {
76					println!("{}", $worker_version);
77					return
78				},
79				// Useful for debugging. --version is used for version checks.
80				"--full-version" => {
81					println!("{}", get_full_version());
82					return
83				},
84
85				"--check-can-enable-landlock" => {
86					#[cfg(target_os = "linux")]
87					let status = if let Err(err) = security::landlock::check_can_fully_enable() {
88						// Write the error to stderr, log it on the host-side.
89						eprintln!("{}", err);
90						-1
91					} else {
92						0
93					};
94					#[cfg(not(target_os = "linux"))]
95					let status = -1;
96					std::process::exit(status)
97				},
98				"--check-can-enable-seccomp" => {
99					#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
100					let status = if let Err(err) = security::seccomp::check_can_fully_enable() {
101						// Write the error to stderr, log it on the host-side.
102						eprintln!("{}", err);
103						-1
104					} else {
105						0
106					};
107					#[cfg(not(all(target_os = "linux", target_arch = "x86_64")))]
108					let status = -1;
109					std::process::exit(status)
110				},
111				"--check-can-unshare-user-namespace-and-change-root" => {
112					#[cfg(target_os = "linux")]
113					let cache_path_tempdir = std::path::Path::new(&args[2]);
114					#[cfg(target_os = "linux")]
115					let status = if let Err(err) =
116						security::change_root::check_can_fully_enable(&cache_path_tempdir)
117					{
118						// Write the error to stderr, log it on the host-side.
119						eprintln!("{}", err);
120						-1
121					} else {
122						0
123					};
124					#[cfg(not(target_os = "linux"))]
125					let status = -1;
126					std::process::exit(status)
127				},
128				"--check-can-do-secure-clone" => {
129					#[cfg(target_os = "linux")]
130					// SAFETY: new process is spawned within a single threaded process. This
131					// invariant is enforced by tests.
132					let status = if let Err(err) = unsafe { security::clone::check_can_fully_clone() } {
133						// Write the error to stderr, log it on the host-side.
134						eprintln!("{}", err);
135						-1
136					} else {
137						0
138					};
139					#[cfg(not(target_os = "linux"))]
140					let status = -1;
141					std::process::exit(status)
142				},
143
144				"test-sleep" => {
145					std::thread::sleep(std::time::Duration::from_secs(5));
146					return
147				},
148
149				subcommand => {
150					// Must be passed for compatibility with the single-binary test workers.
151					if subcommand != $expected_command {
152						panic!(
153							"trying to run {} binary with the {} subcommand",
154							$expected_command, subcommand
155						)
156					}
157				},
158			}
159
160			let mut socket_path = None;
161			let mut worker_dir_path = None;
162			let mut node_version = None;
163
164			let mut i = 2;
165			while i < args.len() {
166				match args[i].as_ref() {
167					"--socket-path" => {
168						socket_path = Some(args[i + 1].as_str());
169						i += 1
170					},
171					"--worker-dir-path" => {
172						worker_dir_path = Some(args[i + 1].as_str());
173						i += 1
174					},
175					"--node-impl-version" => {
176						node_version = Some(args[i + 1].as_str());
177						i += 1
178					},
179					arg => panic!("Unexpected argument found: {}", arg),
180				}
181				i += 1;
182			}
183			let socket_path = socket_path.expect("the --socket-path argument is required");
184			let worker_dir_path =
185				worker_dir_path.expect("the --worker-dir-path argument is required");
186
187			let socket_path = std::path::Path::new(socket_path).to_owned();
188			let worker_dir_path = std::path::Path::new(worker_dir_path).to_owned();
189
190			$entrypoint(socket_path, worker_dir_path, node_version, Some($worker_version));
191		}
192	};
193}
194
195//taken from the os_pipe crate. Copied here to reduce one dependency and
196// because its type-safe abstractions do not play well with nix's clone
197#[cfg(not(target_os = "macos"))]
198pub fn pipe2_cloexec() -> io::Result<(libc::c_int, libc::c_int)> {
199	let mut fds: [libc::c_int; 2] = [0; 2];
200	let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_CLOEXEC) };
201	if res != 0 {
202		return Err(io::Error::last_os_error())
203	}
204	Ok((fds[0], fds[1]))
205}
206
207#[cfg(target_os = "macos")]
208pub fn pipe2_cloexec() -> io::Result<(libc::c_int, libc::c_int)> {
209	let mut fds: [libc::c_int; 2] = [0; 2];
210	let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
211	if res != 0 {
212		return Err(io::Error::last_os_error())
213	}
214	let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, libc::FD_CLOEXEC) };
215	if res != 0 {
216		return Err(io::Error::last_os_error())
217	}
218	let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, libc::FD_CLOEXEC) };
219	if res != 0 {
220		return Err(io::Error::last_os_error())
221	}
222	Ok((fds[0], fds[1]))
223}
224
225/// A wrapper around a file descriptor used to encapsulate and restrict
226/// functionality for pipe operations.
227pub struct PipeFd {
228	file: File,
229}
230
231impl AsRawFd for PipeFd {
232	/// Returns the raw file descriptor associated with this `PipeFd`
233	fn as_raw_fd(&self) -> RawFd {
234		self.file.as_raw_fd()
235	}
236}
237
238impl FromRawFd for PipeFd {
239	/// Creates a new `PipeFd` instance from a raw file descriptor.
240	///
241	/// # Safety
242	///
243	/// The fd passed in must be an owned file descriptor; in particular, it must be open.
244	unsafe fn from_raw_fd(fd: RawFd) -> Self {
245		PipeFd { file: File::from_raw_fd(fd) }
246	}
247}
248
249impl Read for PipeFd {
250	fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
251		self.file.read(buf)
252	}
253
254	fn read_to_end(&mut self, buf: &mut Vec<u8>) -> io::Result<usize> {
255		self.file.read_to_end(buf)
256	}
257}
258
259impl Write for PipeFd {
260	fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
261		self.file.write(buf)
262	}
263
264	fn flush(&mut self) -> io::Result<()> {
265		self.file.flush()
266	}
267
268	fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
269		self.file.write_all(buf)
270	}
271}
272
273/// Some allowed overhead that we account for in the "CPU time monitor" thread's sleeps, on the
274/// child process.
275pub const JOB_TIMEOUT_OVERHEAD: Duration = Duration::from_millis(50);
276
277#[derive(Debug, Clone, Copy)]
278pub enum WorkerKind {
279	Prepare,
280	Execute,
281	CheckPivotRoot,
282}
283
284impl fmt::Display for WorkerKind {
285	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
286		match self {
287			Self::Prepare => write!(f, "prepare"),
288			Self::Execute => write!(f, "execute"),
289			Self::CheckPivotRoot => write!(f, "check pivot root"),
290		}
291	}
292}
293
294#[derive(Debug)]
295pub struct WorkerInfo {
296	pub pid: u32,
297	pub kind: WorkerKind,
298	pub version: Option<String>,
299	pub worker_dir_path: PathBuf,
300}
301
302// NOTE: The worker version must be passed in so that we accurately get the version of the worker,
303// and not the version that this crate was compiled with.
304//
305// NOTE: This must not spawn any threads due to safety requirements in `event_loop` and to avoid
306// errors in [`security::change_root::try_restrict`].
307//
308/// Initializes the worker process, then runs the given event loop, which spawns a new job process
309/// to securely handle each incoming request.
310pub fn run_worker<F>(
311	worker_kind: WorkerKind,
312	socket_path: PathBuf,
313	worker_dir_path: PathBuf,
314	node_version: Option<&str>,
315	worker_version: Option<&str>,
316	mut event_loop: F,
317) where
318	F: FnMut(UnixStream, &WorkerInfo, SecurityStatus) -> io::Result<Never>,
319{
320	#[cfg_attr(not(target_os = "linux"), allow(unused_mut))]
321	let mut worker_info = WorkerInfo {
322		pid: std::process::id(),
323		kind: worker_kind,
324		version: worker_version.map(|v| v.to_string()),
325		worker_dir_path,
326	};
327	gum::debug!(
328		target: LOG_TARGET,
329		?worker_info,
330		?socket_path,
331		"starting pvf worker ({})",
332		worker_info.kind
333	);
334
335	// Check for a mismatch between the node and worker versions.
336	if let (Some(node_version), Some(worker_version)) = (node_version, &worker_info.version) {
337		if node_version != worker_version {
338			gum::error!(
339				target: LOG_TARGET,
340				?worker_info,
341				%node_version,
342				"Node and worker version mismatch, node needs restarting, forcing shutdown",
343			);
344			kill_parent_node_in_emergency();
345			worker_shutdown(worker_info, "Version mismatch");
346		}
347	}
348
349	// Make sure that we can read the worker dir path, and log its contents.
350	let entries: io::Result<Vec<_>> = std::fs::read_dir(&worker_info.worker_dir_path)
351		.and_then(|d| d.map(|res| res.map(|e| e.file_name())).collect());
352	match entries {
353		Ok(entries) =>
354			gum::trace!(target: LOG_TARGET, ?worker_info, "content of worker dir: {:?}", entries),
355		Err(err) => {
356			let err = format!("Could not read worker dir: {}", err.to_string());
357			worker_shutdown_error(worker_info, &err);
358		},
359	}
360
361	// Connect to the socket.
362	let stream = || -> io::Result<UnixStream> {
363		let stream = UnixStream::connect(&socket_path)?;
364		let _ = std::fs::remove_file(&socket_path);
365		Ok(stream)
366	}();
367	let mut stream = match stream {
368		Ok(ok) => ok,
369		Err(err) => worker_shutdown_error(worker_info, &err.to_string()),
370	};
371
372	let WorkerHandshake { security_status } = match recv_worker_handshake(&mut stream) {
373		Ok(ok) => ok,
374		Err(err) => worker_shutdown_error(worker_info, &err.to_string()),
375	};
376
377	// Enable some security features.
378	{
379		gum::trace!(target: LOG_TARGET, ?security_status, "Enabling security features");
380
381		// First, make sure env vars were cleared, to match the environment we perform the checks
382		// within. (In theory, running checks with different env vars could result in different
383		// outcomes of the checks.)
384		if !security::check_env_vars_were_cleared(&worker_info) {
385			let err = "not all env vars were cleared when spawning the process";
386			gum::error!(
387				target: LOG_TARGET,
388				?worker_info,
389				"{}",
390				err
391			);
392			if security_status.secure_validator_mode {
393				worker_shutdown(worker_info, err);
394			}
395		}
396
397		// Call based on whether we can change root. Error out if it should work but fails.
398		//
399		// NOTE: This should not be called in a multi-threaded context (i.e. inside the tokio
400		// runtime). `unshare(2)`:
401		//
402		//       > CLONE_NEWUSER requires that the calling process is not threaded.
403		#[cfg(target_os = "linux")]
404		if security_status.can_unshare_user_namespace_and_change_root {
405			if let Err(err) = security::change_root::enable_for_worker(&worker_info) {
406				// The filesystem may be in an inconsistent state, always bail out.
407				let err = format!("Could not change root to be the worker cache path: {}", err);
408				worker_shutdown_error(worker_info, &err);
409			}
410			worker_info.worker_dir_path = std::path::Path::new("/").to_owned();
411		}
412
413		#[cfg(target_os = "linux")]
414		if security_status.can_enable_landlock {
415			if let Err(err) = security::landlock::enable_for_worker(&worker_info) {
416				// We previously were able to enable, so this should never happen. Shutdown if
417				// running in secure mode.
418				let err = format!("could not fully enable landlock: {:?}", err);
419				gum::error!(
420					target: LOG_TARGET,
421					?worker_info,
422					"{}. This should not happen, please report an issue",
423					err
424				);
425				if security_status.secure_validator_mode {
426					worker_shutdown(worker_info, &err);
427				}
428			}
429		}
430
431		// TODO: We can enable the seccomp networking blacklist on aarch64 as well, but we need a CI
432		//       job to catch regressions. See issue ci_cd/issues/609.
433		#[cfg(all(target_os = "linux", target_arch = "x86_64"))]
434		if security_status.can_enable_seccomp {
435			if let Err(err) = security::seccomp::enable_for_worker(&worker_info) {
436				// We previously were able to enable, so this should never happen. Shutdown if
437				// running in secure mode.
438				let err = format!("could not fully enable seccomp: {:?}", err);
439				gum::error!(
440					target: LOG_TARGET,
441					?worker_info,
442					"{}. This should not happen, please report an issue",
443					err
444				);
445				if security_status.secure_validator_mode {
446					worker_shutdown(worker_info, &err);
447				}
448			}
449		}
450	}
451
452	// Run the main worker loop.
453	let err = event_loop(stream, &worker_info, security_status)
454		// It's never `Ok` because it's `Ok(Never)`.
455		.unwrap_err();
456
457	worker_shutdown(worker_info, &err.to_string());
458}
459
460/// Provide a consistent message on unexpected worker shutdown.
461fn worker_shutdown(worker_info: WorkerInfo, err: &str) -> ! {
462	gum::warn!(target: LOG_TARGET, ?worker_info, "quitting pvf worker ({}): {}", worker_info.kind, err);
463	std::process::exit(1);
464}
465
466/// Provide a consistent error on unexpected worker shutdown.
467fn worker_shutdown_error(worker_info: WorkerInfo, err: &str) -> ! {
468	gum::error!(target: LOG_TARGET, ?worker_info, "quitting pvf worker ({}): {}", worker_info.kind, err);
469	std::process::exit(1);
470}
471
472/// Loop that runs in the CPU time monitor thread on prepare and execute jobs. Continuously wakes up
473/// and then either blocks for the remaining CPU time, or returns if we exceed the CPU timeout.
474///
475/// Returning `Some` indicates that we should send a `TimedOut` error to the host. Will return
476/// `None` if the other thread finishes first, without us timing out.
477///
478/// NOTE: Sending a `TimedOut` error to the host will cause the worker, whether preparation or
479/// execution, to be killed by the host. We do not kill the process here because it would interfere
480/// with the proper handling of this error.
481pub fn cpu_time_monitor_loop(
482	cpu_time_start: ProcessTime,
483	timeout: Duration,
484	finished_rx: Receiver<()>,
485) -> Option<Duration> {
486	loop {
487		let cpu_time_elapsed = cpu_time_start.elapsed();
488
489		// Treat the timeout as CPU time, which is less subject to variance due to load.
490		if cpu_time_elapsed <= timeout {
491			// Sleep for the remaining CPU time, plus a bit to account for overhead. (And we don't
492			// want to wake up too often -- so, since we just want to halt the worker thread if it
493			// stalled, we can sleep longer than necessary.) Note that the sleep is wall clock time.
494			// The CPU clock may be slower than the wall clock.
495			let sleep_interval = timeout.saturating_sub(cpu_time_elapsed) + JOB_TIMEOUT_OVERHEAD;
496			match finished_rx.recv_timeout(sleep_interval) {
497				// Received finish signal.
498				Ok(()) => return None,
499				// Timed out, restart loop.
500				Err(RecvTimeoutError::Timeout) => continue,
501				Err(RecvTimeoutError::Disconnected) => return None,
502			}
503		}
504
505		return Some(cpu_time_elapsed)
506	}
507}
508
509/// Attempt to convert an opaque panic payload to a string.
510///
511/// This is a best effort, and is not guaranteed to provide the most accurate value.
512pub fn stringify_panic_payload(payload: Box<dyn Any + Send + 'static>) -> String {
513	match payload.downcast::<&'static str>() {
514		Ok(msg) => msg.to_string(),
515		Err(payload) => match payload.downcast::<String>() {
516			Ok(msg) => *msg,
517			// At least we tried...
518			Err(_) => "unknown panic payload".to_string(),
519		},
520	}
521}
522
523/// In case of node and worker version mismatch (as a result of in-place upgrade), send `SIGTERM`
524/// to the node to tear it down and prevent it from raising disputes on valid candidates. Node
525/// restart should be handled by the node owner. As node exits, Unix sockets opened to workers
526/// get closed by the OS and other workers receive error on socket read and also exit. Preparation
527/// jobs are written to the temporary files that are renamed to real artifacts on the node side, so
528/// no leftover artifacts are possible.
529fn kill_parent_node_in_emergency() {
530	unsafe {
531		// SAFETY: `getpid()` never fails but may return "no-parent" (0) or "parent-init" (1) in
532		// some corner cases, which is checked. `kill()` never fails.
533		let ppid = libc::getppid();
534		if ppid > 1 {
535			libc::kill(ppid, libc::SIGTERM);
536		}
537	}
538}
539
540/// Receives a handshake with information for the worker.
541fn recv_worker_handshake(stream: &mut UnixStream) -> io::Result<WorkerHandshake> {
542	let worker_handshake = framed_recv_blocking(stream)?;
543	let worker_handshake = WorkerHandshake::decode(&mut &worker_handshake[..]).map_err(|e| {
544		io::Error::new(
545			io::ErrorKind::Other,
546			format!("recv_worker_handshake: failed to decode WorkerHandshake: {}", e),
547		)
548	})?;
549	Ok(worker_handshake)
550}
551
552/// Calculate the total CPU time from the given `usage` structure, returned from
553/// [`nix::sys::resource::getrusage`], and calculates the total CPU time spent, including both user
554/// and system time.
555///
556/// # Arguments
557///
558/// - `rusage`: Contains resource usage information.
559///
560/// # Returns
561///
562/// Returns a `Duration` representing the total CPU time.
563pub fn get_total_cpu_usage(rusage: Usage) -> Duration {
564	let micros = (((rusage.user_time().tv_sec() + rusage.system_time().tv_sec()) * 1_000_000) +
565		(rusage.system_time().tv_usec() + rusage.user_time().tv_usec()) as i64) as u64;
566
567	return Duration::from_micros(micros)
568}
569
570/// Get a job response.
571pub fn recv_child_response<T>(
572	received_data: &mut io::BufReader<&[u8]>,
573	context: &'static str,
574) -> io::Result<T>
575where
576	T: Decode,
577{
578	let response_bytes = framed_recv_blocking(received_data)?;
579	T::decode(&mut response_bytes.as_slice()).map_err(|e| {
580		io::Error::new(
581			io::ErrorKind::Other,
582			format!("{} pvf recv_child_response: decode error: {}", context, e),
583		)
584	})
585}
586
587pub fn send_result<T, E>(
588	stream: &mut UnixStream,
589	result: Result<T, E>,
590	worker_info: &WorkerInfo,
591) -> io::Result<()>
592where
593	T: std::fmt::Debug,
594	E: std::fmt::Debug + std::fmt::Display,
595	Result<T, E>: Encode,
596{
597	if let Err(ref err) = result {
598		gum::warn!(
599			target: LOG_TARGET,
600			?worker_info,
601			"worker: error occurred: {}",
602			err
603		);
604	}
605	gum::trace!(
606		target: LOG_TARGET,
607		?worker_info,
608		"worker: sending result to host: {:?}",
609		result
610	);
611
612	framed_send_blocking(stream, &result.encode()).map_err(|err| {
613		gum::warn!(
614			target: LOG_TARGET,
615			?worker_info,
616			"worker: error occurred sending result to host: {}",
617			err
618		);
619		err
620	})
621}
622
623pub fn stringify_errno(context: &'static str, errno: Errno) -> String {
624	format!("{}: {}: {}", context, errno, io::Error::last_os_error())
625}
626
627/// Functionality related to threads spawned by the workers.
628///
629/// The motivation for this module is to coordinate worker threads without using async Rust.
630pub mod thread {
631	use std::{
632		io, panic,
633		sync::{Arc, Condvar, Mutex},
634		thread,
635		time::Duration,
636	};
637
638	/// Contains the outcome of waiting on threads, or `Pending` if none are ready.
639	#[derive(Debug, Clone, Copy)]
640	pub enum WaitOutcome {
641		Finished,
642		TimedOut,
643		Pending,
644	}
645
646	impl WaitOutcome {
647		pub fn is_pending(&self) -> bool {
648			matches!(self, Self::Pending)
649		}
650	}
651
652	/// Helper type.
653	pub type Cond = Arc<(Mutex<WaitOutcome>, Condvar)>;
654
655	/// Gets a condvar initialized to `Pending`.
656	pub fn get_condvar() -> Cond {
657		Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()))
658	}
659
660	/// Runs a worker thread. Will run the requested function, and afterwards notify the threads
661	/// waiting on the condvar. Catches panics during execution and resumes the panics after
662	/// triggering the condvar, so that the waiting thread is notified on panics.
663	///
664	/// # Returns
665	///
666	/// Returns the thread's join handle. Calling `.join()` on it returns the result of executing
667	/// `f()`, as well as whether we were able to enable sandboxing.
668	pub fn spawn_worker_thread<F, R>(
669		name: &str,
670		f: F,
671		cond: Cond,
672		outcome: WaitOutcome,
673	) -> io::Result<thread::JoinHandle<R>>
674	where
675		F: FnOnce() -> R,
676		F: Send + 'static + panic::UnwindSafe,
677		R: Send + 'static,
678	{
679		thread::Builder::new()
680			.name(name.into())
681			.spawn(move || cond_notify_on_done(f, cond, outcome))
682	}
683
684	/// Runs a worker thread with the given stack size. See [`spawn_worker_thread`].
685	pub fn spawn_worker_thread_with_stack_size<F, R>(
686		name: &str,
687		f: F,
688		cond: Cond,
689		outcome: WaitOutcome,
690		stack_size: usize,
691	) -> io::Result<thread::JoinHandle<R>>
692	where
693		F: FnOnce() -> R,
694		F: Send + 'static + panic::UnwindSafe,
695		R: Send + 'static,
696	{
697		thread::Builder::new()
698			.name(name.into())
699			.stack_size(stack_size)
700			.spawn(move || cond_notify_on_done(f, cond, outcome))
701	}
702
703	/// Runs a function, afterwards notifying the threads waiting on the condvar. Catches panics and
704	/// resumes them after triggering the condvar, so that the waiting thread is notified on panics.
705	fn cond_notify_on_done<F, R>(f: F, cond: Cond, outcome: WaitOutcome) -> R
706	where
707		F: FnOnce() -> R,
708		F: panic::UnwindSafe,
709	{
710		let result = panic::catch_unwind(|| f());
711		cond_notify_all(cond, outcome);
712		match result {
713			Ok(inner) => return inner,
714			Err(err) => panic::resume_unwind(err),
715		}
716	}
717
718	/// Helper function to notify all threads waiting on this condvar.
719	fn cond_notify_all(cond: Cond, outcome: WaitOutcome) {
720		let (lock, cvar) = &*cond;
721		let mut flag = lock.lock().unwrap();
722		if !flag.is_pending() {
723			// Someone else already triggered the condvar.
724			return
725		}
726		*flag = outcome;
727		cvar.notify_all();
728	}
729
730	/// Block the thread while it waits on the condvar.
731	pub fn wait_for_threads(cond: Cond) -> WaitOutcome {
732		let (lock, cvar) = &*cond;
733		let guard = cvar.wait_while(lock.lock().unwrap(), |flag| flag.is_pending()).unwrap();
734		*guard
735	}
736
737	/// Block the thread while it waits on the condvar or on a timeout. If the timeout is hit,
738	/// returns `None`.
739	#[cfg_attr(not(any(target_os = "linux", feature = "jemalloc-allocator")), allow(dead_code))]
740	pub fn wait_for_threads_with_timeout(cond: &Cond, dur: Duration) -> Option<WaitOutcome> {
741		let (lock, cvar) = &**cond;
742		let result = cvar
743			.wait_timeout_while(lock.lock().unwrap(), dur, |flag| flag.is_pending())
744			.unwrap();
745		if result.1.timed_out() {
746			None
747		} else {
748			Some(*result.0)
749		}
750	}
751
752	#[cfg(test)]
753	mod tests {
754		use super::*;
755		use assert_matches::assert_matches;
756
757		#[test]
758		fn get_condvar_should_be_pending() {
759			let condvar = get_condvar();
760			let outcome = *condvar.0.lock().unwrap();
761			assert!(outcome.is_pending());
762		}
763
764		#[test]
765		fn wait_for_threads_with_timeout_return_none_on_time_out() {
766			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
767			let outcome = wait_for_threads_with_timeout(&condvar, Duration::from_millis(100));
768			assert!(outcome.is_none());
769		}
770
771		#[test]
772		fn wait_for_threads_with_timeout_returns_outcome() {
773			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
774			let condvar2 = condvar.clone();
775			cond_notify_all(condvar2, WaitOutcome::Finished);
776			let outcome = wait_for_threads_with_timeout(&condvar, Duration::from_secs(2));
777			assert_matches!(outcome.unwrap(), WaitOutcome::Finished);
778		}
779
780		#[test]
781		fn spawn_worker_thread_should_notify_on_done() {
782			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
783			let response =
784				spawn_worker_thread("thread", || 2, condvar.clone(), WaitOutcome::TimedOut);
785			let (lock, _) = &*condvar;
786			let r = response.unwrap().join().unwrap();
787			assert_eq!(r, 2);
788			assert_matches!(*lock.lock().unwrap(), WaitOutcome::TimedOut);
789		}
790
791		#[test]
792		fn spawn_worker_should_not_change_finished_outcome() {
793			let condvar = Arc::new((Mutex::new(WaitOutcome::Finished), Condvar::new()));
794			let response =
795				spawn_worker_thread("thread", move || 2, condvar.clone(), WaitOutcome::TimedOut);
796
797			let r = response.unwrap().join().unwrap();
798			assert_eq!(r, 2);
799			assert_matches!(*condvar.0.lock().unwrap(), WaitOutcome::Finished);
800		}
801
802		#[test]
803		fn cond_notify_on_done_should_update_wait_outcome_when_panic() {
804			let condvar = Arc::new((Mutex::new(WaitOutcome::Pending), Condvar::new()));
805			let err = panic::catch_unwind(panic::AssertUnwindSafe(|| {
806				cond_notify_on_done(|| panic!("test"), condvar.clone(), WaitOutcome::Finished)
807			}));
808
809			assert_matches!(*condvar.0.lock().unwrap(), WaitOutcome::Finished);
810			assert!(err.is_err());
811		}
812	}
813}
814
815#[cfg(test)]
816mod tests {
817	use super::*;
818	use std::sync::mpsc::channel;
819
820	#[test]
821	fn cpu_time_monitor_loop_should_return_time_elapsed() {
822		let cpu_time_start = ProcessTime::now();
823		let timeout = Duration::from_secs(0);
824		let (_tx, rx) = channel();
825		let result = cpu_time_monitor_loop(cpu_time_start, timeout, rx);
826		assert_ne!(result, None);
827	}
828
829	#[test]
830	fn cpu_time_monitor_loop_should_return_none() {
831		let cpu_time_start = ProcessTime::now();
832		let timeout = Duration::from_secs(10);
833		let (tx, rx) = channel();
834		tx.send(()).unwrap();
835		let result = cpu_time_monitor_loop(cpu_time_start, timeout, rx);
836		assert_eq!(result, None);
837	}
838}