libp2p_core/transport/
memory.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::transport::{DialOpts, ListenerId, Transport, TransportError, TransportEvent};
22use fnv::FnvHashMap;
23use futures::{channel::mpsc, future::Ready, prelude::*, task::Context, task::Poll};
24use multiaddr::{Multiaddr, Protocol};
25use once_cell::sync::Lazy;
26use parking_lot::Mutex;
27use rw_stream_sink::RwStreamSink;
28use std::{
29    collections::{hash_map::Entry, VecDeque},
30    error, fmt, io,
31    num::NonZeroU64,
32    pin::Pin,
33};
34
35static HUB: Lazy<Hub> = Lazy::new(|| Hub(Mutex::new(FnvHashMap::default())));
36
37struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
38
39/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
40/// port of the dialer to a [`Listener`].
41type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
42
43/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
44/// the port of the dialer from a [`DialFuture`].
45type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
46
47impl Hub {
48    /// Registers the given port on the hub.
49    ///
50    /// Randomizes port when given port is `0`. Returns [`None`] when given port
51    /// is already occupied.
52    fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
53        let mut hub = self.0.lock();
54
55        let port = if let Some(port) = NonZeroU64::new(port) {
56            port
57        } else {
58            loop {
59                let Some(port) = NonZeroU64::new(rand::random()) else {
60                    continue;
61                };
62                if !hub.contains_key(&port) {
63                    break port;
64                }
65            }
66        };
67
68        let (tx, rx) = mpsc::channel(2);
69        match hub.entry(port) {
70            Entry::Occupied(_) => return None,
71            Entry::Vacant(e) => e.insert(tx),
72        };
73
74        Some((rx, port))
75    }
76
77    fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
78        self.0.lock().remove(port)
79    }
80
81    fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
82        self.0.lock().get(port).cloned()
83    }
84}
85
86/// Transport that supports `/memory/N` multiaddresses.
87#[derive(Default)]
88pub struct MemoryTransport {
89    listeners: VecDeque<Pin<Box<Listener>>>,
90}
91
92impl MemoryTransport {
93    pub fn new() -> Self {
94        Self::default()
95    }
96}
97
98/// Connection to a `MemoryTransport` currently being opened.
99pub struct DialFuture {
100    /// Ephemeral source port.
101    ///
102    /// These ports mimic TCP ephemeral source ports but are not actually used
103    /// by the memory transport due to the direct use of channels. They merely
104    /// ensure that every connection has a unique address for each dialer, which
105    /// is not at the same time a listen address (analogous to TCP).
106    dial_port: NonZeroU64,
107    sender: ChannelSender,
108    channel_to_send: Option<Channel<Vec<u8>>>,
109    channel_to_return: Option<Channel<Vec<u8>>>,
110}
111
112impl DialFuture {
113    fn new(port: NonZeroU64) -> Option<Self> {
114        let sender = HUB.get(&port)?;
115
116        let (_dial_port_channel, dial_port) = HUB
117            .register_port(0)
118            .expect("there to be some random unoccupied port.");
119
120        let (a_tx, a_rx) = mpsc::channel(4096);
121        let (b_tx, b_rx) = mpsc::channel(4096);
122        Some(DialFuture {
123            dial_port,
124            sender,
125            channel_to_send: Some(RwStreamSink::new(Chan {
126                incoming: a_rx,
127                outgoing: b_tx,
128                dial_port: None,
129            })),
130            channel_to_return: Some(RwStreamSink::new(Chan {
131                incoming: b_rx,
132                outgoing: a_tx,
133                dial_port: Some(dial_port),
134            })),
135        })
136    }
137}
138
139impl Future for DialFuture {
140    type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
141
142    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
143        match self.sender.poll_ready(cx) {
144            Poll::Pending => return Poll::Pending,
145            Poll::Ready(Ok(())) => {}
146            Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
147        }
148
149        let channel_to_send = self
150            .channel_to_send
151            .take()
152            .expect("Future should not be polled again once complete");
153        let dial_port = self.dial_port;
154        if self
155            .sender
156            .start_send((channel_to_send, dial_port))
157            .is_err()
158        {
159            return Poll::Ready(Err(MemoryTransportError::Unreachable));
160        }
161
162        Poll::Ready(Ok(self
163            .channel_to_return
164            .take()
165            .expect("Future should not be polled again once complete")))
166    }
167}
168
169impl Transport for MemoryTransport {
170    type Output = Channel<Vec<u8>>;
171    type Error = MemoryTransportError;
172    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
173    type Dial = DialFuture;
174
175    fn listen_on(
176        &mut self,
177        id: ListenerId,
178        addr: Multiaddr,
179    ) -> Result<(), TransportError<Self::Error>> {
180        let port =
181            parse_memory_addr(&addr).map_err(|_| TransportError::MultiaddrNotSupported(addr))?;
182
183        let (rx, port) = HUB
184            .register_port(port)
185            .ok_or(TransportError::Other(MemoryTransportError::Unreachable))?;
186
187        let listener = Listener {
188            id,
189            port,
190            addr: Protocol::Memory(port.get()).into(),
191            receiver: rx,
192            tell_listen_addr: true,
193        };
194        self.listeners.push_back(Box::pin(listener));
195
196        Ok(())
197    }
198
199    fn remove_listener(&mut self, id: ListenerId) -> bool {
200        if let Some(index) = self.listeners.iter().position(|listener| listener.id == id) {
201            let listener = self.listeners.get_mut(index).unwrap();
202            let val_in = HUB.unregister_port(&listener.port);
203            debug_assert!(val_in.is_some());
204            listener.receiver.close();
205            true
206        } else {
207            false
208        }
209    }
210
211    fn dial(
212        &mut self,
213        addr: Multiaddr,
214        _opts: DialOpts,
215    ) -> Result<DialFuture, TransportError<Self::Error>> {
216        let port = if let Ok(port) = parse_memory_addr(&addr) {
217            if let Some(port) = NonZeroU64::new(port) {
218                port
219            } else {
220                return Err(TransportError::Other(MemoryTransportError::Unreachable));
221            }
222        } else {
223            return Err(TransportError::MultiaddrNotSupported(addr));
224        };
225
226        DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
227    }
228
229    fn poll(
230        mut self: Pin<&mut Self>,
231        cx: &mut Context<'_>,
232    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>>
233    where
234        Self: Sized,
235    {
236        let mut remaining = self.listeners.len();
237        while let Some(mut listener) = self.listeners.pop_back() {
238            if listener.tell_listen_addr {
239                listener.tell_listen_addr = false;
240                let listen_addr = listener.addr.clone();
241                let listener_id = listener.id;
242                self.listeners.push_front(listener);
243                return Poll::Ready(TransportEvent::NewAddress {
244                    listen_addr,
245                    listener_id,
246                });
247            }
248
249            let event = match Stream::poll_next(Pin::new(&mut listener.receiver), cx) {
250                Poll::Pending => None,
251                Poll::Ready(Some((channel, dial_port))) => Some(TransportEvent::Incoming {
252                    listener_id: listener.id,
253                    upgrade: future::ready(Ok(channel)),
254                    local_addr: listener.addr.clone(),
255                    send_back_addr: Protocol::Memory(dial_port.get()).into(),
256                }),
257                Poll::Ready(None) => {
258                    // Listener was closed.
259                    return Poll::Ready(TransportEvent::ListenerClosed {
260                        listener_id: listener.id,
261                        reason: Ok(()),
262                    });
263                }
264            };
265
266            self.listeners.push_front(listener);
267            if let Some(event) = event {
268                return Poll::Ready(event);
269            } else {
270                remaining -= 1;
271                if remaining == 0 {
272                    break;
273                }
274            }
275        }
276        Poll::Pending
277    }
278}
279
280/// Error that can be produced from the `MemoryTransport`.
281#[derive(Debug, Copy, Clone)]
282pub enum MemoryTransportError {
283    /// There's no listener on the given port.
284    Unreachable,
285    /// Tries to listen on a port that is already in use.
286    AlreadyInUse,
287}
288
289impl fmt::Display for MemoryTransportError {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        match *self {
292            MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
293            MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
294        }
295    }
296}
297
298impl error::Error for MemoryTransportError {}
299
300/// Listener for memory connections.
301pub struct Listener {
302    id: ListenerId,
303    /// Port we're listening on.
304    port: NonZeroU64,
305    /// The address we are listening on.
306    addr: Multiaddr,
307    /// Receives incoming connections.
308    receiver: ChannelReceiver,
309    /// Generate [`TransportEvent::NewAddress`] to inform about our listen address.
310    tell_listen_addr: bool,
311}
312
313/// If the address is `/memory/n`, returns the value of `n`.
314fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
315    let mut protocols = a.iter();
316    match protocols.next() {
317        Some(Protocol::Memory(port)) => match protocols.next() {
318            None | Some(Protocol::P2p(_)) => Ok(port),
319            _ => Err(()),
320        },
321        _ => Err(()),
322    }
323}
324
325/// A channel represents an established, in-memory, logical connection between two endpoints.
326///
327/// Implements `AsyncRead` and `AsyncWrite`.
328pub type Channel<T> = RwStreamSink<Chan<T>>;
329
330/// A channel represents an established, in-memory, logical connection between two endpoints.
331///
332/// Implements `Sink` and `Stream`.
333pub struct Chan<T = Vec<u8>> {
334    incoming: mpsc::Receiver<T>,
335    outgoing: mpsc::Sender<T>,
336
337    // Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
338    // port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
339    // [`None`] when [`Chan`] of listener.
340    //
341    // Note: Listening port is unregistered in [`Drop`] implementation of
342    // [`Listener`].
343    dial_port: Option<NonZeroU64>,
344}
345
346impl<T> Unpin for Chan<T> {}
347
348impl<T> Stream for Chan<T> {
349    type Item = Result<T, io::Error>;
350
351    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
352        match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
353            Poll::Pending => Poll::Pending,
354            Poll::Ready(None) => Poll::Ready(None),
355            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
356        }
357    }
358}
359
360impl<T> Sink<T> for Chan<T> {
361    type Error = io::Error;
362
363    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
364        self.outgoing
365            .poll_ready(cx)
366            .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
367    }
368
369    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
370        self.outgoing
371            .start_send(item)
372            .map_err(|_| io::ErrorKind::BrokenPipe.into())
373    }
374
375    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
376        Poll::Ready(Ok(()))
377    }
378
379    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
380        Poll::Ready(Ok(()))
381    }
382}
383
384impl<T: AsRef<[u8]>> From<Chan<T>> for RwStreamSink<Chan<T>> {
385    fn from(channel: Chan<T>) -> RwStreamSink<Chan<T>> {
386        RwStreamSink::new(channel)
387    }
388}
389
390impl<T> Drop for Chan<T> {
391    fn drop(&mut self) {
392        if let Some(port) = self.dial_port {
393            let channel_sender = HUB.unregister_port(&port);
394            debug_assert!(channel_sender.is_some());
395        }
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use crate::{transport::PortUse, Endpoint};
402
403    use super::*;
404
405    #[test]
406    fn parse_memory_addr_works() {
407        assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
408        assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
409        assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
410        assert_eq!(
411            parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()),
412            Err(())
413        );
414        assert_eq!(
415            parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()),
416            Err(())
417        );
418        assert_eq!(
419            parse_memory_addr(&"/memory/1234567890".parse().unwrap()),
420            Ok(1_234_567_890)
421        );
422        assert_eq!(
423            parse_memory_addr(
424                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6"
425                    .parse()
426                    .unwrap()
427            ),
428            Ok(5)
429        );
430        assert_eq!(
431            parse_memory_addr(
432                &"/memory/5/p2p/12D3KooWETLZBFBfkzvH3BQEtA1TJZPmjb4a18ss5TpwNU7DHDX6/p2p-circuit/p2p/12D3KooWLiQ7i8sY6LkPvHmEymncicEgzrdpXegbxEr3xgN8oxMU"
433                    .parse()
434                    .unwrap()
435            ),
436            Ok(5)
437        );
438    }
439
440    #[test]
441    fn listening_twice() {
442        let mut transport = MemoryTransport::default();
443
444        let addr_1: Multiaddr = "/memory/1639174018481".parse().unwrap();
445        let addr_2: Multiaddr = "/memory/8459375923478".parse().unwrap();
446
447        let listener_id_1 = ListenerId::next();
448
449        transport.listen_on(listener_id_1, addr_1.clone()).unwrap();
450        assert!(
451            transport.remove_listener(listener_id_1),
452            "Listener doesn't exist."
453        );
454
455        let listener_id_2 = ListenerId::next();
456        transport.listen_on(listener_id_2, addr_1.clone()).unwrap();
457        let listener_id_3 = ListenerId::next();
458        transport.listen_on(listener_id_3, addr_2.clone()).unwrap();
459
460        assert!(transport
461            .listen_on(ListenerId::next(), addr_1.clone())
462            .is_err());
463        assert!(transport
464            .listen_on(ListenerId::next(), addr_2.clone())
465            .is_err());
466
467        assert!(
468            transport.remove_listener(listener_id_2),
469            "Listener doesn't exist."
470        );
471        assert!(transport.listen_on(ListenerId::next(), addr_1).is_ok());
472        assert!(transport
473            .listen_on(ListenerId::next(), addr_2.clone())
474            .is_err());
475
476        assert!(
477            transport.remove_listener(listener_id_3),
478            "Listener doesn't exist."
479        );
480        assert!(transport.listen_on(ListenerId::next(), addr_2).is_ok());
481    }
482
483    #[test]
484    fn port_not_in_use() {
485        let mut transport = MemoryTransport::default();
486        assert!(transport
487            .dial(
488                "/memory/810172461024613".parse().unwrap(),
489                DialOpts {
490                    role: Endpoint::Dialer,
491                    port_use: PortUse::New
492                }
493            )
494            .is_err());
495        transport
496            .listen_on(
497                ListenerId::next(),
498                "/memory/810172461024613".parse().unwrap(),
499            )
500            .unwrap();
501        assert!(transport
502            .dial(
503                "/memory/810172461024613".parse().unwrap(),
504                DialOpts {
505                    role: Endpoint::Dialer,
506                    port_use: PortUse::New
507                }
508            )
509            .is_ok());
510    }
511
512    #[test]
513    fn stop_listening() {
514        let rand_port = rand::random::<u64>().saturating_add(1);
515        let addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
516
517        let mut transport = MemoryTransport::default().boxed();
518        futures::executor::block_on(async {
519            let listener_id = ListenerId::next();
520            transport.listen_on(listener_id, addr.clone()).unwrap();
521            let reported_addr = transport
522                .select_next_some()
523                .await
524                .into_new_address()
525                .expect("new address");
526            assert_eq!(addr, reported_addr);
527            assert!(transport.remove_listener(listener_id));
528            match transport.select_next_some().await {
529                TransportEvent::ListenerClosed {
530                    listener_id: id,
531                    reason,
532                } => {
533                    assert_eq!(id, listener_id);
534                    assert!(reason.is_ok())
535                }
536                other => panic!("Unexpected transport event: {other:?}"),
537            }
538            assert!(!transport.remove_listener(listener_id));
539        })
540    }
541
542    #[test]
543    fn communicating_between_dialer_and_listener() {
544        let msg = [1, 2, 3];
545
546        // Setup listener.
547
548        let rand_port = rand::random::<u64>().saturating_add(1);
549        let t1_addr: Multiaddr = format!("/memory/{rand_port}").parse().unwrap();
550        let cloned_t1_addr = t1_addr.clone();
551
552        let mut t1 = MemoryTransport::default().boxed();
553
554        let listener = async move {
555            t1.listen_on(ListenerId::next(), t1_addr.clone()).unwrap();
556            let upgrade = loop {
557                let event = t1.select_next_some().await;
558                if let Some(upgrade) = event.into_incoming() {
559                    break upgrade;
560                }
561            };
562
563            let mut socket = upgrade.0.await.unwrap();
564
565            let mut buf = [0; 3];
566            socket.read_exact(&mut buf).await.unwrap();
567
568            assert_eq!(buf, msg);
569        };
570
571        // Setup dialer.
572
573        let mut t2 = MemoryTransport::default();
574        let dialer = async move {
575            let mut socket = t2
576                .dial(
577                    cloned_t1_addr,
578                    DialOpts {
579                        role: Endpoint::Dialer,
580                        port_use: PortUse::New,
581                    },
582                )
583                .unwrap()
584                .await
585                .unwrap();
586            socket.write_all(&msg).await.unwrap();
587        };
588
589        // Wait for both to finish.
590
591        futures::executor::block_on(futures::future::join(listener, dialer));
592    }
593
594    #[test]
595    fn dialer_address_unequal_to_listener_address() {
596        let listener_addr: Multiaddr =
597            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
598        let listener_addr_cloned = listener_addr.clone();
599
600        let mut listener_transport = MemoryTransport::default().boxed();
601
602        let listener = async move {
603            listener_transport
604                .listen_on(ListenerId::next(), listener_addr.clone())
605                .unwrap();
606            loop {
607                if let TransportEvent::Incoming { send_back_addr, .. } =
608                    listener_transport.select_next_some().await
609                {
610                    assert!(
611                        send_back_addr != listener_addr,
612                        "Expect dialer address not to equal listener address."
613                    );
614                    return;
615                }
616            }
617        };
618
619        let dialer = async move {
620            MemoryTransport::default()
621                .dial(
622                    listener_addr_cloned,
623                    DialOpts {
624                        role: Endpoint::Dialer,
625                        port_use: PortUse::New,
626                    },
627                )
628                .unwrap()
629                .await
630                .unwrap();
631        };
632
633        futures::executor::block_on(futures::future::join(listener, dialer));
634    }
635
636    #[test]
637    fn dialer_port_is_deregistered() {
638        let (terminate, should_terminate) = futures::channel::oneshot::channel();
639        let (terminated, is_terminated) = futures::channel::oneshot::channel();
640
641        let listener_addr: Multiaddr =
642            Protocol::Memory(rand::random::<u64>().saturating_add(1)).into();
643        let listener_addr_cloned = listener_addr.clone();
644
645        let mut listener_transport = MemoryTransport::default().boxed();
646
647        let listener = async move {
648            listener_transport
649                .listen_on(ListenerId::next(), listener_addr.clone())
650                .unwrap();
651            loop {
652                if let TransportEvent::Incoming { send_back_addr, .. } =
653                    listener_transport.select_next_some().await
654                {
655                    let dialer_port =
656                        NonZeroU64::new(parse_memory_addr(&send_back_addr).unwrap()).unwrap();
657
658                    assert!(
659                        HUB.get(&dialer_port).is_some(),
660                        "Expect dialer port to stay registered while connection is in use.",
661                    );
662
663                    terminate.send(()).unwrap();
664                    is_terminated.await.unwrap();
665
666                    assert!(
667                        HUB.get(&dialer_port).is_none(),
668                        "Expect dialer port to be deregistered once connection is dropped.",
669                    );
670
671                    return;
672                }
673            }
674        };
675
676        let dialer = async move {
677            let chan = MemoryTransport::default()
678                .dial(
679                    listener_addr_cloned,
680                    DialOpts {
681                        role: Endpoint::Dialer,
682                        port_use: PortUse::New,
683                    },
684                )
685                .unwrap()
686                .await
687                .unwrap();
688
689            should_terminate.await.unwrap();
690            drop(chan);
691            terminated.send(()).unwrap();
692        };
693
694        futures::executor::block_on(futures::future::join(listener, dialer));
695    }
696}