libp2p_websocket/
framed.rs

1// Copyright 2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{error::Error, quicksink, tls};
22use either::Either;
23use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
24use futures_rustls::{client, rustls, server};
25use libp2p_core::{
26    connection::Endpoint,
27    multiaddr::{Multiaddr, Protocol},
28    transport::{ListenerId, TransportError, TransportEvent},
29    Transport,
30};
31use log::{debug, trace};
32use parking_lot::Mutex;
33use soketto::{
34    connection::{self, CloseReason},
35    handshake,
36};
37use std::{collections::HashMap, ops::DerefMut, sync::Arc};
38use std::{convert::TryInto, fmt, io, mem, pin::Pin, task::Context, task::Poll};
39use url::Url;
40
41/// Max. number of payload bytes of a single frame.
42const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
43
44/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
45/// frame payloads which does not implement [`AsyncRead`] or
46/// [`AsyncWrite`]. See [`crate::WsConfig`] if you require the latter.
47#[derive(Debug)]
48pub struct WsConfig<T> {
49    transport: Arc<Mutex<T>>,
50    max_data_size: usize,
51    tls_config: tls::Config,
52    max_redirects: u8,
53    /// Websocket protocol of the inner listener.
54    ///
55    /// This is the suffix of the address provided in `listen_on`.
56    /// Can only be [`Protocol::Ws`] or [`Protocol::Wss`].
57    listener_protos: HashMap<ListenerId, Protocol<'static>>,
58}
59
60impl<T> WsConfig<T>
61where
62    T: Send,
63{
64    /// Create a new websocket transport based on another transport.
65    pub fn new(transport: T) -> Self {
66        WsConfig {
67            transport: Arc::new(Mutex::new(transport)),
68            max_data_size: MAX_DATA_SIZE,
69            tls_config: tls::Config::client(),
70            max_redirects: 0,
71            listener_protos: HashMap::new(),
72        }
73    }
74
75    /// Return the configured maximum number of redirects.
76    pub fn max_redirects(&self) -> u8 {
77        self.max_redirects
78    }
79
80    /// Set max. number of redirects to follow.
81    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
82        self.max_redirects = max;
83        self
84    }
85
86    /// Get the max. frame data size we support.
87    pub fn max_data_size(&self) -> usize {
88        self.max_data_size
89    }
90
91    /// Set the max. frame data size we support.
92    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
93        self.max_data_size = size;
94        self
95    }
96
97    /// Set the TLS configuration if TLS support is desired.
98    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
99        self.tls_config = c;
100        self
101    }
102}
103
104type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
105
106impl<T> Transport for WsConfig<T>
107where
108    T: Transport + Send + Unpin + 'static,
109    T::Error: Send + 'static,
110    T::Dial: Send + 'static,
111    T::ListenerUpgrade: Send + 'static,
112    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
113{
114    type Output = Connection<T::Output>;
115    type Error = Error<T::Error>;
116    type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
117    type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
118
119    fn listen_on(
120        &mut self,
121        id: ListenerId,
122        addr: Multiaddr,
123    ) -> Result<(), TransportError<Self::Error>> {
124        let mut inner_addr = addr.clone();
125        let proto = match inner_addr.pop() {
126            Some(p @ Protocol::Wss(_)) => {
127                if self.tls_config.server.is_some() {
128                    p
129                } else {
130                    debug!("/wss address but TLS server support is not configured");
131                    return Err(TransportError::MultiaddrNotSupported(addr));
132                }
133            }
134            Some(p @ Protocol::Ws(_)) => p,
135            _ => {
136                debug!("{} is not a websocket multiaddr", addr);
137                return Err(TransportError::MultiaddrNotSupported(addr));
138            }
139        };
140        match self.transport.lock().listen_on(id, inner_addr) {
141            Ok(()) => {
142                self.listener_protos.insert(id, proto);
143                Ok(())
144            }
145            Err(e) => Err(e.map(Error::Transport)),
146        }
147    }
148
149    fn remove_listener(&mut self, id: ListenerId) -> bool {
150        self.transport.lock().remove_listener(id)
151    }
152
153    fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
154        self.do_dial(addr, Endpoint::Dialer)
155    }
156
157    fn dial_as_listener(
158        &mut self,
159        addr: Multiaddr,
160    ) -> Result<Self::Dial, TransportError<Self::Error>> {
161        self.do_dial(addr, Endpoint::Listener)
162    }
163
164    fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
165        self.transport.lock().address_translation(server, observed)
166    }
167
168    fn poll(
169        mut self: Pin<&mut Self>,
170        cx: &mut Context<'_>,
171    ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
172        let inner_event = {
173            let mut transport = self.transport.lock();
174            match Transport::poll(Pin::new(transport.deref_mut()), cx) {
175                Poll::Ready(ev) => ev,
176                Poll::Pending => return Poll::Pending,
177            }
178        };
179        let event = match inner_event {
180            TransportEvent::NewAddress {
181                listener_id,
182                mut listen_addr,
183            } => {
184                // Append the ws / wss protocol back to the inner address.
185                let proto = self
186                    .listener_protos
187                    .get(&listener_id)
188                    .expect("Protocol was inserted in Transport::listen_on.");
189                listen_addr.push(proto.clone());
190                debug!("Listening on {}", listen_addr);
191                TransportEvent::NewAddress {
192                    listener_id,
193                    listen_addr,
194                }
195            }
196            TransportEvent::AddressExpired {
197                listener_id,
198                mut listen_addr,
199            } => {
200                let proto = self
201                    .listener_protos
202                    .get(&listener_id)
203                    .expect("Protocol was inserted in Transport::listen_on.");
204                listen_addr.push(proto.clone());
205                TransportEvent::AddressExpired {
206                    listener_id,
207                    listen_addr,
208                }
209            }
210            TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
211                listener_id,
212                error: Error::Transport(error),
213            },
214            TransportEvent::ListenerClosed {
215                listener_id,
216                reason,
217            } => {
218                self.listener_protos
219                    .remove(&listener_id)
220                    .expect("Protocol was inserted in Transport::listen_on.");
221                TransportEvent::ListenerClosed {
222                    listener_id,
223                    reason: reason.map_err(Error::Transport),
224                }
225            }
226            TransportEvent::Incoming {
227                listener_id,
228                upgrade,
229                mut local_addr,
230                mut send_back_addr,
231            } => {
232                let proto = self
233                    .listener_protos
234                    .get(&listener_id)
235                    .expect("Protocol was inserted in Transport::listen_on.");
236                let use_tls = match proto {
237                    Protocol::Wss(_) => true,
238                    Protocol::Ws(_) => false,
239                    _ => unreachable!("Map contains only ws and wss protocols."),
240                };
241                local_addr.push(proto.clone());
242                send_back_addr.push(proto.clone());
243                let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
244                TransportEvent::Incoming {
245                    listener_id,
246                    upgrade,
247                    local_addr,
248                    send_back_addr,
249                }
250            }
251        };
252        Poll::Ready(event)
253    }
254}
255
256impl<T> WsConfig<T>
257where
258    T: Transport + Send + Unpin + 'static,
259    T::Error: Send + 'static,
260    T::Dial: Send + 'static,
261    T::ListenerUpgrade: Send + 'static,
262    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
263{
264    fn do_dial(
265        &mut self,
266        addr: Multiaddr,
267        role_override: Endpoint,
268    ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
269        let mut addr = match parse_ws_dial_addr(addr) {
270            Ok(addr) => addr,
271            Err(Error::InvalidMultiaddr(a)) => {
272                return Err(TransportError::MultiaddrNotSupported(a))
273            }
274            Err(e) => return Err(TransportError::Other(e)),
275        };
276
277        // We are looping here in order to follow redirects (if any):
278        let mut remaining_redirects = self.max_redirects;
279
280        let transport = self.transport.clone();
281        let tls_config = self.tls_config.clone();
282        let max_redirects = self.max_redirects;
283
284        let future = async move {
285            loop {
286                match Self::dial_once(transport.clone(), addr, tls_config.clone(), role_override)
287                    .await
288                {
289                    Ok(Either::Left(redirect)) => {
290                        if remaining_redirects == 0 {
291                            debug!("Too many redirects (> {})", max_redirects);
292                            return Err(Error::TooManyRedirects);
293                        }
294                        remaining_redirects -= 1;
295                        addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
296                    }
297                    Ok(Either::Right(conn)) => return Ok(conn),
298                    Err(e) => return Err(e),
299                }
300            }
301        };
302
303        Ok(Box::pin(future))
304    }
305
306    /// Attempts to dial the given address and perform a websocket handshake.
307    async fn dial_once(
308        transport: Arc<Mutex<T>>,
309        addr: WsAddress,
310        tls_config: tls::Config,
311        role_override: Endpoint,
312    ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
313        trace!("Dialing websocket address: {:?}", addr);
314
315        let dial = match role_override {
316            Endpoint::Dialer => transport.lock().dial(addr.tcp_addr),
317            Endpoint::Listener => transport.lock().dial_as_listener(addr.tcp_addr),
318        }
319        .map_err(|e| match e {
320            TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
321            TransportError::Other(e) => Error::Transport(e),
322        })?;
323
324        let stream = dial.map_err(Error::Transport).await?;
325        trace!("TCP connection to {} established.", addr.host_port);
326
327        let stream = if addr.use_tls {
328            // begin TLS session
329            let dns_name = addr
330                .dns_name
331                .expect("for use_tls we have checked that dns_name is some");
332            trace!("Starting TLS handshake with {:?}", dns_name);
333            let stream = tls_config
334                .client
335                .connect(dns_name.clone(), stream)
336                .map_err(|e| {
337                    debug!("TLS handshake with {:?} failed: {}", dns_name, e);
338                    Error::Tls(tls::Error::from(e))
339                })
340                .await?;
341
342            let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
343            stream
344        } else {
345            // continue with plain stream
346            future::Either::Right(stream)
347        };
348
349        trace!("Sending websocket handshake to {}", addr.host_port);
350
351        let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
352
353        match client
354            .handshake()
355            .map_err(|e| Error::Handshake(Box::new(e)))
356            .await?
357        {
358            handshake::ServerResponse::Redirect {
359                status_code,
360                location,
361            } => {
362                debug!(
363                    "received redirect ({}); location: {}",
364                    status_code, location
365                );
366                Ok(Either::Left(location))
367            }
368            handshake::ServerResponse::Rejected { status_code } => {
369                let msg = format!("server rejected handshake; status code = {status_code}");
370                Err(Error::Handshake(msg.into()))
371            }
372            handshake::ServerResponse::Accepted { .. } => {
373                trace!("websocket handshake with {} successful", addr.host_port);
374                Ok(Either::Right(Connection::new(client.into_builder())))
375            }
376        }
377    }
378
379    fn map_upgrade(
380        &self,
381        upgrade: T::ListenerUpgrade,
382        remote_addr: Multiaddr,
383        use_tls: bool,
384    ) -> <Self as Transport>::ListenerUpgrade {
385        let remote_addr2 = remote_addr.clone(); // used for logging
386        let tls_config = self.tls_config.clone();
387        let max_size = self.max_data_size;
388
389        async move {
390            let stream = upgrade.map_err(Error::Transport).await?;
391            trace!("incoming connection from {}", remote_addr);
392
393            let stream = if use_tls {
394                // begin TLS session
395                let server = tls_config
396                    .server
397                    .expect("for use_tls we checked server is not none");
398
399                trace!("awaiting TLS handshake with {}", remote_addr);
400
401                let stream = server
402                    .accept(stream)
403                    .map_err(move |e| {
404                        debug!("TLS handshake with {} failed: {}", remote_addr, e);
405                        Error::Tls(tls::Error::from(e))
406                    })
407                    .await?;
408
409                let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
410
411                stream
412            } else {
413                // continue with plain stream
414                future::Either::Right(stream)
415            };
416
417            trace!(
418                "receiving websocket handshake request from {}",
419                remote_addr2
420            );
421
422            let mut server = handshake::Server::new(stream);
423
424            let ws_key = {
425                let request = server
426                    .receive_request()
427                    .map_err(|e| Error::Handshake(Box::new(e)))
428                    .await?;
429                request.key()
430            };
431
432            trace!(
433                "accepting websocket handshake request from {}",
434                remote_addr2
435            );
436
437            let response = handshake::server::Response::Accept {
438                key: ws_key,
439                protocol: None,
440            };
441
442            server
443                .send_response(&response)
444                .map_err(|e| Error::Handshake(Box::new(e)))
445                .await?;
446
447            let conn = {
448                let mut builder = server.into_builder();
449                builder.set_max_message_size(max_size);
450                builder.set_max_frame_size(max_size);
451                Connection::new(builder)
452            };
453
454            Ok(conn)
455        }
456        .boxed()
457    }
458}
459
460#[derive(Debug)]
461struct WsAddress {
462    host_port: String,
463    path: String,
464    dns_name: Option<rustls::ServerName>,
465    use_tls: bool,
466    tcp_addr: Multiaddr,
467}
468
469/// Tries to parse the given `Multiaddr` into a `WsAddress` used
470/// for dialing.
471///
472/// Fails if the given `Multiaddr` does not represent a TCP/IP-based
473/// websocket protocol stack.
474fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
475    // The encapsulating protocol must be based on TCP/IP, possibly via DNS.
476    // We peek at it in order to learn the hostname and port to use for
477    // the websocket handshake.
478    let mut protocols = addr.iter();
479    let mut ip = protocols.next();
480    let mut tcp = protocols.next();
481    let (host_port, dns_name) = loop {
482        match (ip, tcp) {
483            (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
484                break (format!("{ip}:{port}"), None)
485            }
486            (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
487                break (format!("{ip}:{port}"), None)
488            }
489            (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
490            | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
491            | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port)))
492            | (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => {
493                break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?))
494            }
495            (Some(_), Some(p)) => {
496                ip = Some(p);
497                tcp = protocols.next();
498            }
499            _ => return Err(Error::InvalidMultiaddr(addr)),
500        }
501    };
502
503    // Now consume the `Ws` / `Wss` protocol from the end of the address,
504    // preserving the trailing `P2p` protocol that identifies the remote,
505    // if any.
506    let mut protocols = addr.clone();
507    let mut p2p = None;
508    let (use_tls, path) = loop {
509        match protocols.pop() {
510            p @ Some(Protocol::P2p(_)) => p2p = p,
511            Some(Protocol::Ws(path)) => break (false, path.into_owned()),
512            Some(Protocol::Wss(path)) => {
513                if dns_name.is_none() {
514                    debug!("Missing DNS name in WSS address: {}", addr);
515                    return Err(Error::InvalidMultiaddr(addr));
516                }
517                break (true, path.into_owned());
518            }
519            _ => return Err(Error::InvalidMultiaddr(addr)),
520        }
521    };
522
523    // The original address, stripped of the `/ws` and `/wss` protocols,
524    // makes up the the address for the inner TCP-based transport.
525    let tcp_addr = match p2p {
526        Some(p) => protocols.with(p),
527        None => protocols,
528    };
529
530    Ok(WsAddress {
531        host_port,
532        dns_name,
533        path,
534        use_tls,
535        tcp_addr,
536    })
537}
538
539// Given a location URL, build a new websocket [`Multiaddr`].
540fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
541    match Url::parse(location) {
542        Ok(url) => {
543            let mut a = Multiaddr::empty();
544            match url.host() {
545                Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
546                Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
547                Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
548                None => return Err(Error::InvalidRedirectLocation),
549            }
550            if let Some(p) = url.port() {
551                a.push(Protocol::Tcp(p))
552            }
553            let s = url.scheme();
554            if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
555                a.push(Protocol::Wss(url.path().into()))
556            } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
557                a.push(Protocol::Ws(url.path().into()))
558            } else {
559                debug!("unsupported scheme: {}", s);
560                return Err(Error::InvalidRedirectLocation);
561            }
562            Ok(a)
563        }
564        Err(e) => {
565            debug!("failed to parse url as multi-address: {:?}", e);
566            Err(Error::InvalidRedirectLocation)
567        }
568    }
569}
570
571/// The websocket connection.
572pub struct Connection<T> {
573    receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
574    sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
575    _marker: std::marker::PhantomData<T>,
576}
577
578/// Data or control information received over the websocket connection.
579#[derive(Debug, Clone)]
580pub enum Incoming {
581    /// Application data.
582    Data(Data),
583    /// PONG control frame data.
584    Pong(Vec<u8>),
585    /// Close reason.
586    Closed(CloseReason),
587}
588
589/// Application data received over the websocket connection
590#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
591pub enum Data {
592    /// UTF-8 encoded textual data.
593    Text(Vec<u8>),
594    /// Binary data.
595    Binary(Vec<u8>),
596}
597
598impl Data {
599    pub fn into_bytes(self) -> Vec<u8> {
600        match self {
601            Data::Text(d) => d,
602            Data::Binary(d) => d,
603        }
604    }
605}
606
607impl AsRef<[u8]> for Data {
608    fn as_ref(&self) -> &[u8] {
609        match self {
610            Data::Text(d) => d,
611            Data::Binary(d) => d,
612        }
613    }
614}
615
616impl Incoming {
617    pub fn is_data(&self) -> bool {
618        self.is_binary() || self.is_text()
619    }
620
621    pub fn is_binary(&self) -> bool {
622        matches!(self, Incoming::Data(Data::Binary(_)))
623    }
624
625    pub fn is_text(&self) -> bool {
626        matches!(self, Incoming::Data(Data::Text(_)))
627    }
628
629    pub fn is_pong(&self) -> bool {
630        matches!(self, Incoming::Pong(_))
631    }
632
633    pub fn is_close(&self) -> bool {
634        matches!(self, Incoming::Closed(_))
635    }
636}
637
638/// Data sent over the websocket connection.
639#[derive(Debug, Clone)]
640pub enum OutgoingData {
641    /// Send some bytes.
642    Binary(Vec<u8>),
643    /// Send a PING message.
644    Ping(Vec<u8>),
645    /// Send an unsolicited PONG message.
646    /// (Incoming PINGs are answered automatically.)
647    Pong(Vec<u8>),
648}
649
650impl<T> fmt::Debug for Connection<T> {
651    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
652        f.write_str("Connection")
653    }
654}
655
656impl<T> Connection<T>
657where
658    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
659{
660    fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
661        let (sender, receiver) = builder.finish();
662        let sink = quicksink::make_sink(sender, |mut sender, action| async move {
663            match action {
664                quicksink::Action::Send(OutgoingData::Binary(x)) => {
665                    sender.send_binary_mut(x).await?
666                }
667                quicksink::Action::Send(OutgoingData::Ping(x)) => {
668                    let data = x[..].try_into().map_err(|_| {
669                        io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
670                    })?;
671                    sender.send_ping(data).await?
672                }
673                quicksink::Action::Send(OutgoingData::Pong(x)) => {
674                    let data = x[..].try_into().map_err(|_| {
675                        io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
676                    })?;
677                    sender.send_pong(data).await?
678                }
679                quicksink::Action::Flush => sender.flush().await?,
680                quicksink::Action::Close => sender.close().await?,
681            }
682            Ok(sender)
683        });
684        let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
685            match receiver.receive(&mut data).await {
686                Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
687                    Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
688                    (data, receiver),
689                )),
690                Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
691                    Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
692                    (data, receiver),
693                )),
694                Ok(soketto::Incoming::Pong(pong)) => {
695                    Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
696                }
697                Ok(soketto::Incoming::Closed(reason)) => {
698                    Some((Ok(Incoming::Closed(reason)), (data, receiver)))
699                }
700                Err(connection::Error::Closed) => None,
701                Err(e) => Some((Err(e), (data, receiver))),
702            }
703        });
704        Connection {
705            receiver: stream.boxed(),
706            sender: Box::pin(sink),
707            _marker: std::marker::PhantomData,
708        }
709    }
710
711    /// Send binary application data to the remote.
712    pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
713        self.send(OutgoingData::Binary(data))
714    }
715
716    /// Send a PING to the remote.
717    pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
718        self.send(OutgoingData::Ping(data))
719    }
720
721    /// Send an unsolicited PONG to the remote.
722    pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
723        self.send(OutgoingData::Pong(data))
724    }
725}
726
727impl<T> Stream for Connection<T>
728where
729    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
730{
731    type Item = io::Result<Incoming>;
732
733    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
734        let item = ready!(self.receiver.poll_next_unpin(cx));
735        let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
736        Poll::Ready(item)
737    }
738}
739
740impl<T> Sink<OutgoingData> for Connection<T>
741where
742    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
743{
744    type Error = io::Error;
745
746    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
747        Pin::new(&mut self.sender)
748            .poll_ready(cx)
749            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
750    }
751
752    fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
753        Pin::new(&mut self.sender)
754            .start_send(item)
755            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
756    }
757
758    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
759        Pin::new(&mut self.sender)
760            .poll_flush(cx)
761            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
762    }
763
764    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
765        Pin::new(&mut self.sender)
766            .poll_close(cx)
767            .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
768    }
769}