environmental/
lib.rs

1// Copyright 2017-2022 Parity Technologies
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Safe global references to stack variables.
16//!
17//! Set up a global reference with environmental! macro giving it a name and type.
18//! Use the `using` function scoped under its name to name a reference and call a function that
19//! takes no parameters yet can access said reference through the similarly placed `with` function.
20//!
21//! # Examples
22//!
23//! ```
24//! #[macro_use] extern crate environmental;
25//! // create a place for the global reference to exist.
26//! environmental!(counter: u32);
27//! fn stuff() {
28//!   // do some stuff, accessing the named reference as desired.
29//!   counter::with(|i| *i += 1);
30//! }
31//! fn main() {
32//!   // declare a stack variable of the same type as our global declaration.
33//!   let mut counter_value = 41u32;
34//!   // call stuff, setting up our `counter` environment as a reference to our counter_value var.
35//!   counter::using(&mut counter_value, stuff);
36//!   println!("The answer is {:?}", counter_value); // will print 42!
37//!   stuff();	// safe! doesn't do anything.
38//! }
39//! ```
40
41#![cfg_attr(not(feature = "std"), no_std)]
42
43extern crate alloc;
44
45#[doc(hidden)]
46pub use core::{cell::RefCell, mem::{transmute, replace}, marker::PhantomData};
47
48#[doc(hidden)]
49pub use alloc::{rc::Rc, vec::Vec};
50
51#[cfg(not(feature = "std"))]
52#[macro_export]
53mod local_key;
54
55#[doc(hidden)]
56#[cfg(not(feature = "std"))]
57pub use local_key::LocalKey;
58
59#[doc(hidden)]
60#[cfg(feature = "std")]
61pub use std::thread::LocalKey;
62
63#[doc(hidden)]
64#[cfg(feature = "std")]
65#[macro_export]
66macro_rules! thread_local_impl {
67	($(#[$attr:meta])* static $name:ident: $t:ty = $init:expr) => (
68		thread_local!($(#[$attr])* static $name: $t = $init);
69	);
70}
71
72#[doc(hidden)]
73#[cfg(not(feature = "std"))]
74#[macro_export]
75macro_rules! thread_local_impl {
76	($(#[$attr:meta])* static $name:ident: $t:ty = $init:expr) => (
77		$(#[$attr])*
78		static $name: $crate::LocalKey<$t> = {
79			fn __init() -> $t { $init }
80
81			$crate::local_key_init!(__init)
82		};
83	);
84}
85
86/// The global inner that stores the stack of globals.
87#[doc(hidden)]
88pub type GlobalInner<T> = RefCell<Vec<Rc<RefCell<*mut T>>>>;
89
90/// The global type.
91type Global<T> = LocalKey<GlobalInner<T>>;
92
93struct PopGlobal<'a, T: 'a + ?Sized> {
94	global_stack: &'a GlobalInner<T>,
95}
96
97impl<'a, T: 'a + ?Sized> Drop for PopGlobal<'a, T> {
98	fn drop(&mut self) {
99		self.global_stack.borrow_mut().pop();
100	}
101}
102
103#[doc(hidden)]
104pub fn using<T: ?Sized, R, F: FnOnce() -> R>(
105	global: &'static Global<T>,
106	protected: &mut T,
107	f: F,
108) -> R {
109	// store the `protected` reference as a pointer so we can provide it to logic running within
110	// `f`.
111	// while we record this pointer (while it's non-zero) we guarantee:
112	// - it will only be used once at any time (no re-entrancy);
113	// - that no other thread will use it; and
114	// - that we do not use the original mutating reference while the pointer exists.
115	global.with(|r| {
116		// Push the new global to the end of the stack.
117		r.borrow_mut().push(
118			Rc::new(RefCell::new(protected as _)),
119		);
120
121		// Even if `f` panics the added global will be popped.
122		let _guard = PopGlobal { global_stack: r };
123
124		f()
125	})
126}
127
128#[doc(hidden)]
129pub fn using_once<T: ?Sized, R, F: FnOnce() -> R>(
130	global: &'static Global<T>,
131	protected: &mut T,
132	f: F,
133) -> R {
134	// store the `protected` reference as a pointer so we can provide it to logic running within
135	// `f`.
136	// while we record this pointer (while it's non-zero) we guarantee:
137	// - it will only be used once at any time (no re-entrancy);
138	// - that no other thread will use it; and
139	// - that we do not use the original mutating reference while the pointer exists.
140	global.with(|r| {
141		// If there is already some state set, we want to use it.
142		if r.borrow().last().is_some() {
143			f()
144		} else {
145			// Push the new global to the end of the stack.
146			r.borrow_mut().push(
147				Rc::new(RefCell::new(protected as _)),
148			);
149
150			// Even if `f` panics the added global will be popped.
151			let _guard = PopGlobal { global_stack: r };
152
153			f()
154		}
155	})
156}
157
158#[doc(hidden)]
159pub fn with<T: ?Sized, R, F: FnOnce(&mut T) -> R>(
160	global: &'static Global<T>,
161	mutator: F,
162) -> Option<R> {
163	global.with(|r| {
164		// We always use the `last` element when we want to access the
165		// currently set global.
166		let last = r.borrow().last().cloned();
167		match last {
168			Some(ptr) => unsafe {
169				// safe because it's only non-zero when it's being called from using, which
170				// is holding on to the underlying reference (and not using it itself) safely.
171				Some(mutator(&mut **ptr.borrow_mut()))
172			}
173			None => None,
174		}
175	})
176}
177
178/// Declare a new global reference module whose underlying value does not contain references.
179///
180/// Will create a module of a given name that contains two functions:
181///
182/// * `pub fn using<R, F: FnOnce() -> R>(protected: &mut $t, f: F) -> R`
183///   This executes `f`, returning its value. During the call, the module's reference is set to
184///   be equal to `protected`. When nesting `using` calls it will build a stack of the set values.
185///   Each call to `with` will always return the latest value in this stack.
186/// * `pub fn with<R, F: FnOnce(&mut $t) -> R>(f: F) -> Option<R>`
187///   This executes `f`, returning `Some` of its value if called from code that is being executed
188///   as part of a `using` call. If not, it returns `None`. `f` is provided with one argument: the
189///   same reference as provided to the most recent `using` call.
190/// * `pub fn using_once<R, F: FnOnce() -> R>(protected: &mut $t, f: F) -> R`
191///   This executes `f`, returning its value. During the call, the module's reference is set to
192///   be equal to `protected` when there is not already a value set. In contrast to `using` this
193///   will not build a stack of set values and it will use the already set value.
194///
195/// # Examples
196///
197/// Initializing the global context with a given value.
198///
199/// ```rust
200/// #[macro_use] extern crate environmental;
201/// environmental!(counter: u32);
202/// fn main() {
203///   let mut counter_value = 41u32;
204///   counter::using(&mut counter_value, || {
205///     let odd = counter::with(|value|
206///       if *value % 2 == 1 {
207///         *value += 1; true
208///       } else {
209///         *value -= 3; false
210///       }).unwrap();	// safe because we're inside a counter::using
211///     println!("counter was {}", match odd { true => "odd", _ => "even" });
212///   });
213///
214///   println!("The answer is {:?}", counter_value); // 42
215/// }
216/// ```
217///
218/// Roughly the same, but with a trait object:
219///
220/// ```rust
221/// #[macro_use] extern crate environmental;
222///
223/// trait Increment { fn increment(&mut self); }
224///
225/// impl Increment for i32 {
226///	fn increment(&mut self) { *self += 1 }
227/// }
228///
229/// environmental!(val: Increment + 'static);
230///
231/// fn main() {
232///	let mut local = 0i32;
233///	val::using(&mut local, || {
234///		val::with(|v| for _ in 0..5 { v.increment() });
235///	});
236///
237///	assert_eq!(local, 5);
238/// }
239/// ```
240#[macro_export]
241macro_rules! environmental {
242	($name:ident : $t:ty) => {
243		#[allow(non_camel_case_types)]
244		struct $name { __private_field: () }
245
246		$crate::thread_local_impl! {
247			static GLOBAL: $crate::GlobalInner<$t> = Default::default()
248		}
249
250		impl $name {
251			#[allow(unused_imports)]
252			#[allow(dead_code)]
253			pub fn using<R, F: FnOnce() -> R>(
254				protected: &mut $t,
255				f: F,
256			) -> R {
257				$crate::using(&GLOBAL, protected, f)
258			}
259
260			#[allow(dead_code)]
261			pub fn with<R, F: FnOnce(&mut $t) -> R>(
262				f: F,
263			) -> Option<R> {
264				$crate::with(&GLOBAL, |x| f(x))
265			}
266
267			#[allow(dead_code)]
268			pub fn using_once<R, F: FnOnce() -> R>(
269				protected: &mut $t,
270				f: F,
271			) -> R {
272				$crate::using_once(&GLOBAL, protected, f)
273			}
274		}
275	};
276	($name:ident : trait @$t:ident [$($args:ty,)*]) => {
277		#[allow(non_camel_case_types, dead_code)]
278		struct $name { __private_field: () }
279
280		$crate::thread_local_impl! {
281			static GLOBAL: $crate::GlobalInner<(dyn $t<$($args),*> + 'static)>
282				= Default::default()
283		}
284
285		impl $name {
286			#[allow(unused_imports)]
287			#[allow(dead_code)]
288			pub fn using<R, F: FnOnce() -> R>(
289				protected: &mut dyn $t<$($args),*>,
290				f: F
291			) -> R {
292				let lifetime_extended = unsafe {
293					$crate::transmute::<&mut dyn $t<$($args),*>, &mut (dyn $t<$($args),*> + 'static)>(protected)
294				};
295				$crate::using(&GLOBAL, lifetime_extended, f)
296			}
297
298			#[allow(dead_code)]
299			pub fn with<R, F: for<'a> FnOnce(&'a mut (dyn $t<$($args),*> + 'a)) -> R>(
300				f: F
301			) -> Option<R> {
302				$crate::with(&GLOBAL, |x| f(x))
303			}
304
305			#[allow(unused_imports)]
306			#[allow(dead_code)]
307			pub fn using_once<R, F: FnOnce() -> R>(
308				protected: &mut dyn $t<$($args),*>,
309				f: F
310			) -> R {
311				let lifetime_extended = unsafe {
312					$crate::transmute::<&mut dyn $t<$($args),*>, &mut (dyn $t<$($args),*> + 'static)>(protected)
313				};
314				$crate::using_once(&GLOBAL, lifetime_extended, f)
315			}
316		}
317	};
318	($name:ident<$traittype:ident> : trait $t:ident <$concretetype:ty>) => {
319		#[allow(non_camel_case_types, dead_code)]
320		struct $name <H: $traittype> { _private_field: $crate::PhantomData<H> }
321
322		$crate::thread_local_impl! {
323			static GLOBAL: $crate::GlobalInner<(dyn $t<$concretetype> + 'static)>
324				= Default::default()
325		}
326
327		impl<H: $traittype> $name<H> {
328			#[allow(unused_imports)]
329			#[allow(dead_code)]
330			pub fn using<R, F: FnOnce() -> R>(
331				protected: &mut dyn $t<H>,
332				f: F
333			) -> R {
334				let lifetime_extended = unsafe {
335					$crate::transmute::<&mut dyn $t<H>, &mut (dyn $t<$concretetype> + 'static)>(protected)
336				};
337				$crate::using(&GLOBAL, lifetime_extended, f)
338			}
339
340			#[allow(dead_code)]
341			pub fn with<R, F: for<'a> FnOnce(&'a mut (dyn $t<$concretetype> + 'a)) -> R>(
342				f: F
343			) -> Option<R> {
344				$crate::with(&GLOBAL, |x| f(x))
345			}
346
347			#[allow(unused_imports)]
348			#[allow(dead_code)]
349			pub fn using_once<R, F: FnOnce() -> R>(
350				protected: &mut dyn $t<H>,
351				f: F
352			) -> R {
353				let lifetime_extended = unsafe {
354					$crate::transmute::<&mut dyn $t<H>, &mut (dyn $t<$concretetype> + 'static)>(protected)
355				};
356				$crate::using_once(&GLOBAL, lifetime_extended, f)
357			}
358		}
359	};
360	($name:ident : trait $t:ident <>) => { $crate::environmental! { $name : trait @$t [] } };
361	($name:ident : trait $t:ident < $($args:ty),* $(,)* >) => {
362		$crate::environmental! { $name : trait @$t [$($args,)*] }
363	};
364	($name:ident : trait $t:ident) => { $crate::environmental! { $name : trait @$t [] } };
365}
366
367#[cfg(test)]
368mod tests {
369	// Test trait in item position
370	#[allow(dead_code)]
371	mod trait_test {
372		trait Test {}
373
374		environmental!(item_positon_trait: trait Test);
375	}
376
377	// Test type in item position
378	#[allow(dead_code)]
379	mod type_test {
380		environmental!(item_position_type: u32);
381	}
382
383	#[test]
384	fn simple_works() {
385		environmental!(counter: u32);
386
387		fn stuff() { counter::with(|value| *value += 1); }
388
389		// declare a stack variable of the same type as our global declaration.
390		let mut local = 41u32;
391
392		// call stuff, setting up our `counter` environment as a reference to our local counter var.
393		counter::using(&mut local, stuff);
394		assert_eq!(local, 42);
395		stuff();	// safe! doesn't do anything.
396		assert_eq!(local, 42);
397	}
398
399	#[test]
400	fn overwrite_with_lesser_lifetime() {
401		environmental!(items: Vec<u8>);
402
403		let mut local_items = vec![1, 2, 3];
404		items::using(&mut local_items, || {
405			let dies_at_end = vec![4, 5, 6];
406			items::with(|items| *items = dies_at_end);
407		});
408
409		assert_eq!(local_items, vec![4, 5, 6]);
410	}
411
412	#[test]
413	fn declare_with_trait_object() {
414		trait Foo {
415			fn get(&self) -> i32;
416			fn set(&mut self, x: i32);
417		}
418
419		impl Foo for i32 {
420			fn get(&self) -> i32 { *self }
421			fn set(&mut self, x: i32) { *self = x }
422		}
423
424		environmental!(foo: dyn Foo + 'static);
425
426		fn stuff() {
427			foo::with(|value| {
428				let new_val = value.get() + 1;
429				value.set(new_val);
430			});
431		}
432
433		let mut local = 41i32;
434		foo::using(&mut local, stuff);
435
436		assert_eq!(local, 42);
437
438		stuff(); // doesn't do anything.
439
440		assert_eq!(local, 42);
441	}
442
443	#[test]
444	fn unwind_recursive() {
445		use std::panic;
446
447		environmental!(items: Vec<u8>);
448
449		let panicked = panic::catch_unwind(|| {
450			let mut local_outer = vec![1, 2, 3];
451
452			items::using(&mut local_outer, || {
453				let mut local_inner = vec![4, 5, 6];
454				items::using(&mut local_inner, || {
455					panic!("are you unsafe?");
456				})
457			});
458		}).is_err();
459
460		assert!(panicked);
461
462		let mut was_cleared = true;
463		items::with(|_items| was_cleared = false);
464
465		assert!(was_cleared);
466	}
467
468	#[test]
469	fn use_non_static_trait() {
470		trait Sum { fn sum(&self) -> usize; }
471		impl Sum for &[usize] {
472			fn sum(&self) -> usize {
473				self.iter().fold(0, |a, c| a + c)
474			}
475		}
476
477		environmental!(sum: trait Sum);
478		let numbers = vec![1, 2, 3, 4, 5];
479		let mut numbers = &numbers[..];
480		let got_sum = sum::using(&mut numbers, || {
481			sum::with(|x| x.sum())
482		}).unwrap();
483
484		assert_eq!(got_sum, 15);
485	}
486
487	#[test]
488	fn stacking_globals() {
489		trait Sum { fn sum(&self) -> usize; }
490		impl Sum for &[usize] {
491			fn sum(&self) -> usize {
492				self.iter().fold(0, |a, c| a + c)
493			}
494		}
495
496		environmental!(sum: trait Sum);
497		let numbers = vec![1, 2, 3, 4, 5];
498		let mut numbers = &numbers[..];
499		let got_sum = sum::using(&mut numbers, || {
500			sum::with(|_| {
501				let numbers2 = vec![1, 2, 3, 4, 5, 6];
502				let mut numbers2 = &numbers2[..];
503				sum::using(&mut numbers2, || {
504					sum::with(|x| x.sum())
505				})
506			})
507		}).unwrap().unwrap();
508
509		assert_eq!(got_sum, 21);
510
511		assert!(sum::with(|_| ()).is_none());
512	}
513
514	#[test]
515	fn use_generic_trait() {
516		trait Plus { fn plus42() -> usize; }
517		struct ConcretePlus;
518		impl Plus for ConcretePlus {
519			fn plus42() -> usize { 42 }
520		}
521		trait Multiplier<T: Plus> { fn mul_and_add(&self) -> usize; }
522		impl<'a, P: Plus> Multiplier<P> for &'a [usize] {
523			fn mul_and_add(&self) -> usize {
524				self.iter().fold(1, |a, c| a * c) + P::plus42()
525			}
526		}
527
528		let numbers = vec![1, 2, 3];
529		let mut numbers = &numbers[..];
530		let out = foo::<ConcretePlus>::using(&mut numbers, || {
531			foo::<ConcretePlus>::with(|x| x.mul_and_add() )
532		}).unwrap();
533
534		assert_eq!(out, 6 + 42);
535		environmental!(foo<Plus>: trait Multiplier<ConcretePlus>);
536	}
537
538	#[test]
539	fn using_once_is_working() {
540		environmental!(value: u32);
541
542		let mut called_inner = false;
543
544		value::using_once(&mut 5, || {
545			value::using_once(&mut 10, || {
546				assert_eq!(5, value::with(|v| *v).unwrap());
547
548				value::using(&mut 20, || {
549					assert_eq!(20, value::with(|v| *v).unwrap());
550
551					value::using_once(&mut 30, || {
552						assert_eq!(20, value::with(|v| *v).unwrap());
553
554						called_inner = true;
555					})
556				})
557			})
558		});
559
560		assert!(called_inner);
561	}
562}