libp2p_core/transport/
memory.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::transport::{ListenerId, Transport, TransportError, TransportEvent};
22use fnv::FnvHashMap;
23use futures::{
24    channel::mpsc,
25    future::{self, Ready},
26    prelude::*,
27    task::Context,
28    task::Poll,
29};
30use multiaddr::{Multiaddr, Protocol};
31use once_cell::sync::Lazy;
32use parking_lot::Mutex;
33use rw_stream_sink::RwStreamSink;
34use std::{
35    collections::{hash_map::Entry, VecDeque},
36    error, fmt, io,
37    num::NonZeroU64,
38    pin::Pin,
39};
40
41static HUB: Lazy<Hub> = Lazy::new(|| Hub(Mutex::new(FnvHashMap::default())));
42
43struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
44
45/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
46/// port of the dialer to a [`Listener`].
47type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
48
49/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
50/// the port of the dialer from a [`DialFuture`].
51type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
52
53impl Hub {
54    /// Registers the given port on the hub.
55    ///
56    /// Randomizes port when given port is `0`. Returns [`None`] when given port
57    /// is already occupied.
58    fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
59        let mut hub = self.0.lock();
60
61        let port = if let Some(port) = NonZeroU64::new(port) {
62            port
63        } else {
64            loop {
65                let port = match NonZeroU64::new(rand::random()) {
66                    Some(p) => p,
67                    None => continue,
68                };
69                if !hub.contains_key(&port) {
70                    break port;
71                }
72            }
73        };
74
75        let (tx, rx) = mpsc::channel(2);
76        match hub.entry(port) {
77            Entry::Occupied(_) => return None,
78            Entry::Vacant(e) => e.insert(tx),
79        };
80
81        Some((rx, port))
82    }
83
84    fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
85        self.0.lock().remove(port)
86    }
87
88    fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
89        self.0.lock().get(port).cloned()
90    }
91}
92
93/// Transport that supports `/memory/N` multiaddresses.
94#[derive(Default)]
95pub struct MemoryTransport {
96    listeners: VecDeque<Pin<Box<Listener>>>,
97}
98
99impl MemoryTransport {
100    pub fn new() -> Self {
101        Self::default()
102    }
103}
104
105/// Connection to a `MemoryTransport` currently being opened.
106pub struct DialFuture {
107    /// Ephemeral source port.
108    ///
109    /// These ports mimic TCP ephemeral source ports but are not actually used
110    /// by the memory transport due to the direct use of channels. They merely
111    /// ensure that every connection has a unique address for each dialer, which
112    /// is not at the same time a listen address (analogous to TCP).
113    dial_port: NonZeroU64,
114    sender: ChannelSender,
115    channel_to_send: Option<Channel<Vec<u8>>>,
116    channel_to_return: Option<Channel<Vec<u8>>>,
117}
118
119impl DialFuture {
120    fn new(port: NonZeroU64) -> Option<Self> {
121        let sender = HUB.get(&port)?;
122
123        let (_dial_port_channel, dial_port) = HUB
124            .register_port(0)
125            .expect("there to be some random unoccupied port.");
126
127        let (a_tx, a_rx) = mpsc::channel(4096);
128        let (b_tx, b_rx) = mpsc::channel(4096);
129        Some(DialFuture {
130            dial_port,
131            sender,
132            channel_to_send: Some(RwStreamSink::new(Chan {
133                incoming: a_rx,
134                outgoing: b_tx,
135                dial_port: None,
136            })),
137            channel_to_return: Some(RwStreamSink::new(Chan {
138                incoming: b_rx,
139                outgoing: a_tx,
140                dial_port: Some(dial_port),
141            })),
142        })
143    }
144}
145
146impl Future for DialFuture {
147    type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
148
149    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        match self.sender.poll_ready(cx) {
151            Poll::Pending => return Poll::Pending,
152            Poll::Ready(Ok(())) => {}
153            Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
154        }
155
156        let channel_to_send = self
157            .channel_to_send
158            .take()
159            .expect("Future should not be polled again once complete");
160        let dial_port = self.dial_port;
161        if self
162            .sender
163            .start_send((channel_to_send, dial_port))
164            .is_err()
165        {
166            return Poll::Ready(Err(MemoryTransportError::Unreachable));
167        }
168
169        Poll::Ready(Ok(self
170            .channel_to_return
171            .take()
172            .expect("Future should not be polled again once complete")))
173    }
174}
175
176impl Transport for MemoryTransport {
177    type Output = Channel<Vec<u8>>;
178    type Error = MemoryTransportError;
179    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
180    type Dial = DialFuture;
181
182    fn listen_on(
183        &mut self,
184        id: ListenerId,
185        addr: Multiaddr,
186    ) -> Result<(), TransportError<Self::Error>> {
187        let port = if let Ok(port) = parse_memory_addr(&addr) {
188            port
189        } else {
190            return Err(TransportError::MultiaddrNotSupported(addr));
191        };
192
193        let (rx, port) = match HUB.register_port(port) {
194            Some((rx, port)) => (rx, port),
195            None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
196        };
197
198        let listener = Listener {
199            id,
200            port,
201            addr: Protocol::Memory(port.get()).into(),
202            receiver: rx,
203            tell_listen_addr: true,
204        };
205        self.listeners.push_back(Box::pin(listener));
206
207        Ok(())
208    }
209
210    fn remove_listener(&mut self, id: ListenerId) -> bool {
211        if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
212            let listener = self.listeners.get_mut(index).unwrap();
213            let val_in = HUB.unregister_port(&listener.port);
214            debug_assert!(val_in.is_some());
215            listener.receiver.close();
216            true
217        } else {
218            false
219        }
220    }
221
222    fn dial(&mut self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
223        let port = if let Ok(port) = parse_memory_addr(&addr) {
224            if let Some(port) = NonZeroU64::new(port) {
225                port
226            } else {
227                return Err(TransportError::Other(MemoryTransportError::Unreachable));
228            }
229        } else {
230            return Err(TransportError::MultiaddrNotSupported(addr));
231        };
232
233        DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
234    }
235
236    fn dial_as_listener(
237        &mut self,
238        addr: Multiaddr,
239    ) -> Result<DialFuture, TransportError<Self::Error>> {
240        self.dial(addr)
241    }
242
243    fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option<Multiaddr> {
244        None
245    }
246
247    fn poll(
248        mut self: Pin<&mut Self>,
249        cx: &mut Context<'_>,
250    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
251    where
252        Self: Sized,
253    {
254        let mut remaining = self.listeners.len();
255        while let Some(mut listener) = self.listeners.pop_back() {
256            if listener.tell_listen_addr {
257                listener.tell_listen_addr = false;
258                let listen_addr = listener.addr.clone();
259                let listener_id = listener.id;
260                self.listeners.push_front(listener);
261                return Poll::Ready(TransportEvent::NewAddress {
262                    listen_addr,
263                    listener_id,
264                });
265            }
266
267            let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
268                Poll::Pending => None,
269                Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
270                    listener_id: listener.id,
271                    upgrade: future::ready(Ok(channel)),
272                    local_addr: listener.addr.clone(),
273                    send_back_addr: Protocol::Memory(dial_port.get()).into(),
274                }),
275                Poll::Ready(None) => {
276                    // Listener was closed.
277                    return Poll::Ready(TransportEvent::ListenerClosed {
278                        listener_id: listener.id,
279                        reason: Ok(()),
280                    });
281                }
282            };
283
284            self.listeners.push_front(listener);
285            if let Some(event) = event {
286                return Poll::Ready(event);
287            } else {
288                remaining -= 1;
289                if remaining == 0 {
290                    break;
291                }
292            }
293        }
294        Poll::Pending
295    }
296}
297
298/// Error that can be produced from the `MemoryTransport`.
299#[derive(Debug, Copy, Clone)]
300pub enum MemoryTransportError {
301    /// There's no listener on the given port.
302    Unreachable,
303    /// Tries to listen on a port that is already in use.
304    AlreadyInUse,
305}
306
307impl fmt::Display for MemoryTransportError {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        match *self {
310            MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
311            MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
312        }
313    }
314}
315
316impl error::Error for MemoryTransportError {}
317
318/// Listener for memory connections.
319pub struct Listener {
320    id: ListenerId,
321    /// Port we're listening on.
322    port: NonZeroU64,
323    /// The address we are listening on.
324    addr: Multiaddr,
325    /// Receives incoming connections.
326    receiver: ChannelReceiver,
327    /// Generate [`TransportEvent::NewAddress`] to inform about our listen address.
328    tell_listen_addr: bool,
329}
330
331/// If the address is `/memory/n`, returns the value of `n`.
332fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
333    let mut protocols = a.iter();
334    match protocols.next() {
335        Some(Protocol::Memory(port)) => match protocols.next() {
336            None | Some(Protocol::P2p(_)) => Ok(port),
337            _ => Err(()),
338        },
339        _ => Err(()),
340    }
341}
342
343/// A channel represents an established, in-memory, logical connection between two endpoints.
344///
345/// Implements `AsyncRead` and `AsyncWrite`.
346pub type Channel<T> = RwStreamSink<Chan<T>>;
347
348/// A channel represents an established, in-memory, logical connection between two endpoints.
349///
350/// Implements `Sink` and `Stream`.
351pub struct Chan<T = Vec<u8>> {
352    incoming: mpsc::Receiver<T>,
353    outgoing: mpsc::Sender<T>,
354
355    // Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
356    // port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
357    // [`None`] when [`Chan`] of listener.
358    //
359    // Note: Listening port is unregistered in [`Drop`] implementation of
360    // [`Listener`].
361    dial_port: Option<NonZeroU64>,
362}
363
364impl<T> Unpin for Chan<T> {}
365
366impl<T> Stream for Chan<T> {
367    type Item = Result<T, io::Error>;
368
369    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
370        match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
371            Poll::Pending => Poll::Pending,
372            Poll::Ready(None) => Poll::Ready(None),
373            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
374        }
375    }
376}
377
378impl<T> Sink<T> for Chan<T> {
379    type Error = io::Error;
380
381    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
382        self.outgoing
383            .poll_ready(cx)
384            .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
385    }
386
387    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
388        self.outgoing
389            .start_send(item)
390            .map_err(|_| io::ErrorKind::BrokenPipe.into())
391    }
392
393    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
394        Poll::Ready(Ok(()))
395    }
396
397    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
398        Poll::Ready(Ok(()))
399    }
400}
401
402impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
403    fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
404        RwStreamSink::new(channel)
405    }
406}
407
408impl<T> Drop for Chan<T> {
409    fn drop(&mut self) {
410        if let Some(port) = self.dial_port {
411            let channel_sender = HUB.unregister_port(&port);
412            debug_assert!(channel_sender.is_some());
413        }
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn parse_memory_addr_works() {
423        assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
424        assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
425        assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
426        assert_eq!(
427            parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
428            Err(())
429        );
430        assert_eq!(
431            parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
432            Err(())
433        );
434        assert_eq!(
435            parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
436            Ok(1_234_567_890)
437        );
438        assert_eq!(
439            parse_memory_addr(
440                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
441                    .parse()
442                    .unwrap()
443            ),
444            Ok(5)
445        );
446        assert_eq!(
447            parse_memory_addr(
448                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
449                    .parse()
450                    .unwrap()
451            ),
452            Ok(5)
453        );
454    }
455
456    #[test]
457    fn listening_twice() {
458        let mut transport = MemoryTransport::default();
459
460        let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
461        let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
462
463        let listener_id_1 = ListenerId::next();
464
465        transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
466        assert!(
467            transport.remove_listener(listener_id_1),
468            "Listener doesn't exist."
469        );
470
471        let listener_id_2 = ListenerId::next();
472        transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
473        let listener_id_3 = ListenerId::next();
474        transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
475
476        assert!(transport
477            .listen_on(ListenerId::next(), addr_1.clone())
478            .is_err());
479        assert!(transport
480            .listen_on(ListenerId::next(), addr_2.clone())
481            .is_err());
482
483        assert!(
484            transport.remove_listener(listener_id_2),
485            "Listener doesn't exist."
486        );
487        assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
488        assert!(transport
489            .listen_on(ListenerId::next(), addr_2.clone())
490            .is_err());
491
492        assert!(
493            transport.remove_listener(listener_id_3),
494            "Listener doesn't exist."
495        );
496        assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
497    }
498
499    #[test]
500    fn port_not_in_use() {
501        let mut transport = MemoryTransport::default();
502        assert!(transport
503            .dial("/memory/810172461024613".parse().unwrap())
504            .is_err());
505        transport
506            .listen_on(
507                ListenerId::next(),
508                "/memory/810172461024613".parse().unwrap(),
509            )
510            .unwrap();
511        assert!(transport
512            .dial("/memory/810172461024613".parse().unwrap())
513            .is_ok());
514    }
515
516    #[test]
517    fn stop_listening() {
518        let rand_port = rand::random::<u64>().saturating_add(1);
519        let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
520
521        let mut transport = MemoryTransport::default().boxed();
522        futures::executor::block_on(async {
523            let listener_id = ListenerId::next();
524            transport.listen_on(listener_id, addr.clone()).unwrap();
525            let reported_addr = transport
526                .select_next_some()
527                .await
528                .into_new_address()
529                .expect("new address");
530            assert_eq!(addr, reported_addr);
531            assert!(transport.remove_listener(listener_id));
532            match transport.select_next_some().await {
533                TransportEvent::ListenerClosed {
534                    listener_id: id,
535                    reason,
536                } => {
537                    assert_eq!(id, listener_id);
538                    assert!(reason.is_ok())
539                }
540                other => panic!("Unexpected transport event: {other:?}"),
541            }
542            assert!(!transport.remove_listener(listener_id));
543        })
544    }
545
546    #[test]
547    fn communicating_between_dialer_and_listener() {
548        let msg = [1, 2, 3];
549
550        // Setup listener.
551
552        let rand_port = rand::random::<u64>().saturating_add(1);
553        let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
554        let cloned_t1_addr = t1_addr.clone();
555
556        let mut t1 = MemoryTransport::default().boxed();
557
558        let listener = async move {
559            t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
560            let upgrade = loop {
561                let event = t1.select_next_some().await;
562                if let Some(upgrade) = event.into_incoming() {
563                    break upgrade;
564                }
565            };
566
567            let mut socket = upgrade.0.await.unwrap();
568
569            let mut buf = [0; 3];
570            socket.read_exact(&mut buf).await.unwrap();
571
572            assert_eq!(buf, msg);
573        };
574
575        // Setup dialer.
576
577        let mut t2 = MemoryTransport::default();
578        let dialer = async move {
579            let mut socket = t2.dial(cloned_t1_addr).unwrap().await.unwrap();
580            socket.write_all(&msg).await.unwrap();
581        };
582
583        // Wait for both to finish.
584
585        futures::executor::block_on(futures::future::join(listener, dialer));
586    }
587
588    #[test]
589    fn dialer_address_unequal_to_listener_address() {
590        let listener_addr: Multiaddr =
591            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
592        let listener_addr_cloned = listener_addr.clone();
593
594        let mut listener_transport = MemoryTransport::default().boxed();
595
596        let listener = async move {
597            listener_transport
598                .listen_on(ListenerId::next(), listener_addr.clone())
599                .unwrap();
600            loop {
601                if let TransportEvent::Incoming { send_back_addr, .. } =
602                    listener_transport.select_next_some().await
603                {
604                    assert!(
605                        send_back_addr != listener_addr,
606                        "Expect dialer address not to equal listener address."
607                    );
608                    return;
609                }
610            }
611        };
612
613        let dialer = async move {
614            MemoryTransport::default()
615                .dial(listener_addr_cloned)
616                .unwrap()
617                .await
618                .unwrap();
619        };
620
621        futures::executor::block_on(futures::future::join(listener, dialer));
622    }
623
624    #[test]
625    fn dialer_port_is_deregistered() {
626        let (terminate, should_terminate) = futures::channel::oneshot::channel();
627        let (terminated, is_terminated) = futures::channel::oneshot::channel();
628
629        let listener_addr: Multiaddr =
630            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
631        let listener_addr_cloned = listener_addr.clone();
632
633        let mut listener_transport = MemoryTransport::default().boxed();
634
635        let listener = async move {
636            listener_transport
637                .listen_on(ListenerId::next(), listener_addr.clone())
638                .unwrap();
639            loop {
640                if let TransportEvent::Incoming { send_back_addr, .. } =
641                    listener_transport.select_next_some().await
642                {
643                    let dialer_port =
644                        NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
645
646                    assert!(
647                        HUB.get(&dialer_port).is_some(),
648                        "Expect dialer port to stay registered while connection is in use.",
649                    );
650
651                    terminate.send(()).unwrap();
652                    is_terminated.await.unwrap();
653
654                    assert!(
655                        HUB.get(&dialer_port).is_none(),
656                        "Expect dialer port to be deregistered once connection is dropped.",
657                    );
658
659                    return;
660                }
661            }
662        };
663
664        let dialer = async move {
665            let chan = MemoryTransport::default()
666                .dial(listener_addr_cloned)
667                .unwrap()
668                .await
669                .unwrap();
670
671            should_terminate.await.unwrap();
672            drop(chan);
673            terminated.send(()).unwrap();
674        };
675
676        futures::executor::block_on(futures::future::join(listener, dialer));
677    }
678}