litep2p/transport/websocket/
mod.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rigts to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! WebSocket transport.
22
23use crate::{
24    config::Role,
25    error::{AddressError, Error, NegotiationError},
26    transport::{
27        common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
28        manager::TransportHandle,
29        websocket::{
30            config::Config,
31            connection::{NegotiatedConnection, WebSocketConnection},
32        },
33        Transport, TransportBuilder, TransportEvent,
34    },
35    types::ConnectionId,
36    DialError, PeerId,
37};
38
39use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
40use multiaddr::{Multiaddr, Protocol};
41use socket2::{Domain, Socket, Type};
42use std::net::SocketAddr;
43use tokio::net::TcpStream;
44use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
45
46use url::Url;
47
48use std::{
49    collections::{HashMap, HashSet},
50    pin::Pin,
51    task::{Context, Poll},
52    time::Duration,
53};
54
55pub(crate) use substream::Substream;
56
57mod connection;
58mod stream;
59mod substream;
60
61pub mod config;
62
63/// Logging target for the file.
64const LOG_TARGET: &str = "litep2p::websocket";
65
66/// Pending inbound connection.
67struct PendingInboundConnection {
68    /// Socket address of the remote peer.
69    connection: TcpStream,
70    /// Address of the remote peer.
71    address: SocketAddr,
72}
73
74/// WebSocket transport.
75pub(crate) struct WebSocketTransport {
76    /// Transport context.
77    context: TransportHandle,
78
79    /// Transport configuration.
80    config: Config,
81
82    /// WebSocket listener.
83    listener: SocketListener,
84
85    /// Dial addresses.
86    dial_addresses: DialAddresses,
87
88    /// Pending dials.
89    pending_dials: HashMap<ConnectionId, Multiaddr>,
90
91    /// Pending inbound connections.
92    pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
93
94    /// Pending connections.
95    pending_connections: FuturesUnordered<
96        BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>,
97    >,
98
99    /// Pending raw, unnegotiated connections.
100    pending_raw_connections: FuturesUnordered<
101        BoxFuture<
102            'static,
103            Result<
104                (
105                    ConnectionId,
106                    Multiaddr,
107                    WebSocketStream<MaybeTlsStream<TcpStream>>,
108                ),
109                (ConnectionId, Vec<(Multiaddr, DialError)>),
110            >,
111        >,
112    >,
113
114    /// Opened raw connection, waiting for approval/rejection from `TransportManager`.
115    opened_raw: HashMap<ConnectionId, (WebSocketStream<MaybeTlsStream<TcpStream>>, Multiaddr)>,
116
117    /// Canceled raw connections.
118    canceled: HashSet<ConnectionId>,
119
120    /// Negotiated connections waiting validation.
121    pending_open: HashMap<ConnectionId, NegotiatedConnection>,
122}
123
124impl WebSocketTransport {
125    /// Handle inbound connection.
126    fn on_inbound_connection(
127        &mut self,
128        connection_id: ConnectionId,
129        connection: TcpStream,
130        address: SocketAddr,
131    ) {
132        let keypair = self.context.keypair.clone();
133        let yamux_config = self.config.yamux_config.clone();
134        let connection_open_timeout = self.config.connection_open_timeout;
135        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
136        let max_write_buffer_size = self.config.noise_write_buffer_size;
137        let address = Multiaddr::empty()
138            .with(Protocol::from(address.ip()))
139            .with(Protocol::Tcp(address.port()))
140            .with(Protocol::Ws(std::borrow::Cow::Borrowed("/")));
141
142        self.pending_connections.push(Box::pin(async move {
143            match tokio::time::timeout(connection_open_timeout, async move {
144                WebSocketConnection::accept_connection(
145                    connection,
146                    connection_id,
147                    keypair,
148                    address,
149                    yamux_config,
150                    max_read_ahead_factor,
151                    max_write_buffer_size,
152                )
153                .await
154                .map_err(|error| (connection_id, error.into()))
155            })
156            .await
157            {
158                Err(_) => Err((connection_id, DialError::Timeout)),
159                Ok(Err(error)) => Err(error),
160                Ok(Ok(result)) => Ok(result),
161            }
162        }));
163    }
164
165    /// Convert `Multiaddr` into `url::Url`
166    fn multiaddr_into_url(address: Multiaddr) -> Result<(Url, PeerId), AddressError> {
167        let mut protocol_stack = address.iter();
168
169        let dial_address = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
170            Protocol::Ip4(address) => address.to_string(),
171            Protocol::Ip6(address) => format!("[{address}]"),
172            Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) =>
173                address.to_string(),
174
175            _ => return Err(AddressError::InvalidProtocol),
176        };
177
178        let url = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
179            Protocol::Tcp(port) => match protocol_stack.next() {
180                Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"),
181                Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"),
182                _ => return Err(AddressError::InvalidProtocol),
183            },
184            _ => return Err(AddressError::InvalidProtocol),
185        };
186
187        let peer = match protocol_stack.next() {
188            Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
189            protocol => {
190                tracing::warn!(
191                    target: LOG_TARGET,
192                    ?protocol,
193                    "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`",
194                );
195                return Err(AddressError::PeerIdMissing);
196            }
197        };
198
199        tracing::trace!(target: LOG_TARGET, ?url, "parse address");
200
201        url::Url::parse(&url)
202            .map(|url| (url, peer))
203            .map_err(|_| AddressError::InvalidUrl)
204    }
205
206    /// Dial remote peer over `address`.
207    async fn dial_peer(
208        address: Multiaddr,
209        dial_addresses: DialAddresses,
210        connection_open_timeout: Duration,
211        nodelay: bool,
212    ) -> Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>), DialError> {
213        let (url, _) = Self::multiaddr_into_url(address.clone())?;
214
215        let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
216        let remote_address =
217            match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
218                Err(_) => return Err(DialError::Timeout),
219                Ok(Err(error)) => return Err(error.into()),
220                Ok(Ok(address)) => address,
221            };
222
223        let domain = match remote_address.is_ipv4() {
224            true => Domain::IPV4,
225            false => Domain::IPV6,
226        };
227        let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
228        if remote_address.is_ipv6() {
229            socket.set_only_v6(true)?;
230        }
231        socket.set_nonblocking(true)?;
232        socket.set_nodelay(nodelay)?;
233
234        match dial_addresses.local_dial_address(&remote_address.ip()) {
235            Ok(Some(dial_address)) => {
236                socket.set_reuse_address(true)?;
237                #[cfg(unix)]
238                socket.set_reuse_port(true)?;
239                socket.bind(&dial_address.into())?;
240            }
241            Ok(None) => {}
242            Err(()) => {
243                tracing::debug!(
244                    target: LOG_TARGET,
245                    ?remote_address,
246                    "tcp listener not enabled for remote address, using ephemeral port",
247                );
248            }
249        }
250
251        let future = async move {
252            match socket.connect(&remote_address.into()) {
253                Ok(()) => {}
254                Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {}
255                Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {}
256                Err(err) => return Err(DialError::from(err)),
257            }
258
259            let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
260            stream.writable().await?;
261            if let Some(e) = stream.take_error()? {
262                return Err(DialError::from(e));
263            }
264
265            Ok((
266                address,
267                tokio_tungstenite::client_async_tls(url, stream)
268                    .await
269                    .map_err(NegotiationError::WebSocket)?
270                    .0,
271            ))
272        };
273
274        match tokio::time::timeout(connection_open_timeout, future).await {
275            Err(_) => Err(DialError::Timeout),
276            Ok(Err(error)) => Err(error),
277            Ok(Ok((address, stream))) => Ok((address, stream)),
278        }
279    }
280}
281
282impl TransportBuilder for WebSocketTransport {
283    type Config = Config;
284    type Transport = WebSocketTransport;
285
286    /// Create new [`Transport`] object.
287    fn new(
288        context: TransportHandle,
289        mut config: Self::Config,
290    ) -> crate::Result<(Self, Vec<Multiaddr>)>
291    where
292        Self: Sized,
293    {
294        tracing::debug!(
295            target: LOG_TARGET,
296            listen_addresses = ?config.listen_addresses,
297            "start websocket transport",
298        );
299        let (listener, listen_addresses, dial_addresses) = SocketListener::new::<WebSocketAddress>(
300            std::mem::take(&mut config.listen_addresses),
301            config.reuse_port,
302            config.nodelay,
303        );
304
305        Ok((
306            Self {
307                listener,
308                config,
309                context,
310                dial_addresses,
311                canceled: HashSet::new(),
312                opened_raw: HashMap::new(),
313                pending_open: HashMap::new(),
314                pending_dials: HashMap::new(),
315                pending_inbound_connections: HashMap::new(),
316                pending_connections: FuturesUnordered::new(),
317                pending_raw_connections: FuturesUnordered::new(),
318            },
319            listen_addresses,
320        ))
321    }
322}
323
324impl Transport for WebSocketTransport {
325    fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
326        let yamux_config = self.config.yamux_config.clone();
327        let keypair = self.context.keypair.clone();
328        let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?;
329        let connection_open_timeout = self.config.connection_open_timeout;
330        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
331        let max_write_buffer_size = self.config.noise_write_buffer_size;
332        let dial_addresses = self.dial_addresses.clone();
333        let nodelay = self.config.nodelay;
334
335        self.pending_dials.insert(connection_id, address.clone());
336
337        tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
338
339        let future = async move {
340            let (_, stream) = WebSocketTransport::dial_peer(
341                address.clone(),
342                dial_addresses,
343                connection_open_timeout,
344                nodelay,
345            )
346            .await
347            .map_err(|error| (connection_id, error))?;
348
349            WebSocketConnection::open_connection(
350                connection_id,
351                keypair,
352                stream,
353                address,
354                peer,
355                ws_address,
356                yamux_config,
357                max_read_ahead_factor,
358                max_write_buffer_size,
359            )
360            .await
361            .map_err(|error| (connection_id, error.into()))
362        };
363
364        self.pending_connections.push(Box::pin(async move {
365            match tokio::time::timeout(connection_open_timeout, future).await {
366                Err(_) => Err((connection_id, DialError::Timeout)),
367                Ok(Err(error)) => Err(error),
368                Ok(Ok(result)) => Ok(result),
369            }
370        }));
371
372        Ok(())
373    }
374
375    fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
376        let context = self
377            .pending_open
378            .remove(&connection_id)
379            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
380        let protocol_set = self.context.protocol_set(connection_id);
381        let bandwidth_sink = self.context.bandwidth_sink.clone();
382        let substream_open_timeout = self.config.substream_open_timeout;
383
384        tracing::trace!(
385            target: LOG_TARGET,
386            ?connection_id,
387            "start connection",
388        );
389
390        self.context.executor.run(Box::pin(async move {
391            if let Err(error) = WebSocketConnection::new(
392                context,
393                protocol_set,
394                bandwidth_sink,
395                substream_open_timeout,
396            )
397            .start()
398            .await
399            {
400                tracing::debug!(
401                    target: LOG_TARGET,
402                    ?connection_id,
403                    ?error,
404                    "connection exited with error",
405                );
406            }
407        }));
408
409        Ok(())
410    }
411
412    fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
413        self.pending_open
414            .remove(&connection_id)
415            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
416    }
417
418    fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
419        let pending = self
420            .pending_inbound_connections
421            .remove(&connection_id)
422            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
423
424        self.on_inbound_connection(connection_id, pending.connection, pending.address);
425
426        Ok(())
427    }
428
429    fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
430        self.pending_open
431            .remove(&connection_id)
432            .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
433    }
434
435    fn open(
436        &mut self,
437        connection_id: ConnectionId,
438        addresses: Vec<Multiaddr>,
439    ) -> crate::Result<()> {
440        let num_addresses = addresses.len();
441        let mut futures: FuturesUnordered<_> = addresses
442            .into_iter()
443            .map(|address| {
444                let connection_open_timeout = self.config.connection_open_timeout;
445                let dial_addresses = self.dial_addresses.clone();
446                let nodelay = self.config.nodelay;
447
448                async move {
449                    WebSocketTransport::dial_peer(
450                        address.clone(),
451                        dial_addresses,
452                        connection_open_timeout,
453                        nodelay,
454                    )
455                    .await
456                    .map_err(|error| (address, error))
457                }
458            })
459            .collect();
460
461        self.pending_raw_connections.push(Box::pin(async move {
462            let mut errors = Vec::with_capacity(num_addresses);
463
464            while let Some(result) = futures.next().await {
465                match result {
466                    Ok((address, stream)) => return Ok((connection_id, address, stream)),
467                    Err(error) => {
468                        tracing::debug!(
469                            target: LOG_TARGET,
470                            ?connection_id,
471                            ?error,
472                            "failed to open connection",
473                        );
474                        errors.push(error)
475                    }
476                }
477            }
478
479            Err((connection_id, errors))
480        }));
481
482        Ok(())
483    }
484
485    fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
486        let (stream, address) = self
487            .opened_raw
488            .remove(&connection_id)
489            .ok_or(Error::ConnectionDoesntExist(connection_id))?;
490
491        let peer = match address.iter().find(|protocol| std::matches!(protocol, Protocol::P2p(_))) {
492            Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
493            _ => return Err(Error::InvalidState),
494        };
495        let yamux_config = self.config.yamux_config.clone();
496        let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
497        let max_write_buffer_size = self.config.noise_write_buffer_size;
498        let connection_open_timeout = self.config.connection_open_timeout;
499        let keypair = self.context.keypair.clone();
500
501        tracing::trace!(
502            target: LOG_TARGET,
503            ?peer,
504            ?connection_id,
505            ?address,
506            "negotiate connection",
507        );
508
509        self.pending_dials.insert(connection_id, address.clone());
510        self.pending_connections.push(Box::pin(async move {
511            match tokio::time::timeout(connection_open_timeout, async move {
512                WebSocketConnection::negotiate_connection(
513                    stream,
514                    Some(peer),
515                    Role::Dialer,
516                    address,
517                    connection_id,
518                    keypair,
519                    yamux_config,
520                    max_read_ahead_factor,
521                    max_write_buffer_size,
522                )
523                .await
524                .map_err(|error| (connection_id, error.into()))
525            })
526            .await
527            {
528                Err(_) => Err((connection_id, DialError::Timeout)),
529                Ok(Err(error)) => Err(error),
530                Ok(Ok(connection)) => Ok(connection),
531            }
532        }));
533
534        Ok(())
535    }
536
537    fn cancel(&mut self, connection_id: ConnectionId) {
538        self.canceled.insert(connection_id);
539    }
540}
541
542impl Stream for WebSocketTransport {
543    type Item = TransportEvent;
544
545    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
546        if let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) {
547            return match connection {
548                Err(_) => Poll::Ready(None),
549                Ok((connection, address)) => {
550                    let connection_id = self.context.next_connection_id();
551
552                    self.pending_inbound_connections.insert(
553                        connection_id,
554                        PendingInboundConnection {
555                            connection,
556                            address,
557                        },
558                    );
559
560                    Poll::Ready(Some(TransportEvent::PendingInboundConnection {
561                        connection_id,
562                    }))
563                }
564            };
565        }
566
567        while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
568            match result {
569                Ok((connection_id, address, stream)) => {
570                    tracing::trace!(
571                        target: LOG_TARGET,
572                        ?connection_id,
573                        ?address,
574                        canceled = self.canceled.contains(&connection_id),
575                        "connection opened",
576                    );
577
578                    if !self.canceled.remove(&connection_id) {
579                        self.opened_raw.insert(connection_id, (stream, address.clone()));
580
581                        return Poll::Ready(Some(TransportEvent::ConnectionOpened {
582                            connection_id,
583                            address,
584                        }));
585                    }
586                }
587                Err((connection_id, errors)) =>
588                    if !self.canceled.remove(&connection_id) {
589                        return Poll::Ready(Some(TransportEvent::OpenFailure {
590                            connection_id,
591                            errors,
592                        }));
593                    },
594            }
595        }
596
597        while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
598            match connection {
599                Ok(connection) => {
600                    let peer = connection.peer();
601                    let endpoint = connection.endpoint();
602                    self.pending_open.insert(connection.connection_id(), connection);
603
604                    return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
605                        peer,
606                        endpoint,
607                    }));
608                }
609                Err((connection_id, error)) => {
610                    if let Some(address) = self.pending_dials.remove(&connection_id) {
611                        return Poll::Ready(Some(TransportEvent::DialFailure {
612                            connection_id,
613                            address,
614                            error,
615                        }));
616                    } else {
617                        tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
618                    }
619                }
620            }
621        }
622
623        Poll::Pending
624    }
625}