litep2p/transport/websocket/
mod.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rigts to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! WebSocket transport.
22
23use crate::{
24    error::{AddressError, Error, NegotiationError},
25    transport::{
26        common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
27        manager::TransportHandle,
28        websocket::{
29            config::Config,
30            connection::{NegotiatedConnection, WebSocketConnection},
31        },
32        Transport, TransportBuilder, TransportEvent,
33    },
34    types::ConnectionId,
35    utils::futures_stream::FuturesStream,
36    DialError, PeerId,
37};
38
39use futures::{
40    future::BoxFuture,
41    stream::{AbortHandle, FuturesUnordered},
42    Stream, StreamExt, TryFutureExt,
43};
44use hickory_resolver::TokioResolver;
45use multiaddr::{Multiaddr, Protocol};
46use socket2::{Domain, Socket, Type};
47use std::{net::SocketAddr, sync::Arc};
48use tokio::net::TcpStream;
49use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
50
51use url::Url;
52
53use std::{
54    collections::HashMap,
55    pin::Pin,
56    task::{Context, Poll},
57    time::Duration,
58};
59
60pub(crate) use substream::Substream;
61
62mod connection;
63mod stream;
64mod substream;
65
66pub mod config;
67
68/// Logging target for the file.
69const LOG_TARGET: &str = "litep2p::websocket";
70
71/// Pending inbound connection.
72struct PendingInboundConnection {
73    /// Socket address of the remote peer.
74    connection: TcpStream,
75    /// Address of the remote peer.
76    address: SocketAddr,
77}
78
79#[derive(Debug)]
80enum RawConnectionResult {
81    /// The first successful connection.
82    Connected {
83        negotiated: NegotiatedConnection,
84        errors: Vec<(Multiaddr, DialError)>,
85    },
86
87    /// All connection attempts failed.
88    Failed {
89        connection_id: ConnectionId,
90        errors: Vec<(Multiaddr, DialError)>,
91    },
92
93    /// Future was canceled.
94    Canceled { connection_id: ConnectionId },
95}
96
97/// WebSocket transport.
98pub(crate) struct WebSocketTransport {
99    /// Transport context.
100    context: TransportHandle,
101
102    /// Transport configuration.
103    config: Config,
104
105    /// WebSocket listener.
106    listener: SocketListener,
107
108    /// Dial addresses.
109    dial_addresses: DialAddresses,
110
111    /// Pending dials.
112    pending_dials: HashMap<ConnectionId, Multiaddr>,
113
114    /// Pending inbound connections.
115    pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
116
117    /// Pending connections.
118    pending_connections:
119        FuturesStream<BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>>,
120
121    /// Pending raw, unnegotiated connections.
122    pending_raw_connections: FuturesStream<BoxFuture<'static, RawConnectionResult>>,
123
124    /// Opened raw connection, waiting for approval/rejection from `TransportManager`.
125    opened: HashMap<ConnectionId, NegotiatedConnection>,
126
127    /// Cancel raw connections futures.
128    ///
129    /// This is cancelling `Self::pending_raw_connections`.
130    cancel_futures: HashMap<ConnectionId, AbortHandle>,
131
132    /// Negotiated connections waiting validation.
133    pending_open: HashMap<ConnectionId, NegotiatedConnection>,
134
135    /// DNS resolver.
136    resolver: Arc<TokioResolver>,
137}
138
139impl WebSocketTransport {
140    /// Handle inbound connection.
141    fn on_inbound_connection(
142        &mut self,
143        connection_id: ConnectionId,
144        connection: TcpStream,
145        address: SocketAddr,
146    ) {
147        let keypair = self.context.keypair.clone();
148        let yamux_config = self.config.yamux_config.clone();
149        let connection_open_timeout = self.config.connection_open_timeout;
150        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
151        let max_write_buffer_size = self.config.noise_write_buffer_size;
152        let substream_open_timeout = self.config.substream_open_timeout;
153        let address = Multiaddr::empty()
154            .with(Protocol::from(address.ip()))
155            .with(Protocol::Tcp(address.port()))
156            .with(Protocol::Ws(std::borrow::Cow::Borrowed("/")));
157
158        self.pending_connections.push(Box::pin(async move {
159            match tokio::time::timeout(connection_open_timeout, async move {
160                WebSocketConnection::accept_connection(
161                    connection,
162                    connection_id,
163                    keypair,
164                    address,
165                    yamux_config,
166                    max_read_ahead_factor,
167                    max_write_buffer_size,
168                    substream_open_timeout,
169                )
170                .await
171                .map_err(|error| (connection_id, error.into()))
172            })
173            .await
174            {
175                Err(_) => Err((connection_id, DialError::Timeout)),
176                Ok(Err(error)) => Err(error),
177                Ok(Ok(result)) => Ok(result),
178            }
179        }));
180    }
181
182    /// Convert `Multiaddr` into `url::Url`
183    fn multiaddr_into_url(address: Multiaddr) -> Result<(Url, PeerId), AddressError> {
184        let mut protocol_stack = address.iter();
185
186        let dial_address = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
187            Protocol::Ip4(address) => address.to_string(),
188            Protocol::Ip6(address) => format!("[{address}]"),
189            Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) =>
190                address.to_string(),
191
192            _ => return Err(AddressError::InvalidProtocol),
193        };
194
195        let url = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
196            Protocol::Tcp(port) => match protocol_stack.next() {
197                Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"),
198                Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"),
199                _ => return Err(AddressError::InvalidProtocol),
200            },
201            _ => return Err(AddressError::InvalidProtocol),
202        };
203
204        let peer = match protocol_stack.next() {
205            Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
206            protocol => {
207                tracing::warn!(
208                    target: LOG_TARGET,
209                    ?protocol,
210                    "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`",
211                );
212                return Err(AddressError::PeerIdMissing);
213            }
214        };
215
216        tracing::trace!(target: LOG_TARGET, ?url, "parse address");
217
218        url::Url::parse(&url)
219            .map(|url| (url, peer))
220            .map_err(|_| AddressError::InvalidUrl)
221    }
222
223    /// Dial remote peer over `address`.
224    async fn dial_peer(
225        address: Multiaddr,
226        dial_addresses: DialAddresses,
227        connection_open_timeout: Duration,
228        nodelay: bool,
229        resolver: Arc<TokioResolver>,
230    ) -> Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>), DialError> {
231        let (url, _) = Self::multiaddr_into_url(address.clone())?;
232
233        let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
234        let remote_address =
235            match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip(resolver))
236                .await
237            {
238                Err(_) => return Err(DialError::Timeout),
239                Ok(Err(error)) => return Err(error.into()),
240                Ok(Ok(address)) => address,
241            };
242
243        let domain = match remote_address.is_ipv4() {
244            true => Domain::IPV4,
245            false => Domain::IPV6,
246        };
247        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
248        if remote_address.is_ipv6() {
249            socket.set_only_v6(true)?;
250        }
251        socket.set_nonblocking(true)?;
252        socket.set_nodelay(nodelay)?;
253
254        match dial_addresses.local_dial_address(&remote_address.ip()) {
255            Ok(Some(dial_address)) => {
256                socket.set_reuse_address(true)?;
257                #[cfg(unix)]
258                socket.set_reuse_port(true)?;
259                socket.bind(&dial_address.into())?;
260            }
261            Ok(None) => {}
262            Err(()) => {
263                tracing::debug!(
264                    target: LOG_TARGET,
265                    ?remote_address,
266                    "tcp listener not enabled for remote address, using ephemeral port",
267                );
268            }
269        }
270
271        let future = async move {
272            match socket.connect(&remote_address.into()) {
273                Ok(()) => {}
274                Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {}
275                Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {}
276                Err(err) => return Err(DialError::from(err)),
277            }
278
279            let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
280            stream.writable().await?;
281            if let Some(e) = stream.take_error()? {
282                return Err(DialError::from(e));
283            }
284
285            Ok((
286                address,
287                tokio_tungstenite::client_async_tls(url, stream)
288                    .await
289                    .map_err(NegotiationError::WebSocket)?
290                    .0,
291            ))
292        };
293
294        match tokio::time::timeout(connection_open_timeout, future).await {
295            Err(_) => Err(DialError::Timeout),
296            Ok(Err(error)) => Err(error),
297            Ok(Ok((address, stream))) => Ok((address, stream)),
298        }
299    }
300}
301
302impl TransportBuilder for WebSocketTransport {
303    type Config = Config;
304    type Transport = WebSocketTransport;
305
306    /// Create new [`Transport`] object.
307    fn new(
308        context: TransportHandle,
309        mut config: Self::Config,
310        resolver: Arc<TokioResolver>,
311    ) -> crate::Result<(Self, Vec<Multiaddr>)>
312    where
313        Self: Sized,
314    {
315        tracing::debug!(
316            target: LOG_TARGET,
317            listen_addresses = ?config.listen_addresses,
318            "start websocket transport",
319        );
320        let (listener, listen_addresses, dial_addresses) = SocketListener::new::<WebSocketAddress>(
321            std::mem::take(&mut config.listen_addresses),
322            config.reuse_port,
323            config.nodelay,
324        );
325
326        Ok((
327            Self {
328                listener,
329                config,
330                context,
331                dial_addresses,
332                opened: HashMap::new(),
333                pending_open: HashMap::new(),
334                pending_dials: HashMap::new(),
335                pending_inbound_connections: HashMap::new(),
336                pending_connections: FuturesStream::new(),
337                pending_raw_connections: FuturesStream::new(),
338                cancel_futures: HashMap::new(),
339                resolver,
340            },
341            listen_addresses,
342        ))
343    }
344}
345
346impl Transport for WebSocketTransport {
347    fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
348        let yamux_config = self.config.yamux_config.clone();
349        let keypair = self.context.keypair.clone();
350        let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?;
351        let connection_open_timeout = self.config.connection_open_timeout;
352        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
353        let max_write_buffer_size = self.config.noise_write_buffer_size;
354        let substream_open_timeout = self.config.substream_open_timeout;
355        let dial_addresses = self.dial_addresses.clone();
356        let nodelay = self.config.nodelay;
357        let resolver = self.resolver.clone();
358
359        self.pending_dials.insert(connection_id, address.clone());
360
361        tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
362
363        let future = async move {
364            let (_, stream) = WebSocketTransport::dial_peer(
365                address.clone(),
366                dial_addresses,
367                connection_open_timeout,
368                nodelay,
369                resolver,
370            )
371            .await
372            .map_err(|error| (connection_id, error))?;
373
374            WebSocketConnection::open_connection(
375                connection_id,
376                keypair,
377                stream,
378                address,
379                peer,
380                ws_address,
381                yamux_config,
382                max_read_ahead_factor,
383                max_write_buffer_size,
384                substream_open_timeout,
385            )
386            .await
387            .map_err(|error| (connection_id, error.into()))
388        };
389
390        self.pending_connections.push(Box::pin(async move {
391            match tokio::time::timeout(connection_open_timeout, future).await {
392                Err(_) => Err((connection_id, DialError::Timeout)),
393                Ok(Err(error)) => Err(error),
394                Ok(Ok(result)) => Ok(result),
395            }
396        }));
397
398        Ok(())
399    }
400
401    fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
402        let context = self
403            .pending_open
404            .remove(&connection_id)
405            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
406        let protocol_set = self.context.protocol_set(connection_id);
407        let bandwidth_sink = self.context.bandwidth_sink.clone();
408        let substream_open_timeout = self.config.substream_open_timeout;
409
410        tracing::trace!(
411            target: LOG_TARGET,
412            ?connection_id,
413            "start connection",
414        );
415
416        self.context.executor.run(Box::pin(async move {
417            if let Err(error) = WebSocketConnection::new(
418                context,
419                protocol_set,
420                bandwidth_sink,
421                substream_open_timeout,
422            )
423            .start()
424            .await
425            {
426                tracing::debug!(
427                    target: LOG_TARGET,
428                    ?connection_id,
429                    ?error,
430                    "connection exited with error",
431                );
432            }
433        }));
434
435        Ok(())
436    }
437
438    fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
439        self.pending_open
440            .remove(&connection_id)
441            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
442    }
443
444    fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
445        let pending = self.pending_inbound_connections.remove(&connection_id).ok_or_else(|| {
446            tracing::error!(
447                target: LOG_TARGET,
448                ?connection_id,
449                "Cannot accept non existent pending connection",
450            );
451
452            Error::ConnectionDoesntExist(connection_id)
453        })?;
454
455        self.on_inbound_connection(connection_id, pending.connection, pending.address);
456
457        Ok(())
458    }
459
460    fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
461        self.pending_inbound_connections.remove(&connection_id).map_or_else(
462            || {
463                tracing::error!(
464                    target: LOG_TARGET,
465                    ?connection_id,
466                    "Cannot reject non existent pending connection",
467                );
468
469                Err(Error::ConnectionDoesntExist(connection_id))
470            },
471            |_| Ok(()),
472        )
473    }
474
475    fn open(
476        &mut self,
477        connection_id: ConnectionId,
478        addresses: Vec<Multiaddr>,
479    ) -> crate::Result<()> {
480        let num_addresses = addresses.len();
481
482        let mut futures: FuturesUnordered<_> = addresses
483            .into_iter()
484            .map(|address| {
485                let yamux_config = self.config.yamux_config.clone();
486                let keypair = self.context.keypair.clone();
487                let connection_open_timeout = self.config.connection_open_timeout;
488                let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
489                let max_write_buffer_size = self.config.noise_write_buffer_size;
490                let substream_open_timeout = self.config.substream_open_timeout;
491                let dial_addresses = self.dial_addresses.clone();
492                let nodelay = self.config.nodelay;
493                let resolver = self.resolver.clone();
494
495                async move {
496                    let (address, stream) = WebSocketTransport::dial_peer(
497                        address.clone(),
498                        dial_addresses,
499                        connection_open_timeout,
500                        nodelay,
501                        resolver,
502                    )
503                    .await
504                    .map_err(|error| (address, error))?;
505
506                    let open_address = address.clone();
507                    let (ws_address, peer) = Self::multiaddr_into_url(address.clone())
508                        .map_err(|error| (address.clone(), error.into()))?;
509
510                    WebSocketConnection::open_connection(
511                        connection_id,
512                        keypair,
513                        stream,
514                        address,
515                        peer,
516                        ws_address,
517                        yamux_config,
518                        max_read_ahead_factor,
519                        max_write_buffer_size,
520                        substream_open_timeout,
521                    )
522                    .await
523                    .map_err(|error| (open_address, error.into()))
524                }
525            })
526            .collect();
527
528        // Future that will resolve to the first successful connection.
529        let future = async move {
530            let mut errors = Vec::with_capacity(num_addresses);
531            while let Some(result) = futures.next().await {
532                match result {
533                    Ok(negotiated) => return RawConnectionResult::Connected { negotiated, errors },
534                    Err(error) => {
535                        tracing::debug!(
536                            target: LOG_TARGET,
537                            ?connection_id,
538                            ?error,
539                            "failed to open connection",
540                        );
541                        errors.push(error)
542                    }
543                }
544            }
545
546            RawConnectionResult::Failed {
547                connection_id,
548                errors,
549            }
550        };
551
552        let (fut, handle) = futures::future::abortable(future);
553        let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id });
554        self.pending_raw_connections.push(Box::pin(fut));
555        self.cancel_futures.insert(connection_id, handle);
556
557        Ok(())
558    }
559
560    fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
561        let negotiated = self
562            .opened
563            .remove(&connection_id)
564            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
565
566        self.pending_connections.push(Box::pin(async move { Ok(negotiated) }));
567
568        Ok(())
569    }
570
571    fn cancel(&mut self, connection_id: ConnectionId) {
572        // Cancel the future if it exists.
573        // State clean-up happens inside the `poll_next`.
574        if let Some(handle) = self.cancel_futures.get(&connection_id) {
575            handle.abort();
576        }
577    }
578}
579
580impl Stream for WebSocketTransport {
581    type Item = TransportEvent;
582
583    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
584        if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
585            return match event {
586                None => {
587                    tracing::error!(
588                        target: LOG_TARGET,
589                        "Websocket listener terminated, ignore if the node is stopping",
590                    );
591
592                    Poll::Ready(None)
593                }
594                Some(Err(error)) => {
595                    tracing::error!(
596                        target: LOG_TARGET,
597                        ?error,
598                        "Websocket listener terminated with error",
599                    );
600
601                    Poll::Ready(None)
602                }
603                Some(Ok((connection, address))) => {
604                    let connection_id = self.context.next_connection_id();
605                    tracing::trace!(
606                        target: LOG_TARGET,
607                        ?connection_id,
608                        ?address,
609                        "pending inbound Websocket connection",
610                    );
611
612                    self.pending_inbound_connections.insert(
613                        connection_id,
614                        PendingInboundConnection {
615                            connection,
616                            address,
617                        },
618                    );
619
620                    Poll::Ready(Some(TransportEvent::PendingInboundConnection {
621                        connection_id,
622                    }))
623                }
624            };
625        }
626
627        while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
628            tracing::trace!(target: LOG_TARGET, ?result, "raw connection result");
629
630            match result {
631                RawConnectionResult::Connected { negotiated, errors } => {
632                    let Some(handle) = self.cancel_futures.remove(&negotiated.connection_id())
633                    else {
634                        tracing::warn!(
635                            target: LOG_TARGET,
636                            connection_id = ?negotiated.connection_id(),
637                            address = ?negotiated.endpoint().address(),
638                            ?errors,
639                            "raw connection without a cancel handle",
640                        );
641                        continue;
642                    };
643
644                    if !handle.is_aborted() {
645                        let connection_id = negotiated.connection_id();
646                        let address = negotiated.endpoint().address().clone();
647
648                        self.opened.insert(connection_id, negotiated);
649
650                        return Poll::Ready(Some(TransportEvent::ConnectionOpened {
651                            connection_id,
652                            address,
653                        }));
654                    }
655                }
656
657                RawConnectionResult::Failed {
658                    connection_id,
659                    errors,
660                } => {
661                    let Some(handle) = self.cancel_futures.remove(&connection_id) else {
662                        tracing::warn!(
663                            target: LOG_TARGET,
664                            ?connection_id,
665                            ?errors,
666                            "raw connection without a cancel handle",
667                        );
668                        continue;
669                    };
670
671                    if !handle.is_aborted() {
672                        return Poll::Ready(Some(TransportEvent::OpenFailure {
673                            connection_id,
674                            errors,
675                        }));
676                    }
677                }
678                RawConnectionResult::Canceled { connection_id } => {
679                    if self.cancel_futures.remove(&connection_id).is_none() {
680                        tracing::warn!(
681                            target: LOG_TARGET,
682                            ?connection_id,
683                            "raw cancelled connection without a cancel handle",
684                        );
685                    }
686                }
687            }
688        }
689
690        while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
691            match connection {
692                Ok(connection) => {
693                    let peer = connection.peer();
694                    let endpoint = connection.endpoint();
695                    self.pending_dials.remove(&connection.connection_id());
696                    self.pending_open.insert(connection.connection_id(), connection);
697
698                    return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
699                        peer,
700                        endpoint,
701                    }));
702                }
703                Err((connection_id, error)) => {
704                    if let Some(address) = self.pending_dials.remove(&connection_id) {
705                        return Poll::Ready(Some(TransportEvent::DialFailure {
706                            connection_id,
707                            address,
708                            error,
709                        }));
710                    } else {
711                        tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
712                    }
713                }
714            }
715        }
716
717        Poll::Pending
718    }
719}