fragile/
sticky.rs

1#![allow(clippy::unit_arg)]
2
3use std::cmp;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::num::NonZeroUsize;
8
9use crate::errors::InvalidThreadAccess;
10use crate::registry;
11use crate::thread_id;
12use crate::StackToken;
13
14/// A [`Sticky<T>`] keeps a value T stored in a thread.
15///
16/// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a
17/// similar interface.  The difference is that whereas [`Fragile`](crate::Fragile) has
18/// its destructor called in the thread where the value was sent, a
19/// [`Sticky`] that is moved to another thread will have the internal
20/// destructor called when the originating thread tears down.
21///
22/// Because [`Sticky`] allows values to be kept alive for longer than the
23/// [`Sticky`] itself, it requires all its contents to be `'static` for
24/// soundness.  More importantly it also requires the use of [`StackToken`]s.
25/// For information about how to use stack tokens and why they are neded,
26/// refer to [`stack_token!`](crate::stack_token).
27///
28/// As this uses TLS internally the general rules about the platform limitations
29/// of destructors for TLS apply.
30pub struct Sticky<T: 'static> {
31    item_id: registry::ItemId,
32    thread_id: NonZeroUsize,
33    _marker: PhantomData<*mut T>,
34}
35
36impl<T> Drop for Sticky<T> {
37    fn drop(&mut self) {
38        // if the type needs dropping we can only do so on the
39        // right thread.  worst case we leak the value until the
40        // thread dies.
41        if mem::needs_drop::<T>() {
42            unsafe {
43                if self.is_valid() {
44                    self.unsafe_take_value();
45                }
46            }
47
48        // otherwise we take the liberty to drop the value
49        // right here and now.  We can however only do that if
50        // we are on the right thread.  If we are not, we again
51        // need to wait for the thread to shut down.
52        } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) {
53            unsafe {
54                (entry.drop)(entry.ptr);
55            }
56        }
57    }
58}
59
60impl<T> Sticky<T> {
61    /// Creates a new [`Sticky`] wrapping a `value`.
62    ///
63    /// The value that is moved into the [`Sticky`] can be non `Send` and
64    /// will be anchored to the thread that created the object.  If the
65    /// sticky wrapper type ends up being send from thread to thread
66    /// only the original thread can interact with the value.
67    pub fn new(value: T) -> Self {
68        let entry = registry::Entry {
69            ptr: Box::into_raw(Box::new(value)).cast(),
70            drop: |ptr| {
71                let ptr = ptr.cast::<T>();
72                // SAFETY: This callback will only be called once, with the
73                // above pointer.
74                drop(unsafe { Box::from_raw(ptr) });
75            },
76        };
77
78        let thread_id = thread_id::get();
79        let item_id = registry::insert(thread_id, entry);
80
81        Sticky {
82            item_id,
83            thread_id,
84            _marker: PhantomData,
85        }
86    }
87
88    #[inline(always)]
89    fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R {
90        self.assert_thread();
91
92        registry::with(self.item_id, self.thread_id, |entry| {
93            f(entry.ptr.cast::<T>())
94        })
95    }
96
97    /// Returns `true` if the access is valid.
98    ///
99    /// This will be `false` if the value was sent to another thread.
100    #[inline(always)]
101    pub fn is_valid(&self) -> bool {
102        thread_id::get() == self.thread_id
103    }
104
105    #[inline(always)]
106    fn assert_thread(&self) {
107        if !self.is_valid() {
108            panic!("trying to access wrapped value in sticky container from incorrect thread.");
109        }
110    }
111
112    /// Consumes the `Sticky`, returning the wrapped value.
113    ///
114    /// # Panics
115    ///
116    /// Panics if called from a different thread than the one where the
117    /// original value was created.
118    pub fn into_inner(mut self) -> T {
119        self.assert_thread();
120        unsafe {
121            let rv = self.unsafe_take_value();
122            mem::forget(self);
123            rv
124        }
125    }
126
127    unsafe fn unsafe_take_value(&mut self) -> T {
128        let ptr = registry::remove(self.item_id, self.thread_id)
129            .ptr
130            .cast::<T>();
131        *Box::from_raw(ptr)
132    }
133
134    /// Consumes the `Sticky`, returning the wrapped value if successful.
135    ///
136    /// The wrapped value is returned if this is called from the same thread
137    /// as the one where the original value was created, otherwise the
138    /// `Sticky` is returned as `Err(self)`.
139    pub fn try_into_inner(self) -> Result<T, Self> {
140        if self.is_valid() {
141            Ok(self.into_inner())
142        } else {
143            Err(self)
144        }
145    }
146
147    /// Immutably borrows the wrapped value.
148    ///
149    /// # Panics
150    ///
151    /// Panics if the calling thread is not the one that wrapped the value.
152    /// For a non-panicking variant, use [`try_get`](#method.try_get`).
153    pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
154        self.with_value(|value| unsafe { &*value })
155    }
156
157    /// Mutably borrows the wrapped value.
158    ///
159    /// # Panics
160    ///
161    /// Panics if the calling thread is not the one that wrapped the value.
162    /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
163    pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
164        self.with_value(|value| unsafe { &mut *value })
165    }
166
167    /// Tries to immutably borrow the wrapped value.
168    ///
169    /// Returns `None` if the calling thread is not the one that wrapped the value.
170    pub fn try_get<'stack>(
171        &'stack self,
172        _proof: &'stack StackToken,
173    ) -> Result<&'stack T, InvalidThreadAccess> {
174        if self.is_valid() {
175            Ok(self.with_value(|value| unsafe { &*value }))
176        } else {
177            Err(InvalidThreadAccess)
178        }
179    }
180
181    /// Tries to mutably borrow the wrapped value.
182    ///
183    /// Returns `None` if the calling thread is not the one that wrapped the value.
184    pub fn try_get_mut<'stack>(
185        &'stack mut self,
186        _proof: &'stack StackToken,
187    ) -> Result<&'stack mut T, InvalidThreadAccess> {
188        if self.is_valid() {
189            Ok(self.with_value(|value| unsafe { &mut *value }))
190        } else {
191            Err(InvalidThreadAccess)
192        }
193    }
194}
195
196impl<T> From<T> for Sticky<T> {
197    #[inline]
198    fn from(t: T) -> Sticky<T> {
199        Sticky::new(t)
200    }
201}
202
203impl<T: Clone> Clone for Sticky<T> {
204    #[inline]
205    fn clone(&self) -> Sticky<T> {
206        crate::stack_token!(tok);
207        Sticky::new(self.get(tok).clone())
208    }
209}
210
211impl<T: Default> Default for Sticky<T> {
212    #[inline]
213    fn default() -> Sticky<T> {
214        Sticky::new(T::default())
215    }
216}
217
218impl<T: PartialEq> PartialEq for Sticky<T> {
219    #[inline]
220    fn eq(&self, other: &Sticky<T>) -> bool {
221        crate::stack_token!(tok);
222        *self.get(tok) == *other.get(tok)
223    }
224}
225
226impl<T: Eq> Eq for Sticky<T> {}
227
228impl<T: PartialOrd> PartialOrd for Sticky<T> {
229    #[inline]
230    fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
231        crate::stack_token!(tok);
232        self.get(tok).partial_cmp(other.get(tok))
233    }
234
235    #[inline]
236    fn lt(&self, other: &Sticky<T>) -> bool {
237        crate::stack_token!(tok);
238        *self.get(tok) < *other.get(tok)
239    }
240
241    #[inline]
242    fn le(&self, other: &Sticky<T>) -> bool {
243        crate::stack_token!(tok);
244        *self.get(tok) <= *other.get(tok)
245    }
246
247    #[inline]
248    fn gt(&self, other: &Sticky<T>) -> bool {
249        crate::stack_token!(tok);
250        *self.get(tok) > *other.get(tok)
251    }
252
253    #[inline]
254    fn ge(&self, other: &Sticky<T>) -> bool {
255        crate::stack_token!(tok);
256        *self.get(tok) >= *other.get(tok)
257    }
258}
259
260impl<T: Ord> Ord for Sticky<T> {
261    #[inline]
262    fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
263        crate::stack_token!(tok);
264        self.get(tok).cmp(other.get(tok))
265    }
266}
267
268impl<T: fmt::Display> fmt::Display for Sticky<T> {
269    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
270        crate::stack_token!(tok);
271        fmt::Display::fmt(self.get(tok), f)
272    }
273}
274
275impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
276    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
277        crate::stack_token!(tok);
278        match self.try_get(tok) {
279            Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
280            Err(..) => {
281                struct InvalidPlaceholder;
282                impl fmt::Debug for InvalidPlaceholder {
283                    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
284                        f.write_str("<invalid thread>")
285                    }
286                }
287
288                f.debug_struct("Sticky")
289                    .field("value", &InvalidPlaceholder)
290                    .finish()
291            }
292        }
293    }
294}
295
296// similar as for fragile ths type is sync because it only accesses TLS data
297// which is thread local.  There is nothing that needs to be synchronized.
298unsafe impl<T> Sync for Sticky<T> {}
299
300// The entire point of this type is to be Send
301unsafe impl<T> Send for Sticky<T> {}
302
303#[test]
304fn test_basic() {
305    use std::thread;
306    let val = Sticky::new(true);
307    crate::stack_token!(tok);
308    assert_eq!(val.to_string(), "true");
309    assert_eq!(val.get(tok), &true);
310    assert!(val.try_get(tok).is_ok());
311    thread::spawn(move || {
312        crate::stack_token!(tok);
313        assert!(val.try_get(tok).is_err());
314    })
315    .join()
316    .unwrap();
317}
318
319#[test]
320fn test_mut() {
321    let mut val = Sticky::new(true);
322    crate::stack_token!(tok);
323    *val.get_mut(tok) = false;
324    assert_eq!(val.to_string(), "false");
325    assert_eq!(val.get(tok), &false);
326}
327
328#[test]
329#[should_panic]
330fn test_access_other_thread() {
331    use std::thread;
332    let val = Sticky::new(true);
333    thread::spawn(move || {
334        crate::stack_token!(tok);
335        val.get(tok);
336    })
337    .join()
338    .unwrap();
339}
340
341#[test]
342fn test_drop_same_thread() {
343    use std::sync::atomic::{AtomicBool, Ordering};
344    use std::sync::Arc;
345    let was_called = Arc::new(AtomicBool::new(false));
346    struct X(Arc<AtomicBool>);
347    impl Drop for X {
348        fn drop(&mut self) {
349            self.0.store(true, Ordering::SeqCst);
350        }
351    }
352    let val = Sticky::new(X(was_called.clone()));
353    mem::drop(val);
354    assert!(was_called.load(Ordering::SeqCst));
355}
356
357#[test]
358fn test_noop_drop_elsewhere() {
359    use std::sync::atomic::{AtomicBool, Ordering};
360    use std::sync::Arc;
361    use std::thread;
362
363    let was_called = Arc::new(AtomicBool::new(false));
364
365    {
366        let was_called = was_called.clone();
367        thread::spawn(move || {
368            struct X(Arc<AtomicBool>);
369            impl Drop for X {
370                fn drop(&mut self) {
371                    self.0.store(true, Ordering::SeqCst);
372                }
373            }
374
375            let val = Sticky::new(X(was_called.clone()));
376            assert!(thread::spawn(move || {
377                // moves it here but do not deallocate
378                crate::stack_token!(tok);
379                val.try_get(tok).ok();
380            })
381            .join()
382            .is_ok());
383
384            assert!(!was_called.load(Ordering::SeqCst));
385        })
386        .join()
387        .unwrap();
388    }
389
390    assert!(was_called.load(Ordering::SeqCst));
391}
392
393#[test]
394fn test_rc_sending() {
395    use std::rc::Rc;
396    use std::thread;
397    let val = Sticky::new(Rc::new(true));
398    thread::spawn(move || {
399        crate::stack_token!(tok);
400        assert!(val.try_get(tok).is_err());
401    })
402    .join()
403    .unwrap();
404}
405
406#[test]
407fn test_two_stickies() {
408    struct Wat;
409
410    impl Drop for Wat {
411        fn drop(&mut self) {
412            // do nothing
413        }
414    }
415
416    let s1 = Sticky::new(Wat);
417    let s2 = Sticky::new(Wat);
418
419    // make sure all is well
420
421    drop(s1);
422    drop(s2);
423}