litep2p/transport/websocket/
connection.rs

1// Copyright 2023 litep2p developers
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{
22    config::Role,
23    crypto::{
24        ed25519::Keypair,
25        noise::{self, NoiseSocket},
26    },
27    error::{Error, NegotiationError, SubstreamError},
28    multistream_select::{dialer_select_proto, listener_select_proto, Negotiated, Version},
29    protocol::{Direction, Permit, ProtocolCommand, ProtocolSet},
30    substream,
31    transport::{
32        websocket::{stream::BufferedStream, substream::Substream},
33        Endpoint,
34    },
35    types::{protocol::ProtocolName, ConnectionId, SubstreamId},
36    BandwidthSink, PeerId,
37};
38
39use futures::{future::BoxFuture, stream::FuturesUnordered, AsyncRead, AsyncWrite, StreamExt};
40use multiaddr::{multihash::Multihash, Multiaddr, Protocol};
41use tokio::net::TcpStream;
42use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
43use tokio_util::compat::FuturesAsyncReadCompatExt;
44use url::Url;
45
46use std::time::Duration;
47
48mod schema {
49    pub(super) mod noise {
50        include!(concat!(env!("OUT_DIR"), "/noise.rs"));
51    }
52}
53
54/// Logging target for the file.
55const LOG_TARGET: &str = "litep2p::websocket::connection";
56
57/// Negotiated substream and its context.
58pub struct NegotiatedSubstream {
59    /// Substream direction.
60    direction: Direction,
61
62    /// Substream ID.
63    substream_id: SubstreamId,
64
65    /// Protocol name.
66    protocol: ProtocolName,
67
68    /// Yamux substream.
69    io: crate::yamux::Stream,
70
71    /// Permit.
72    permit: Permit,
73}
74
75/// WebSocket connection error.
76#[derive(Debug)]
77enum ConnectionError {
78    /// Timeout
79    Timeout {
80        /// Protocol.
81        protocol: Option<ProtocolName>,
82
83        /// Substream ID.
84        substream_id: Option<SubstreamId>,
85    },
86
87    /// Failed to negotiate connection/substream.
88    FailedToNegotiate {
89        /// Protocol.
90        protocol: Option<ProtocolName>,
91
92        /// Substream ID.
93        substream_id: Option<SubstreamId>,
94
95        /// Error.
96        error: SubstreamError,
97    },
98}
99
100/// Negotiated connection.
101pub(super) struct NegotiatedConnection {
102    /// Remote peer ID.
103    peer: PeerId,
104
105    /// Endpoint.
106    endpoint: Endpoint,
107
108    /// Yamux connection.
109    connection:
110        crate::yamux::ControlledConnection<NoiseSocket<BufferedStream<MaybeTlsStream<TcpStream>>>>,
111
112    /// Yamux control.
113    control: crate::yamux::Control,
114}
115
116impl NegotiatedConnection {
117    /// Get `ConnectionId` of the negotiated connection.
118    pub fn connection_id(&self) -> ConnectionId {
119        self.endpoint.connection_id()
120    }
121
122    /// Get `PeerId` of the negotiated connection.
123    pub fn peer(&self) -> PeerId {
124        self.peer
125    }
126
127    /// Get `Endpoint` of the negotiated connection.
128    pub fn endpoint(&self) -> Endpoint {
129        self.endpoint.clone()
130    }
131}
132
133/// WebSocket connection.
134pub(crate) struct WebSocketConnection {
135    /// Protocol context.
136    protocol_set: ProtocolSet,
137
138    /// Yamux connection.
139    connection:
140        crate::yamux::ControlledConnection<NoiseSocket<BufferedStream<MaybeTlsStream<TcpStream>>>>,
141
142    /// Yamux control.
143    control: crate::yamux::Control,
144
145    /// Remote peer ID.
146    peer: PeerId,
147
148    /// Endpoint.
149    endpoint: Endpoint,
150
151    /// Substream open timeout.
152    substream_open_timeout: Duration,
153
154    /// Connection ID.
155    connection_id: ConnectionId,
156
157    /// Bandwidth sink.
158    bandwidth_sink: BandwidthSink,
159
160    /// Pending substreams.
161    pending_substreams:
162        FuturesUnordered<BoxFuture<'static, Result<NegotiatedSubstream, ConnectionError>>>,
163}
164
165impl WebSocketConnection {
166    /// Create new [`WebSocketConnection`].
167    pub(super) fn new(
168        connection: NegotiatedConnection,
169        protocol_set: ProtocolSet,
170        bandwidth_sink: BandwidthSink,
171        substream_open_timeout: Duration,
172    ) -> Self {
173        let NegotiatedConnection {
174            peer,
175            endpoint,
176            connection,
177            control,
178        } = connection;
179
180        Self {
181            connection_id: endpoint.connection_id(),
182            protocol_set,
183            connection,
184            control,
185            peer,
186            endpoint,
187            bandwidth_sink,
188            substream_open_timeout,
189            pending_substreams: FuturesUnordered::new(),
190        }
191    }
192
193    /// Negotiate protocol.
194    async fn negotiate_protocol<S: AsyncRead + AsyncWrite + Unpin>(
195        stream: S,
196        role: &Role,
197        protocols: Vec<&str>,
198    ) -> Result<(Negotiated<S>, ProtocolName), NegotiationError> {
199        tracing::trace!(target: LOG_TARGET, ?protocols, "negotiating protocols");
200
201        let (protocol, socket) = match role {
202            Role::Dialer => dialer_select_proto(stream, protocols, Version::V1).await,
203            Role::Listener => listener_select_proto(stream, protocols).await,
204        }
205        .map_err(NegotiationError::MultistreamSelectError)?;
206
207        tracing::trace!(target: LOG_TARGET, ?protocol, "protocol negotiated");
208
209        Ok((socket, ProtocolName::from(protocol.to_string())))
210    }
211
212    /// Open WebSocket connection.
213    pub(super) async fn open_connection(
214        connection_id: ConnectionId,
215        keypair: Keypair,
216        stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
217        address: Multiaddr,
218        dialed_peer: PeerId,
219        ws_address: Url,
220        yamux_config: crate::yamux::Config,
221        max_read_ahead_factor: usize,
222        max_write_buffer_size: usize,
223    ) -> Result<NegotiatedConnection, NegotiationError> {
224        tracing::trace!(
225            target: LOG_TARGET,
226            ?address,
227            ?ws_address,
228            ?connection_id,
229            "open connection to remote peer",
230        );
231
232        Self::negotiate_connection(
233            stream,
234            Some(dialed_peer),
235            Role::Dialer,
236            address,
237            connection_id,
238            keypair,
239            yamux_config,
240            max_read_ahead_factor,
241            max_write_buffer_size,
242        )
243        .await
244    }
245
246    /// Accept WebSocket connection.
247    pub(super) async fn accept_connection(
248        stream: TcpStream,
249        connection_id: ConnectionId,
250        keypair: Keypair,
251        address: Multiaddr,
252        yamux_config: crate::yamux::Config,
253        max_read_ahead_factor: usize,
254        max_write_buffer_size: usize,
255    ) -> Result<NegotiatedConnection, NegotiationError> {
256        let stream = MaybeTlsStream::Plain(stream);
257
258        Self::negotiate_connection(
259            tokio_tungstenite::accept_async(stream)
260                .await
261                .map_err(NegotiationError::WebSocket)?,
262            None,
263            Role::Listener,
264            address,
265            connection_id,
266            keypair,
267            yamux_config,
268            max_read_ahead_factor,
269            max_write_buffer_size,
270        )
271        .await
272    }
273
274    /// Negotiate WebSocket connection.
275    pub(super) async fn negotiate_connection(
276        stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
277        dialed_peer: Option<PeerId>,
278        role: Role,
279        address: Multiaddr,
280        connection_id: ConnectionId,
281        keypair: Keypair,
282        yamux_config: crate::yamux::Config,
283        max_read_ahead_factor: usize,
284        max_write_buffer_size: usize,
285    ) -> Result<NegotiatedConnection, NegotiationError> {
286        tracing::trace!(
287            target: LOG_TARGET,
288            ?connection_id,
289            ?address,
290            ?role,
291            ?dialed_peer,
292            "negotiate connection"
293        );
294        let stream = BufferedStream::new(stream);
295
296        // negotiate `noise`
297        let (stream, _) = Self::negotiate_protocol(stream, &role, vec!["/noise"]).await?;
298
299        tracing::trace!(
300            target: LOG_TARGET,
301            "`multistream-select` and `noise` negotiated"
302        );
303
304        // perform noise handshake
305        let (stream, peer) = noise::handshake(
306            stream.inner(),
307            &keypair,
308            role,
309            max_read_ahead_factor,
310            max_write_buffer_size,
311        )
312        .await?;
313
314        if let Some(dialed_peer) = dialed_peer {
315            if peer != dialed_peer {
316                return Err(NegotiationError::PeerIdMismatch(dialed_peer, peer));
317            }
318        }
319
320        let stream: NoiseSocket<BufferedStream<_>> = stream;
321        tracing::trace!(target: LOG_TARGET, "noise handshake done");
322
323        // negotiate `yamux`
324        let (stream, _) = Self::negotiate_protocol(stream, &role, vec!["/yamux/1.0.0"]).await?;
325        tracing::trace!(target: LOG_TARGET, "`yamux` negotiated");
326
327        let connection = crate::yamux::Connection::new(stream.inner(), yamux_config, role.into());
328        let (control, connection) = crate::yamux::Control::new(connection);
329
330        let address = match role {
331            Role::Dialer => address,
332            Role::Listener => address.with(Protocol::P2p(Multihash::from(peer))),
333        };
334
335        Ok(NegotiatedConnection {
336            peer,
337            control,
338            connection,
339            endpoint: match role {
340                Role::Dialer => Endpoint::dialer(address, connection_id),
341                Role::Listener => Endpoint::listener(address, connection_id),
342            },
343        })
344    }
345
346    /// Accept substream.
347    pub async fn accept_substream(
348        stream: crate::yamux::Stream,
349        permit: Permit,
350        substream_id: SubstreamId,
351        protocols: Vec<ProtocolName>,
352    ) -> Result<NegotiatedSubstream, NegotiationError> {
353        tracing::trace!(
354            target: LOG_TARGET,
355            ?substream_id,
356            "accept inbound substream"
357        );
358
359        let protocols = protocols.iter().map(|protocol| &**protocol).collect::<Vec<&str>>();
360        let (io, protocol) = Self::negotiate_protocol(stream, &Role::Listener, protocols).await?;
361
362        tracing::trace!(
363            target: LOG_TARGET,
364            ?substream_id,
365            "substream accepted and negotiated"
366        );
367
368        Ok(NegotiatedSubstream {
369            io: io.inner(),
370            direction: Direction::Inbound,
371            substream_id,
372            protocol,
373            permit,
374        })
375    }
376
377    /// Open substream for `protocol`.
378    pub async fn open_substream(
379        mut control: crate::yamux::Control,
380        permit: Permit,
381        substream_id: SubstreamId,
382        protocol: ProtocolName,
383        fallback_names: Vec<ProtocolName>,
384    ) -> Result<NegotiatedSubstream, SubstreamError> {
385        tracing::debug!(target: LOG_TARGET, ?protocol, ?substream_id, "open substream");
386
387        let stream = match control.open_stream().await {
388            Ok(stream) => {
389                tracing::trace!(target: LOG_TARGET, ?substream_id, "substream opened");
390                stream
391            }
392            Err(error) => {
393                tracing::debug!(
394                    target: LOG_TARGET,
395                    ?substream_id,
396                    ?error,
397                    "failed to open substream"
398                );
399                return Err(SubstreamError::YamuxError(
400                    error,
401                    Direction::Outbound(substream_id),
402                ));
403            }
404        };
405
406        // TODO: protocols don't change after they've been initialized so this should be done only
407        // once
408        let protocols = std::iter::once(&*protocol)
409            .chain(fallback_names.iter().map(|protocol| &**protocol))
410            .collect();
411
412        let (io, protocol) = Self::negotiate_protocol(stream, &Role::Dialer, protocols).await?;
413
414        Ok(NegotiatedSubstream {
415            io: io.inner(),
416            substream_id,
417            direction: Direction::Outbound(substream_id),
418            protocol,
419            permit,
420        })
421    }
422
423    /// Start connection event loop.
424    pub(crate) async fn start(mut self) -> crate::Result<()> {
425        self.protocol_set
426            .report_connection_established(self.peer, self.endpoint)
427            .await?;
428
429        loop {
430            tokio::select! {
431                substream = self.connection.next() => match substream {
432                    Some(Ok(stream)) => {
433                        let substream = self.protocol_set.next_substream_id();
434                        let protocols = self.protocol_set.protocols();
435                        let permit = self.protocol_set.try_get_permit().ok_or(Error::ConnectionClosed)?;
436                        let substream_open_timeout = self.substream_open_timeout;
437
438                        self.pending_substreams.push(Box::pin(async move {
439                            match tokio::time::timeout(
440                                substream_open_timeout,
441                                Self::accept_substream(stream, permit, substream, protocols),
442                            )
443                            .await
444                            {
445                                Ok(Ok(substream)) => Ok(substream),
446                                Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate {
447                                    protocol: None,
448                                    substream_id: None,
449                                    error: SubstreamError::NegotiationError(error),
450                                }),
451                                Err(_) => Err(ConnectionError::Timeout {
452                                    protocol: None,
453                                    substream_id: None
454                                }),
455                            }
456                        }));
457                    },
458                    Some(Err(error)) => {
459                        tracing::debug!(
460                            target: LOG_TARGET,
461                            peer = ?self.peer,
462                            ?error,
463                            "connection closed with error"
464                        );
465                        self.protocol_set.report_connection_closed(self.peer, self.connection_id).await?;
466
467                        return Ok(())
468                    }
469                    None => {
470                        tracing::debug!(target: LOG_TARGET, peer = ?self.peer, "connection closed");
471                        self.protocol_set.report_connection_closed(self.peer, self.connection_id).await?;
472
473                        return Ok(())
474                    }
475                },
476                // TODO: move this to a function
477                substream = self.pending_substreams.select_next_some(), if !self.pending_substreams.is_empty() => {
478                    match substream {
479                        // TODO: return error to protocol
480                        Err(error) => {
481                            tracing::debug!(
482                                target: LOG_TARGET,
483                                ?error,
484                                "failed to accept/open substream",
485                            );
486
487                            let (protocol, substream_id, error) = match error {
488                                ConnectionError::Timeout { protocol, substream_id } => {
489                                    (protocol, substream_id, SubstreamError::NegotiationError(NegotiationError::Timeout))
490                                }
491                                ConnectionError::FailedToNegotiate { protocol, substream_id, error } => {
492                                    (protocol, substream_id, error)
493                                }
494                            };
495
496                            if let (Some(protocol), Some(substream_id)) = (protocol, substream_id) {
497                                self.protocol_set
498                                    .report_substream_open_failure(protocol, substream_id, error)
499                                    .await?;
500                            }
501                        }
502                        Ok(substream) => {
503                            let protocol = substream.protocol.clone();
504                            let direction = substream.direction;
505                            let substream_id = substream.substream_id;
506                            let socket = FuturesAsyncReadCompatExt::compat(substream.io);
507                            let bandwidth_sink = self.bandwidth_sink.clone();
508
509                            let substream = substream::Substream::new_websocket(
510                                self.peer,
511                                substream_id,
512                                Substream::new(socket, bandwidth_sink, substream.permit),
513                                self.protocol_set.protocol_codec(&protocol)
514                            );
515
516                            self.protocol_set
517                                .report_substream_open(self.peer, protocol, direction, substream)
518                                .await?;
519                        }
520                    }
521                }
522                protocol = self.protocol_set.next() => match protocol {
523                    Some(ProtocolCommand::OpenSubstream { protocol, fallback_names, substream_id, permit }) => {
524                        let control = self.control.clone();
525                        let substream_open_timeout = self.substream_open_timeout;
526
527                        tracing::trace!(
528                            target: LOG_TARGET,
529                            ?protocol,
530                            ?substream_id,
531                            "open substream"
532                        );
533
534                        self.pending_substreams.push(Box::pin(async move {
535                            match tokio::time::timeout(
536                                substream_open_timeout,
537                                Self::open_substream(
538                                    control,
539                                    permit,
540                                    substream_id,
541                                    protocol.clone(),
542                                    fallback_names
543                                ),
544                            )
545                            .await
546                            {
547                                Ok(Ok(substream)) => Ok(substream),
548                                Ok(Err(error)) => Err(ConnectionError::FailedToNegotiate {
549                                    protocol: Some(protocol),
550                                    substream_id: Some(substream_id),
551                                    error,
552                                }),
553                                Err(_) => Err(ConnectionError::Timeout {
554                                    protocol: Some(protocol),
555                                    substream_id: Some(substream_id)
556                                }),
557                            }
558                        }));
559                    }
560                    Some(ProtocolCommand::ForceClose) => {
561                        tracing::debug!(
562                            target: LOG_TARGET,
563                            peer = ?self.peer,
564                            connection_id = ?self.connection_id,
565                            "force closing connection",
566                        );
567
568                        return self.protocol_set.report_connection_closed(self.peer, self.connection_id).await
569                    }
570                    None => {
571                        tracing::debug!(target: LOG_TARGET, "protocols have exited, shutting down connection");
572                        return self.protocol_set.report_connection_closed(self.peer, self.connection_id).await
573                    }
574                }
575            }
576        }
577    }
578}