1#![allow(clippy::unit_arg)]
2
3use std::cmp;
4use std::fmt;
5use std::marker::PhantomData;
6use std::mem;
7use std::num::NonZeroUsize;
8
9use crate::errors::InvalidThreadAccess;
10use crate::registry;
11use crate::thread_id;
12use crate::StackToken;
13
14pub struct Sticky<T: 'static> {
31 item_id: registry::ItemId,
32 thread_id: NonZeroUsize,
33 _marker: PhantomData<*mut T>,
34}
35
36impl<T> Drop for Sticky<T> {
37 fn drop(&mut self) {
38 if mem::needs_drop::<T>() {
42 unsafe {
43 if self.is_valid() {
44 self.unsafe_take_value();
45 }
46 }
47
48 } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) {
53 unsafe {
54 (entry.drop)(entry.ptr);
55 }
56 }
57 }
58}
59
60impl<T> Sticky<T> {
61 pub fn new(value: T) -> Self {
68 let entry = registry::Entry {
69 ptr: Box::into_raw(Box::new(value)).cast(),
70 drop: |ptr| {
71 let ptr = ptr.cast::<T>();
72 drop(unsafe { Box::from_raw(ptr) });
75 },
76 };
77
78 let thread_id = thread_id::get();
79 let item_id = registry::insert(thread_id, entry);
80
81 Sticky {
82 item_id,
83 thread_id,
84 _marker: PhantomData,
85 }
86 }
87
88 #[inline(always)]
89 fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R {
90 self.assert_thread();
91
92 registry::with(self.item_id, self.thread_id, |entry| {
93 f(entry.ptr.cast::<T>())
94 })
95 }
96
97 #[inline(always)]
101 pub fn is_valid(&self) -> bool {
102 thread_id::get() == self.thread_id
103 }
104
105 #[inline(always)]
106 fn assert_thread(&self) {
107 if !self.is_valid() {
108 panic!("trying to access wrapped value in sticky container from incorrect thread.");
109 }
110 }
111
112 pub fn into_inner(mut self) -> T {
119 self.assert_thread();
120 unsafe {
121 let rv = self.unsafe_take_value();
122 mem::forget(self);
123 rv
124 }
125 }
126
127 unsafe fn unsafe_take_value(&mut self) -> T {
128 let ptr = registry::remove(self.item_id, self.thread_id)
129 .ptr
130 .cast::<T>();
131 *Box::from_raw(ptr)
132 }
133
134 pub fn try_into_inner(self) -> Result<T, Self> {
140 if self.is_valid() {
141 Ok(self.into_inner())
142 } else {
143 Err(self)
144 }
145 }
146
147 pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
154 self.with_value(|value| unsafe { &*value })
155 }
156
157 pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
164 self.with_value(|value| unsafe { &mut *value })
165 }
166
167 pub fn try_get<'stack>(
171 &'stack self,
172 _proof: &'stack StackToken,
173 ) -> Result<&'stack T, InvalidThreadAccess> {
174 if self.is_valid() {
175 Ok(self.with_value(|value| unsafe { &*value }))
176 } else {
177 Err(InvalidThreadAccess)
178 }
179 }
180
181 pub fn try_get_mut<'stack>(
185 &'stack mut self,
186 _proof: &'stack StackToken,
187 ) -> Result<&'stack mut T, InvalidThreadAccess> {
188 if self.is_valid() {
189 Ok(self.with_value(|value| unsafe { &mut *value }))
190 } else {
191 Err(InvalidThreadAccess)
192 }
193 }
194}
195
196impl<T> From<T> for Sticky<T> {
197 #[inline]
198 fn from(t: T) -> Sticky<T> {
199 Sticky::new(t)
200 }
201}
202
203impl<T: Clone> Clone for Sticky<T> {
204 #[inline]
205 fn clone(&self) -> Sticky<T> {
206 crate::stack_token!(tok);
207 Sticky::new(self.get(tok).clone())
208 }
209}
210
211impl<T: Default> Default for Sticky<T> {
212 #[inline]
213 fn default() -> Sticky<T> {
214 Sticky::new(T::default())
215 }
216}
217
218impl<T: PartialEq> PartialEq for Sticky<T> {
219 #[inline]
220 fn eq(&self, other: &Sticky<T>) -> bool {
221 crate::stack_token!(tok);
222 *self.get(tok) == *other.get(tok)
223 }
224}
225
226impl<T: Eq> Eq for Sticky<T> {}
227
228impl<T: PartialOrd> PartialOrd for Sticky<T> {
229 #[inline]
230 fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
231 crate::stack_token!(tok);
232 self.get(tok).partial_cmp(other.get(tok))
233 }
234
235 #[inline]
236 fn lt(&self, other: &Sticky<T>) -> bool {
237 crate::stack_token!(tok);
238 *self.get(tok) < *other.get(tok)
239 }
240
241 #[inline]
242 fn le(&self, other: &Sticky<T>) -> bool {
243 crate::stack_token!(tok);
244 *self.get(tok) <= *other.get(tok)
245 }
246
247 #[inline]
248 fn gt(&self, other: &Sticky<T>) -> bool {
249 crate::stack_token!(tok);
250 *self.get(tok) > *other.get(tok)
251 }
252
253 #[inline]
254 fn ge(&self, other: &Sticky<T>) -> bool {
255 crate::stack_token!(tok);
256 *self.get(tok) >= *other.get(tok)
257 }
258}
259
260impl<T: Ord> Ord for Sticky<T> {
261 #[inline]
262 fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
263 crate::stack_token!(tok);
264 self.get(tok).cmp(other.get(tok))
265 }
266}
267
268impl<T: fmt::Display> fmt::Display for Sticky<T> {
269 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
270 crate::stack_token!(tok);
271 fmt::Display::fmt(self.get(tok), f)
272 }
273}
274
275impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
276 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
277 crate::stack_token!(tok);
278 match self.try_get(tok) {
279 Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
280 Err(..) => {
281 struct InvalidPlaceholder;
282 impl fmt::Debug for InvalidPlaceholder {
283 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
284 f.write_str("<invalid thread>")
285 }
286 }
287
288 f.debug_struct("Sticky")
289 .field("value", &InvalidPlaceholder)
290 .finish()
291 }
292 }
293 }
294}
295
296unsafe impl<T> Sync for Sticky<T> {}
299
300unsafe impl<T> Send for Sticky<T> {}
302
303#[test]
304fn test_basic() {
305 use std::thread;
306 let val = Sticky::new(true);
307 crate::stack_token!(tok);
308 assert_eq!(val.to_string(), "true");
309 assert_eq!(val.get(tok), &true);
310 assert!(val.try_get(tok).is_ok());
311 thread::spawn(move || {
312 crate::stack_token!(tok);
313 assert!(val.try_get(tok).is_err());
314 })
315 .join()
316 .unwrap();
317}
318
319#[test]
320fn test_mut() {
321 let mut val = Sticky::new(true);
322 crate::stack_token!(tok);
323 *val.get_mut(tok) = false;
324 assert_eq!(val.to_string(), "false");
325 assert_eq!(val.get(tok), &false);
326}
327
328#[test]
329#[should_panic]
330fn test_access_other_thread() {
331 use std::thread;
332 let val = Sticky::new(true);
333 thread::spawn(move || {
334 crate::stack_token!(tok);
335 val.get(tok);
336 })
337 .join()
338 .unwrap();
339}
340
341#[test]
342fn test_drop_same_thread() {
343 use std::sync::atomic::{AtomicBool, Ordering};
344 use std::sync::Arc;
345 let was_called = Arc::new(AtomicBool::new(false));
346 struct X(Arc<AtomicBool>);
347 impl Drop for X {
348 fn drop(&mut self) {
349 self.0.store(true, Ordering::SeqCst);
350 }
351 }
352 let val = Sticky::new(X(was_called.clone()));
353 mem::drop(val);
354 assert!(was_called.load(Ordering::SeqCst));
355}
356
357#[test]
358fn test_noop_drop_elsewhere() {
359 use std::sync::atomic::{AtomicBool, Ordering};
360 use std::sync::Arc;
361 use std::thread;
362
363 let was_called = Arc::new(AtomicBool::new(false));
364
365 {
366 let was_called = was_called.clone();
367 thread::spawn(move || {
368 struct X(Arc<AtomicBool>);
369 impl Drop for X {
370 fn drop(&mut self) {
371 self.0.store(true, Ordering::SeqCst);
372 }
373 }
374
375 let val = Sticky::new(X(was_called.clone()));
376 assert!(thread::spawn(move || {
377 crate::stack_token!(tok);
379 val.try_get(tok).ok();
380 })
381 .join()
382 .is_ok());
383
384 assert!(!was_called.load(Ordering::SeqCst));
385 })
386 .join()
387 .unwrap();
388 }
389
390 assert!(was_called.load(Ordering::SeqCst));
391}
392
393#[test]
394fn test_rc_sending() {
395 use std::rc::Rc;
396 use std::thread;
397 let val = Sticky::new(Rc::new(true));
398 thread::spawn(move || {
399 crate::stack_token!(tok);
400 assert!(val.try_get(tok).is_err());
401 })
402 .join()
403 .unwrap();
404}
405
406#[test]
407fn test_two_stickies() {
408 struct Wat;
409
410 impl Drop for Wat {
411 fn drop(&mut self) {
412 }
414 }
415
416 let s1 = Sticky::new(Wat);
417 let s2 = Sticky::new(Wat);
418
419 drop(s1);
422 drop(s2);
423}