1use crate::transport::{ListenerId, Transport, TransportError, TransportEvent};
22use fnv::FnvHashMap;
23use futures::{
24 channel::mpsc,
25 future::{self, Ready},
26 prelude::*,
27 task::Context,
28 task::Poll,
29};
30use multiaddr::{Multiaddr, Protocol};
31use once_cell::sync::Lazy;
32use parking_lot::Mutex;
33use rw_stream_sink::RwStreamSink;
34use std::{
35 collections::{hash_map::Entry, VecDeque},
36 error, fmt, io,
37 num::NonZeroU64,
38 pin::Pin,
39};
40
41static HUB: Lazy<Hub> = Lazy::new(|| Hub(Mutex::new(FnvHashMap::default())));
42
43struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
44
45type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
48
49type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
52
53impl Hub {
54 fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
59 let mut hub = self.0.lock();
60
61 let port = if let Some(port) = NonZeroU64::new(port) {
62 port
63 } else {
64 loop {
65 let port = match NonZeroU64::new(rand::random()) {
66 Some(p) => p,
67 None => continue,
68 };
69 if !hub.contains_key(&port) {
70 break port;
71 }
72 }
73 };
74
75 let (tx, rx) = mpsc::channel(2);
76 match hub.entry(port) {
77 Entry::Occupied(_) => return None,
78 Entry::Vacant(e) => e.insert(tx),
79 };
80
81 Some((rx, port))
82 }
83
84 fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
85 self.0.lock().remove(port)
86 }
87
88 fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
89 self.0.lock().get(port).cloned()
90 }
91}
92
93#[derive(Default)]
95pub struct MemoryTransport {
96 listeners: VecDeque<Pin<Box<Listener>>>,
97}
98
99impl MemoryTransport {
100 pub fn new() -> Self {
101 Self::default()
102 }
103}
104
105pub struct DialFuture {
107 dial_port: NonZeroU64,
114 sender: ChannelSender,
115 channel_to_send: Option<Channel<Vec<u8>>>,
116 channel_to_return: Option<Channel<Vec<u8>>>,
117}
118
119impl DialFuture {
120 fn new(port: NonZeroU64) -> Option<Self> {
121 let sender = HUB.get(&port)?;
122
123 let (_dial_port_channel, dial_port) = HUB
124 .register_port(0)
125 .expect("there to be some random unoccupied port.");
126
127 let (a_tx, a_rx) = mpsc::channel(4096);
128 let (b_tx, b_rx) = mpsc::channel(4096);
129 Some(DialFuture {
130 dial_port,
131 sender,
132 channel_to_send: Some(RwStreamSink::new(Chan {
133 incoming: a_rx,
134 outgoing: b_tx,
135 dial_port: None,
136 })),
137 channel_to_return: Some(RwStreamSink::new(Chan {
138 incoming: b_rx,
139 outgoing: a_tx,
140 dial_port: Some(dial_port),
141 })),
142 })
143 }
144}
145
146impl Future for DialFuture {
147 type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
148
149 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 match self.sender.poll_ready(cx) {
151 Poll::Pending => return Poll::Pending,
152 Poll::Ready(Ok(())) => {}
153 Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
154 }
155
156 let channel_to_send = self
157 .channel_to_send
158 .take()
159 .expect("Future should not be polled again once complete");
160 let dial_port = self.dial_port;
161 if self
162 .sender
163 .start_send((channel_to_send, dial_port))
164 .is_err()
165 {
166 return Poll::Ready(Err(MemoryTransportError::Unreachable));
167 }
168
169 Poll::Ready(Ok(self
170 .channel_to_return
171 .take()
172 .expect("Future should not be polled again once complete")))
173 }
174}
175
176impl Transport for MemoryTransport {
177 type Output = Channel<Vec<u8>>;
178 type Error = MemoryTransportError;
179 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
180 type Dial = DialFuture;
181
182 fn listen_on(
183 &mut self,
184 id: ListenerId,
185 addr: Multiaddr,
186 ) -> Result<(), TransportError<Self::Error>> {
187 let port = if let Ok(port) = parse_memory_addr(&addr) {
188 port
189 } else {
190 return Err(TransportError::MultiaddrNotSupported(addr));
191 };
192
193 let (rx, port) = match HUB.register_port(port) {
194 Some((rx, port)) => (rx, port),
195 None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
196 };
197
198 let listener = Listener {
199 id,
200 port,
201 addr: Protocol::Memory(port.get()).into(),
202 receiver: rx,
203 tell_listen_addr: true,
204 };
205 self.listeners.push_back(Box::pin(listener));
206
207 Ok(())
208 }
209
210 fn remove_listener(&mut self, id: ListenerId) -> bool {
211 if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
212 let listener = self.listeners.get_mut(index).unwrap();
213 let val_in = HUB.unregister_port(&listener.port);
214 debug_assert!(val_in.is_some());
215 listener.receiver.close();
216 true
217 } else {
218 false
219 }
220 }
221
222 fn dial(&mut self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
223 let port = if let Ok(port) = parse_memory_addr(&addr) {
224 if let Some(port) = NonZeroU64::new(port) {
225 port
226 } else {
227 return Err(TransportError::Other(MemoryTransportError::Unreachable));
228 }
229 } else {
230 return Err(TransportError::MultiaddrNotSupported(addr));
231 };
232
233 DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
234 }
235
236 fn dial_as_listener(
237 &mut self,
238 addr: Multiaddr,
239 ) -> Result<DialFuture, TransportError<Self::Error>> {
240 self.dial(addr)
241 }
242
243 fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option<Multiaddr> {
244 None
245 }
246
247 fn poll(
248 mut self: Pin<&mut Self>,
249 cx: &mut Context<'_>,
250 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
251 where
252 Self: Sized,
253 {
254 let mut remaining = self.listeners.len();
255 while let Some(mut listener) = self.listeners.pop_back() {
256 if listener.tell_listen_addr {
257 listener.tell_listen_addr = false;
258 let listen_addr = listener.addr.clone();
259 let listener_id = listener.id;
260 self.listeners.push_front(listener);
261 return Poll::Ready(TransportEvent::NewAddress {
262 listen_addr,
263 listener_id,
264 });
265 }
266
267 let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
268 Poll::Pending => None,
269 Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
270 listener_id: listener.id,
271 upgrade: future::ready(Ok(channel)),
272 local_addr: listener.addr.clone(),
273 send_back_addr: Protocol::Memory(dial_port.get()).into(),
274 }),
275 Poll::Ready(None) => {
276 return Poll::Ready(TransportEvent::ListenerClosed {
278 listener_id: listener.id,
279 reason: Ok(()),
280 });
281 }
282 };
283
284 self.listeners.push_front(listener);
285 if let Some(event) = event {
286 return Poll::Ready(event);
287 } else {
288 remaining -= 1;
289 if remaining == 0 {
290 break;
291 }
292 }
293 }
294 Poll::Pending
295 }
296}
297
298#[derive(Debug, Copy, Clone)]
300pub enum MemoryTransportError {
301 Unreachable,
303 AlreadyInUse,
305}
306
307impl fmt::Display for MemoryTransportError {
308 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 match *self {
310 MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
311 MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
312 }
313 }
314}
315
316impl error::Error for MemoryTransportError {}
317
318pub struct Listener {
320 id: ListenerId,
321 port: NonZeroU64,
323 addr: Multiaddr,
325 receiver: ChannelReceiver,
327 tell_listen_addr: bool,
329}
330
331fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
333 let mut protocols = a.iter();
334 match protocols.next() {
335 Some(Protocol::Memory(port)) => match protocols.next() {
336 None | Some(Protocol::P2p(_)) => Ok(port),
337 _ => Err(()),
338 },
339 _ => Err(()),
340 }
341}
342
343pub type Channel<T> = RwStreamSink<Chan<T>>;
347
348pub struct Chan<T = Vec<u8>> {
352 incoming: mpsc::Receiver<T>,
353 outgoing: mpsc::Sender<T>,
354
355 dial_port: Option<NonZeroU64>,
362}
363
364impl<T> Unpin for Chan<T> {}
365
366impl<T> Stream for Chan<T> {
367 type Item = Result<T, io::Error>;
368
369 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
370 match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
371 Poll::Pending => Poll::Pending,
372 Poll::Ready(None) => Poll::Ready(None),
373 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
374 }
375 }
376}
377
378impl<T> Sink<T> for Chan<T> {
379 type Error = io::Error;
380
381 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
382 self.outgoing
383 .poll_ready(cx)
384 .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
385 }
386
387 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
388 self.outgoing
389 .start_send(item)
390 .map_err(|_| io::ErrorKind::BrokenPipe.into())
391 }
392
393 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
394 Poll::Ready(Ok(()))
395 }
396
397 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
398 Poll::Ready(Ok(()))
399 }
400}
401
402impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
403 fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
404 RwStreamSink::new(channel)
405 }
406}
407
408impl<T> Drop for Chan<T> {
409 fn drop(&mut self) {
410 if let Some(port) = self.dial_port {
411 let channel_sender = HUB.unregister_port(&port);
412 debug_assert!(channel_sender.is_some());
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn parse_memory_addr_works() {
423 assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
424 assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
425 assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
426 assert_eq!(
427 parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
428 Err(())
429 );
430 assert_eq!(
431 parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
432 Err(())
433 );
434 assert_eq!(
435 parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
436 Ok(1_234_567_890)
437 );
438 assert_eq!(
439 parse_memory_addr(
440 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
441 .parse()
442 .unwrap()
443 ),
444 Ok(5)
445 );
446 assert_eq!(
447 parse_memory_addr(
448 &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
449 .parse()
450 .unwrap()
451 ),
452 Ok(5)
453 );
454 }
455
456 #[test]
457 fn listening_twice() {
458 let mut transport = MemoryTransport::default();
459
460 let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
461 let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
462
463 let listener_id_1 = ListenerId::next();
464
465 transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
466 assert!(
467 transport.remove_listener(listener_id_1),
468 "Listener doesn't exist."
469 );
470
471 let listener_id_2 = ListenerId::next();
472 transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
473 let listener_id_3 = ListenerId::next();
474 transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
475
476 assert!(transport
477 .listen_on(ListenerId::next(), addr_1.clone())
478 .is_err());
479 assert!(transport
480 .listen_on(ListenerId::next(), addr_2.clone())
481 .is_err());
482
483 assert!(
484 transport.remove_listener(listener_id_2),
485 "Listener doesn't exist."
486 );
487 assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
488 assert!(transport
489 .listen_on(ListenerId::next(), addr_2.clone())
490 .is_err());
491
492 assert!(
493 transport.remove_listener(listener_id_3),
494 "Listener doesn't exist."
495 );
496 assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
497 }
498
499 #[test]
500 fn port_not_in_use() {
501 let mut transport = MemoryTransport::default();
502 assert!(transport
503 .dial("/memory/810172461024613".parse().unwrap())
504 .is_err());
505 transport
506 .listen_on(
507 ListenerId::next(),
508 "/memory/810172461024613".parse().unwrap(),
509 )
510 .unwrap();
511 assert!(transport
512 .dial("/memory/810172461024613".parse().unwrap())
513 .is_ok());
514 }
515
516 #[test]
517 fn stop_listening() {
518 let rand_port = rand::random::<u64>().saturating_add(1);
519 let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
520
521 let mut transport = MemoryTransport::default().boxed();
522 futures::executor::block_on(async {
523 let listener_id = ListenerId::next();
524 transport.listen_on(listener_id, addr.clone()).unwrap();
525 let reported_addr = transport
526 .select_next_some()
527 .await
528 .into_new_address()
529 .expect("new address");
530 assert_eq!(addr, reported_addr);
531 assert!(transport.remove_listener(listener_id));
532 match transport.select_next_some().await {
533 TransportEvent::ListenerClosed {
534 listener_id: id,
535 reason,
536 } => {
537 assert_eq!(id, listener_id);
538 assert!(reason.is_ok())
539 }
540 other => panic!("Unexpected transport event: {other:?}"),
541 }
542 assert!(!transport.remove_listener(listener_id));
543 })
544 }
545
546 #[test]
547 fn communicating_between_dialer_and_listener() {
548 let msg = [1, 2, 3];
549
550 let rand_port = rand::random::<u64>().saturating_add(1);
553 let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
554 let cloned_t1_addr = t1_addr.clone();
555
556 let mut t1 = MemoryTransport::default().boxed();
557
558 let listener = async move {
559 t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
560 let upgrade = loop {
561 let event = t1.select_next_some().await;
562 if let Some(upgrade) = event.into_incoming() {
563 break upgrade;
564 }
565 };
566
567 let mut socket = upgrade.0.await.unwrap();
568
569 let mut buf = [0; 3];
570 socket.read_exact(&mut buf).await.unwrap();
571
572 assert_eq!(buf, msg);
573 };
574
575 let mut t2 = MemoryTransport::default();
578 let dialer = async move {
579 let mut socket = t2.dial(cloned_t1_addr).unwrap().await.unwrap();
580 socket.write_all(&msg).await.unwrap();
581 };
582
583 futures::executor::block_on(futures::future::join(listener, dialer));
586 }
587
588 #[test]
589 fn dialer_address_unequal_to_listener_address() {
590 let listener_addr: Multiaddr =
591 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
592 let listener_addr_cloned = listener_addr.clone();
593
594 let mut listener_transport = MemoryTransport::default().boxed();
595
596 let listener = async move {
597 listener_transport
598 .listen_on(ListenerId::next(), listener_addr.clone())
599 .unwrap();
600 loop {
601 if let TransportEvent::Incoming { send_back_addr, .. } =
602 listener_transport.select_next_some().await
603 {
604 assert!(
605 send_back_addr != listener_addr,
606 "Expect dialer address not to equal listener address."
607 );
608 return;
609 }
610 }
611 };
612
613 let dialer = async move {
614 MemoryTransport::default()
615 .dial(listener_addr_cloned)
616 .unwrap()
617 .await
618 .unwrap();
619 };
620
621 futures::executor::block_on(futures::future::join(listener, dialer));
622 }
623
624 #[test]
625 fn dialer_port_is_deregistered() {
626 let (terminate, should_terminate) = futures::channel::oneshot::channel();
627 let (terminated, is_terminated) = futures::channel::oneshot::channel();
628
629 let listener_addr: Multiaddr =
630 Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
631 let listener_addr_cloned = listener_addr.clone();
632
633 let mut listener_transport = MemoryTransport::default().boxed();
634
635 let listener = async move {
636 listener_transport
637 .listen_on(ListenerId::next(), listener_addr.clone())
638 .unwrap();
639 loop {
640 if let TransportEvent::Incoming { send_back_addr, .. } =
641 listener_transport.select_next_some().await
642 {
643 let dialer_port =
644 NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
645
646 assert!(
647 HUB.get(&dialer_port).is_some(),
648 "Expect dialer port to stay registered while connection is in use.",
649 );
650
651 terminate.send(()).unwrap();
652 is_terminated.await.unwrap();
653
654 assert!(
655 HUB.get(&dialer_port).is_none(),
656 "Expect dialer port to be deregistered once connection is dropped.",
657 );
658
659 return;
660 }
661 }
662 };
663
664 let dialer = async move {
665 let chan = MemoryTransport::default()
666 .dial(listener_addr_cloned)
667 .unwrap()
668 .await
669 .unwrap();
670
671 should_terminate.await.unwrap();
672 drop(chan);
673 terminated.send(()).unwrap();
674 };
675
676 futures::executor::block_on(futures::future::join(listener, dialer));
677 }
678}