1use crate::transport::{DialOpts, ListenerId, Transport, TransportError, TransportEvent};
22use fnv::FnvHashMap;
23use futures::{channel::mpsc, future::Ready, prelude::*, task::Context, task::Poll};
24use multiaddr::{Multiaddr, Protocol};
25use once_cell::sync::Lazy;
26use parking_lot::Mutex;
27use rw_stream_sink::RwStreamSink;
28use std::{
29 collections::{hash_map::Entry, VecDeque},
30 error, fmt, io,
31 num::NonZeroU64,
32 pin::Pin,
33};
34
35static HUB: Lazy<Hub> = Lazy::new(|| Hub(Mutex::new(FnvHashMap::default())));
36
37struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
38
39type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
42
43type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
46
47impl Hub {
48 fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
53 let mut hub = self.0.lock();
54
55 let port = if let Some(port) = NonZeroU64::new(port) {
56 port
57 } else {
58 loop {
59 let Some(port) = NonZeroU64::new(rand::random()) else {
60 continue;
61 };
62 if !hub.contains_key(&port) {
63 break port;
64 }
65 }
66 };
67
68 let (tx, rx) = mpsc::channel(2);
69 match hub.entry(port) {
70 Entry::Occupied(_) => return None,
71 Entry::Vacant(e) => e.insert(tx),
72 };
73
74 Some((rx, port))
75 }
76
77 fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
78 self.0.lock().remove(port)
79 }
80
81 fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
82 self.0.lock().get(port).cloned()
83 }
84}
85
86#[derive(Default)]
88pub struct MemoryTransport {
89 listeners: VecDeque<Pin<Box<Listener>>>,
90}
91
92impl MemoryTransport {
93 pub fn new() -> Self {
94 Self::default()
95 }
96}
97
98pub struct DialFuture {
100 dial_port: NonZeroU64,
107 sender: ChannelSender,
108 channel_to_send: Option<Channel<Vec<u8>>>,
109 channel_to_return: Option<Channel<Vec<u8>>>,
110}
111
112impl DialFuture {
113 fn new(port: NonZeroU64) -> Option<Self> {
114 let sender = HUB.get(&port)?;
115
116 let (_dial_port_channel, dial_port) = HUB
117 .register_port(0)
118 .expect("there to be some random unoccupied port.");
119
120 let (a_tx, a_rx) = mpsc::channel(4096);
121 let (b_tx, b_rx) = mpsc::channel(4096);
122 Some(DialFuture {
123 dial_port,
124 sender,
125 channel_to_send: Some(RwStreamSink::new(Chan {
126 incoming: a_rx,
127 outgoing: b_tx,
128 dial_port: None,
129 })),
130 channel_to_return: Some(RwStreamSink::new(Chan {
131 incoming: b_rx,
132 outgoing: a_tx,
133 dial_port: Some(dial_port),
134 })),
135 })
136 }
137}
138
139impl Future for DialFuture {
140 type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
141
142 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143 match self.sender.poll_ready(cx) {
144 Poll::Pending => return Poll::Pending,
145 Poll::Ready(Ok(())) => {}
146 Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
147 }
148
149 let channel_to_send = self
150 .channel_to_send
151 .take()
152 .expect("Future should not be polled again once complete");
153 let dial_port = self.dial_port;
154 if self
155 .sender
156 .start_send((channel_to_send, dial_port))
157 .is_err()
158 {
159 return Poll::Ready(Err(MemoryTransportError::Unreachable));
160 }
161
162 Poll::Ready(Ok(self
163 .channel_to_return
164 .take()
165 .expect("Future should not be polled again once complete")))
166 }
167}
168
169impl Transport for MemoryTransport {
170 type Output = Channel<Vec<u8>>;
171 type Error = MemoryTransportError;
172 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
173 type Dial = DialFuture;
174
175 fn listen_on(
176 &mut self,
177 id: ListenerId,
178 addr: Multiaddr,
179 ) -> Result<(), TransportError<Self::Error>> {
180 let port =
181 parse_memory_addr(&addr).map_err(|_| TransportError::MultiaddrNotSupported(addr))?;
182
183 let (rx, port) = HUB
184 .register_port(port)
185 .ok_or(TransportError::Other(MemoryTransportError::Unreachable))?;
186
187 let listener = Listener {
188 id,
189 port,
190 addr: Protocol::Memory(port.get()).into(),
191 receiver: rx,
192 tell_listen_addr: true,
193 };
194 self.listeners.push_back(Box::pin(listener));
195
196 Ok(())
197 }
198
199 fn remove_listener(&mut self, id: ListenerId) -> bool {
200 if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
201 let listener = self.listeners.get_mut(index).unwrap();
202 let val_in = HUB.unregister_port(&listener.port);
203 debug_assert!(val_in.is_some());
204 listener.receiver.close();
205 true
206 } else {
207 false
208 }
209 }
210
211 fn dial(
212 &mut self,
213 addr: Multiaddr,
214 _opts: DialOpts,
215 ) -> Result<DialFuture, TransportError<Self::Error>> {
216 let port = if let Ok(port) = parse_memory_addr(&addr) {
217 if let Some(port) = NonZeroU64::new(port) {
218 port
219 } else {
220 return Err(TransportError::Other(MemoryTransportError::Unreachable));
221 }
222 } else {
223 return Err(TransportError::MultiaddrNotSupported(addr));
224 };
225
226 DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
227 }
228
229 fn poll(
230 mut self: Pin<&mut Self>,
231 cx: &mut Context<'_>,
232 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
233 where
234 Self: Sized,
235 {
236 let mut remaining = self.listeners.len();
237 while let Some(mut listener) = self.listeners.pop_back() {
238 if listener.tell_listen_addr {
239 listener.tell_listen_addr = false;
240 let listen_addr = listener.addr.clone();
241 let listener_id = listener.id;
242 self.listeners.push_front(listener);
243 return Poll::Ready(TransportEvent::NewAddress {
244 listen_addr,
245 listener_id,
246 });
247 }
248
249 let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
250 Poll::Pending => None,
251 Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
252 listener_id: listener.id,
253 upgrade: future::ready(Ok(channel)),
254 local_addr: listener.addr.clone(),
255 send_back_addr: Protocol::Memory(dial_port.get()).into(),
256 }),
257 Poll::Ready(None) => {
258 return Poll::Ready(TransportEvent::ListenerClosed {
260 listener_id: listener.id,
261 reason: Ok(()),
262 });
263 }
264 };
265
266 self.listeners.push_front(listener);
267 if let Some(event) = event {
268 return Poll::Ready(event);
269 } else {
270 remaining -= 1;
271 if remaining == 0 {
272 break;
273 }
274 }
275 }
276 Poll::Pending
277 }
278}
279
280#[derive(Debug, Copy, Clone)]
282pub enum MemoryTransportError {
283 Unreachable,
285 AlreadyInUse,
287}
288
289impl fmt::Display for MemoryTransportError {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 match *self {
292 MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
293 MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
294 }
295 }
296}
297
298impl error::Error for MemoryTransportError {}
299
300pub struct Listener {
302 id: ListenerId,
303 port: NonZeroU64,
305 addr: Multiaddr,
307 receiver: ChannelReceiver,
309 tell_listen_addr: bool,
311}
312
313fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
315 let mut protocols = a.iter();
316 match protocols.next() {
317 Some(Protocol::Memory(port)) => match protocols.next() {
318 None | Some(Protocol::P2p(_)) => Ok(port),
319 _ => Err(()),
320 },
321 _ => Err(()),
322 }
323}
324
325pub type Channel<T> = RwStreamSink<Chan<T>>;
329
330pub struct Chan<T = Vec<u8>> {
334 incoming: mpsc::Receiver<T>,
335 outgoing: mpsc::Sender<T>,
336
337 dial_port: Option<NonZeroU64>,
344}
345
346impl<T> Unpin for Chan<T> {}
347
348impl<T> Stream for Chan<T> {
349 type Item = Result<T, io::Error>;
350
351 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
352 match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
353 Poll::Pending => Poll::Pending,
354 Poll::Ready(None) => Poll::Ready(None),
355 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
356 }
357 }
358}
359
360impl<T> Sink<T> for Chan<T> {
361 type Error = io::Error;
362
363 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
364 self.outgoing
365 .poll_ready(cx)
366 .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
367 }
368
369 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
370 self.outgoing
371 .start_send(item)
372 .map_err(|_| io::ErrorKind::BrokenPipe.into())
373 }
374
375 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
376 Poll::Ready(Ok(()))
377 }
378
379 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
380 Poll::Ready(Ok(()))
381 }
382}
383
384impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
385 fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
386 RwStreamSink::new(channel)
387 }
388}
389
390impl<T> Drop for Chan<T> {
391 fn drop(&mut self) {
392 if let Some(port) = self.dial_port {
393 let channel_sender = HUB.unregister_port(&port);
394 debug_assert!(channel_sender.is_some());
395 }
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use crate::{transport::PortUse, Endpoint};
402
403 use super::*;
404
405 #[test]
406 fn parse_memory_addr_works() {
407 assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
408 assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
409 assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
410 assert_eq!(
411 parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
412 Err(())
413 );
414 assert_eq!(
415 parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
416 Err(())
417 );
418 assert_eq!(
419 parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
420 Ok(1_234_567_890)
421 );
422 assert_eq!(
423 parse_memory_addr(
424 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
425 .parse()
426 .unwrap()
427 ),
428 Ok(5)
429 );
430 assert_eq!(
431 parse_memory_addr(
432 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
433 .parse()
434 .unwrap()
435 ),
436 Ok(5)
437 );
438 }
439
440 #[test]
441 fn listening_twice() {
442 let mut transport = MemoryTransport::default();
443
444 let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
445 let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
446
447 let listener_id_1 = ListenerId::next();
448
449 transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
450 assert!(
451 transport.remove_listener(listener_id_1),
452 "Listener doesn't exist."
453 );
454
455 let listener_id_2 = ListenerId::next();
456 transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
457 let listener_id_3 = ListenerId::next();
458 transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
459
460 assert!(transport
461 .listen_on(ListenerId::next(), addr_1.clone())
462 .is_err());
463 assert!(transport
464 .listen_on(ListenerId::next(), addr_2.clone())
465 .is_err());
466
467 assert!(
468 transport.remove_listener(listener_id_2),
469 "Listener doesn't exist."
470 );
471 assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
472 assert!(transport
473 .listen_on(ListenerId::next(), addr_2.clone())
474 .is_err());
475
476 assert!(
477 transport.remove_listener(listener_id_3),
478 "Listener doesn't exist."
479 );
480 assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
481 }
482
483 #[test]
484 fn port_not_in_use() {
485 let mut transport = MemoryTransport::default();
486 assert!(transport
487 .dial(
488 "/memory/810172461024613".parse().unwrap(),
489 DialOpts {
490 role: Endpoint::Dialer,
491 port_use: PortUse::New
492 }
493 )
494 .is_err());
495 transport
496 .listen_on(
497 ListenerId::next(),
498 "/memory/810172461024613".parse().unwrap(),
499 )
500 .unwrap();
501 assert!(transport
502 .dial(
503 "/memory/810172461024613".parse().unwrap(),
504 DialOpts {
505 role: Endpoint::Dialer,
506 port_use: PortUse::New
507 }
508 )
509 .is_ok());
510 }
511
512 #[test]
513 fn stop_listening() {
514 let rand_port = rand::random::<u64>().saturating_add(1);
515 let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
516
517 let mut transport = MemoryTransport::default().boxed();
518 futures::executor::block_on(async {
519 let listener_id = ListenerId::next();
520 transport.listen_on(listener_id, addr.clone()).unwrap();
521 let reported_addr = transport
522 .select_next_some()
523 .await
524 .into_new_address()
525 .expect("new address");
526 assert_eq!(addr, reported_addr);
527 assert!(transport.remove_listener(listener_id));
528 match transport.select_next_some().await {
529 TransportEvent::ListenerClosed {
530 listener_id: id,
531 reason,
532 } => {
533 assert_eq!(id, listener_id);
534 assert!(reason.is_ok())
535 }
536 other => panic!("Unexpected transport event: {other:?}"),
537 }
538 assert!(!transport.remove_listener(listener_id));
539 })
540 }
541
542 #[test]
543 fn communicating_between_dialer_and_listener() {
544 let msg = [1, 2, 3];
545
546 let rand_port = rand::random::<u64>().saturating_add(1);
549 let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
550 let cloned_t1_addr = t1_addr.clone();
551
552 let mut t1 = MemoryTransport::default().boxed();
553
554 let listener = async move {
555 t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
556 let upgrade = loop {
557 let event = t1.select_next_some().await;
558 if let Some(upgrade) = event.into_incoming() {
559 break upgrade;
560 }
561 };
562
563 let mut socket = upgrade.0.await.unwrap();
564
565 let mut buf = [0; 3];
566 socket.read_exact(&mut buf).await.unwrap();
567
568 assert_eq!(buf, msg);
569 };
570
571 let mut t2 = MemoryTransport::default();
574 let dialer = async move {
575 let mut socket = t2
576 .dial(
577 cloned_t1_addr,
578 DialOpts {
579 role: Endpoint::Dialer,
580 port_use: PortUse::New,
581 },
582 )
583 .unwrap()
584 .await
585 .unwrap();
586 socket.write_all(&msg).await.unwrap();
587 };
588
589 futures::executor::block_on(futures::future::join(listener, dialer));
592 }
593
594 #[test]
595 fn dialer_address_unequal_to_listener_address() {
596 let listener_addr: Multiaddr =
597 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
598 let listener_addr_cloned = listener_addr.clone();
599
600 let mut listener_transport = MemoryTransport::default().boxed();
601
602 let listener = async move {
603 listener_transport
604 .listen_on(ListenerId::next(), listener_addr.clone())
605 .unwrap();
606 loop {
607 if let TransportEvent::Incoming { send_back_addr, .. } =
608 listener_transport.select_next_some().await
609 {
610 assert!(
611 send_back_addr != listener_addr,
612 "Expect dialer address not to equal listener address."
613 );
614 return;
615 }
616 }
617 };
618
619 let dialer = async move {
620 MemoryTransport::default()
621 .dial(
622 listener_addr_cloned,
623 DialOpts {
624 role: Endpoint::Dialer,
625 port_use: PortUse::New,
626 },
627 )
628 .unwrap()
629 .await
630 .unwrap();
631 };
632
633 futures::executor::block_on(futures::future::join(listener, dialer));
634 }
635
636 #[test]
637 fn dialer_port_is_deregistered() {
638 let (terminate, should_terminate) = futures::channel::oneshot::channel();
639 let (terminated, is_terminated) = futures::channel::oneshot::channel();
640
641 let listener_addr: Multiaddr =
642 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
643 let listener_addr_cloned = listener_addr.clone();
644
645 let mut listener_transport = MemoryTransport::default().boxed();
646
647 let listener = async move {
648 listener_transport
649 .listen_on(ListenerId::next(), listener_addr.clone())
650 .unwrap();
651 loop {
652 if let TransportEvent::Incoming { send_back_addr, .. } =
653 listener_transport.select_next_some().await
654 {
655 let dialer_port =
656 NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
657
658 assert!(
659 HUB.get(&dialer_port).is_some(),
660 "Expect dialer port to stay registered while connection is in use.",
661 );
662
663 terminate.send(()).unwrap();
664 is_terminated.await.unwrap();
665
666 assert!(
667 HUB.get(&dialer_port).is_none(),
668 "Expect dialer port to be deregistered once connection is dropped.",
669 );
670
671 return;
672 }
673 }
674 };
675
676 let dialer = async move {
677 let chan = MemoryTransport::default()
678 .dial(
679 listener_addr_cloned,
680 DialOpts {
681 role: Endpoint::Dialer,
682 port_use: PortUse::New,
683 },
684 )
685 .unwrap()
686 .await
687 .unwrap();
688
689 should_terminate.await.unwrap();
690 drop(chan);
691 terminated.send(()).unwrap();
692 };
693
694 futures::executor::block_on(futures::future::join(listener, dialer));
695 }
696}