libp2p_websocket/
lib.rs

1// Copyright 2017-2019 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Implementation of the libp2p `Transport` trait for Websockets.
22
23#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25pub mod error;
26pub mod framed;
27mod quicksink;
28pub mod tls;
29
30use error::Error;
31use framed::{Connection, Incoming};
32use futures::{future::BoxFuture, prelude::*, ready};
33use libp2p_core::{
34    connection::ConnectedPoint,
35    multiaddr::Multiaddr,
36    transport::{map::MapFuture, ListenerId, TransportError, TransportEvent},
37    Transport,
38};
39use rw_stream_sink::RwStreamSink;
40use std::{
41    io,
42    pin::Pin,
43    task::{Context, Poll},
44};
45
46/// A Websocket transport.
47///
48/// DO NOT wrap this transport with a DNS transport if you want Secure Websockets to work.
49///
50/// A Secure Websocket transport needs to wrap DNS transport to resolve domain names after
51/// they are checked against the remote certificates. Use a combination of DNS and TCP transports
52/// to build a Secure Websocket transport.
53///
54/// If you don't need Secure Websocket's support, use a plain TCP transport as an inner transport.
55///
56/// # Dependencies
57///
58/// This transport requires the `zlib` shared library to be installed on the system.
59///
60/// Future releases might lift this requirement, see <https://github.com/paritytech/soketto/issues/72>.
61///
62/// # Examples
63///
64/// Secure Websocket transport:
65///
66/// ```
67/// # use futures::future;
68/// # use libp2p_core::{transport::ListenerId, Transport};
69/// # use libp2p_dns as dns;
70/// # use libp2p_tcp as tcp;
71/// # use libp2p_websocket as websocket;
72/// # use rcgen::generate_simple_self_signed;
73/// # use std::pin::Pin;
74/// #
75/// # #[async_std::main]
76/// # async fn main() {
77///
78/// let mut transport = websocket::WsConfig::new(dns::async_std::Transport::system(
79///     tcp::async_io::Transport::new(tcp::Config::default()),
80/// ).await.unwrap());
81///
82/// let rcgen_cert = generate_simple_self_signed(vec!["localhost".to_string()]).unwrap();
83/// let priv_key = websocket::tls::PrivateKey::new(rcgen_cert.serialize_private_key_der());
84/// let cert = websocket::tls::Certificate::new(rcgen_cert.serialize_der().unwrap());
85/// transport.set_tls_config(websocket::tls::Config::new(priv_key, vec![cert]).unwrap());
86///
87/// let id = transport.listen_on(ListenerId::next(), "/ip4/127.0.0.1/tcp/0/wss".parse().unwrap()).unwrap();
88///
89/// let addr = future::poll_fn(|cx| Pin::new(&mut transport).poll(cx)).await.into_new_address().unwrap();
90/// println!("Listening on {addr}");
91///
92/// # }
93/// ```
94///
95/// Plain Websocket transport:
96///
97/// ```
98/// # use futures::future;
99/// # use libp2p_core::{transport::ListenerId, Transport};
100/// # use libp2p_dns as dns;
101/// # use libp2p_tcp as tcp;
102/// # use libp2p_websocket as websocket;
103/// # use std::pin::Pin;
104/// #
105/// # #[async_std::main]
106/// # async fn main() {
107///
108/// let mut transport = websocket::WsConfig::new(
109///     tcp::async_io::Transport::new(tcp::Config::default()),
110/// );
111///
112/// let id = transport.listen_on(ListenerId::next(), "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap()).unwrap();
113///
114/// let addr = future::poll_fn(|cx| Pin::new(&mut transport).poll(cx)).await.into_new_address().unwrap();
115/// println!("Listening on {addr}");
116///
117/// # }
118/// ```
119#[derive(Debug)]
120pub struct WsConfig<T: Transport>
121where
122    T: Transport,
123    T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
124{
125    transport: libp2p_core::transport::map::Map<framed::WsConfig<T>, WrapperFn<T::Output>>,
126}
127
128impl<T: Transport> WsConfig<T>
129where
130    T: Transport + Send + Unpin + 'static,
131    T::Error: Send + 'static,
132    T::Dial: Send + 'static,
133    T::ListenerUpgrade: Send + 'static,
134    T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
135{
136    /// Create a new websocket transport based on the given transport.
137    ///
138    /// > **Note*: The given transport must be based on TCP/IP and should
139    /// > usually incorporate DNS resolution, though the latter is not
140    /// > strictly necessary if one wishes to only use the `Ws` protocol
141    /// > with known IP addresses and ports. See [`libp2p-tcp`](https://docs.rs/libp2p-tcp/)
142    /// > and [`libp2p-dns`](https://docs.rs/libp2p-dns) for constructing
143    /// > the inner transport.
144    pub fn new(transport: T) -> Self {
145        Self {
146            transport: framed::WsConfig::new(transport)
147                .map(wrap_connection as WrapperFn<T::Output>),
148        }
149    }
150
151    /// Return the configured maximum number of redirects.
152    pub fn max_redirects(&self) -> u8 {
153        self.transport.inner().max_redirects()
154    }
155
156    /// Set max. number of redirects to follow.
157    pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
158        self.transport.inner_mut().set_max_redirects(max);
159        self
160    }
161
162    /// Get the max. frame data size we support.
163    pub fn max_data_size(&self) -> usize {
164        self.transport.inner().max_data_size()
165    }
166
167    /// Set the max. frame data size we support.
168    pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
169        self.transport.inner_mut().set_max_data_size(size);
170        self
171    }
172
173    /// Set the TLS configuration if TLS support is desired.
174    pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
175        self.transport.inner_mut().set_tls_config(c);
176        self
177    }
178}
179
180impl<T> Transport for WsConfig<T>
181where
182    T: Transport + Send + Unpin + 'static,
183    T::Error: Send + 'static,
184    T::Dial: Send + 'static,
185    T::ListenerUpgrade: Send + 'static,
186    T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
187{
188    type Output = RwStreamSink<BytesConnection<T::Output>>;
189    type Error = Error<T::Error>;
190    type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
191    type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
192
193    fn listen_on(
194        &mut self,
195        id: ListenerId,
196        addr: Multiaddr,
197    ) -> Result<(), TransportError<Self::Error>> {
198        self.transport.listen_on(id, addr)
199    }
200
201    fn remove_listener(&mut self, id: ListenerId) -> bool {
202        self.transport.remove_listener(id)
203    }
204
205    fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
206        self.transport.dial(addr)
207    }
208
209    fn dial_as_listener(
210        &mut self,
211        addr: Multiaddr,
212    ) -> Result<Self::Dial, TransportError<Self::Error>> {
213        self.transport.dial_as_listener(addr)
214    }
215
216    fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
217        self.transport.address_translation(server, observed)
218    }
219
220    fn poll(
221        mut self: Pin<&mut Self>,
222        cx: &mut Context<'_>,
223    ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
224        Pin::new(&mut self.transport).poll(cx)
225    }
226}
227
228/// Type alias corresponding to `framed::WsConfig::Dial` and `framed::WsConfig::ListenerUpgrade`.
229pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
230
231/// Function type that wraps a websocket connection (see. `wrap_connection`).
232pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
233
234/// Wrap a websocket connection producing data frames into a `RwStreamSink`
235/// implementing `AsyncRead` + `AsyncWrite`.
236fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
237where
238    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
239{
240    RwStreamSink::new(BytesConnection(c))
241}
242
243/// The websocket connection.
244#[derive(Debug)]
245pub struct BytesConnection<T>(Connection<T>);
246
247impl<T> Stream for BytesConnection<T>
248where
249    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
250{
251    type Item = io::Result<Vec<u8>>;
252
253    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254        loop {
255            if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
256                if let Incoming::Data(payload) = item {
257                    return Poll::Ready(Some(Ok(payload.into_bytes())));
258                }
259            } else {
260                return Poll::Ready(None);
261            }
262        }
263    }
264}
265
266impl<T> Sink<Vec<u8>> for BytesConnection<T>
267where
268    T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
269{
270    type Error = io::Error;
271
272    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
273        Pin::new(&mut self.0).poll_ready(cx)
274    }
275
276    fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
277        Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
278    }
279
280    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281        Pin::new(&mut self.0).poll_flush(cx)
282    }
283
284    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285        Pin::new(&mut self.0).poll_close(cx)
286    }
287}
288
289// Tests //////////////////////////////////////////////////////////////////////////////////////////
290
291#[cfg(test)]
292mod tests {
293    use super::WsConfig;
294    use futures::prelude::*;
295    use libp2p_core::{multiaddr::Protocol, transport::ListenerId, Multiaddr, Transport};
296    use libp2p_identity::PeerId;
297    use libp2p_tcp as tcp;
298
299    #[test]
300    fn dialer_connects_to_listener_ipv4() {
301        let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
302        futures::executor::block_on(connect(a))
303    }
304
305    #[test]
306    fn dialer_connects_to_listener_ipv6() {
307        let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
308        futures::executor::block_on(connect(a))
309    }
310
311    fn new_ws_config() -> WsConfig<tcp::async_io::Transport> {
312        WsConfig::new(tcp::async_io::Transport::new(tcp::Config::default()))
313    }
314
315    async fn connect(listen_addr: Multiaddr) {
316        let mut ws_config = new_ws_config().boxed();
317        ws_config
318            .listen_on(ListenerId::next(), listen_addr)
319            .expect("listener");
320
321        let addr = ws_config
322            .next()
323            .await
324            .expect("no error")
325            .into_new_address()
326            .expect("listen address");
327
328        assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
329        assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
330
331        let inbound = async move {
332            let (conn, _addr) = ws_config
333                .select_next_some()
334                .map(|ev| ev.into_incoming())
335                .await
336                .unwrap();
337            conn.await
338        };
339
340        let outbound = new_ws_config()
341            .boxed()
342            .dial(addr.with(Protocol::P2p(PeerId::random())))
343            .unwrap();
344
345        let (a, b) = futures::join!(inbound, outbound);
346        a.and(b).unwrap();
347    }
348}