1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25pub mod error;
26pub mod framed;
27mod quicksink;
28pub mod tls;
29
30use error::Error;
31use framed::{Connection, Incoming};
32use futures::{future::BoxFuture, prelude::*, ready};
33use libp2p_core::{
34 connection::ConnectedPoint,
35 multiaddr::Multiaddr,
36 transport::{map::MapFuture, ListenerId, TransportError, TransportEvent},
37 Transport,
38};
39use rw_stream_sink::RwStreamSink;
40use std::{
41 io,
42 pin::Pin,
43 task::{Context, Poll},
44};
45
46#[derive(Debug)]
120pub struct WsConfig<T: Transport>
121where
122 T: Transport,
123 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
124{
125 transport: libp2p_core::transport::map::Map<framed::WsConfig<T>, WrapperFn<T::Output>>,
126}
127
128impl<T: Transport> WsConfig<T>
129where
130 T: Transport + Send + Unpin + 'static,
131 T::Error: Send + 'static,
132 T::Dial: Send + 'static,
133 T::ListenerUpgrade: Send + 'static,
134 T::Output: AsyncRead + AsyncWrite + Send + Unpin + 'static,
135{
136 pub fn new(transport: T) -> Self {
145 Self {
146 transport: framed::WsConfig::new(transport)
147 .map(wrap_connection as WrapperFn<T::Output>),
148 }
149 }
150
151 pub fn max_redirects(&self) -> u8 {
153 self.transport.inner().max_redirects()
154 }
155
156 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
158 self.transport.inner_mut().set_max_redirects(max);
159 self
160 }
161
162 pub fn max_data_size(&self) -> usize {
164 self.transport.inner().max_data_size()
165 }
166
167 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
169 self.transport.inner_mut().set_max_data_size(size);
170 self
171 }
172
173 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
175 self.transport.inner_mut().set_tls_config(c);
176 self
177 }
178}
179
180impl<T> Transport for WsConfig<T>
181where
182 T: Transport + Send + Unpin + 'static,
183 T::Error: Send + 'static,
184 T::Dial: Send + 'static,
185 T::ListenerUpgrade: Send + 'static,
186 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
187{
188 type Output = RwStreamSink<BytesConnection<T::Output>>;
189 type Error = Error<T::Error>;
190 type ListenerUpgrade = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
191 type Dial = MapFuture<InnerFuture<T::Output, T::Error>, WrapperFn<T::Output>>;
192
193 fn listen_on(
194 &mut self,
195 id: ListenerId,
196 addr: Multiaddr,
197 ) -> Result<(), TransportError<Self::Error>> {
198 self.transport.listen_on(id, addr)
199 }
200
201 fn remove_listener(&mut self, id: ListenerId) -> bool {
202 self.transport.remove_listener(id)
203 }
204
205 fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
206 self.transport.dial(addr)
207 }
208
209 fn dial_as_listener(
210 &mut self,
211 addr: Multiaddr,
212 ) -> Result<Self::Dial, TransportError<Self::Error>> {
213 self.transport.dial_as_listener(addr)
214 }
215
216 fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
217 self.transport.address_translation(server, observed)
218 }
219
220 fn poll(
221 mut self: Pin<&mut Self>,
222 cx: &mut Context<'_>,
223 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
224 Pin::new(&mut self.transport).poll(cx)
225 }
226}
227
228pub type InnerFuture<T, E> = BoxFuture<'static, Result<Connection<T>, Error<E>>>;
230
231pub type WrapperFn<T> = fn(Connection<T>, ConnectedPoint) -> RwStreamSink<BytesConnection<T>>;
233
234fn wrap_connection<T>(c: Connection<T>, _: ConnectedPoint) -> RwStreamSink<BytesConnection<T>>
237where
238 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
239{
240 RwStreamSink::new(BytesConnection(c))
241}
242
243#[derive(Debug)]
245pub struct BytesConnection<T>(Connection<T>);
246
247impl<T> Stream for BytesConnection<T>
248where
249 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
250{
251 type Item = io::Result<Vec<u8>>;
252
253 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
254 loop {
255 if let Some(item) = ready!(self.0.try_poll_next_unpin(cx)?) {
256 if let Incoming::Data(payload) = item {
257 return Poll::Ready(Some(Ok(payload.into_bytes())));
258 }
259 } else {
260 return Poll::Ready(None);
261 }
262 }
263 }
264}
265
266impl<T> Sink<Vec<u8>> for BytesConnection<T>
267where
268 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
269{
270 type Error = io::Error;
271
272 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
273 Pin::new(&mut self.0).poll_ready(cx)
274 }
275
276 fn start_send(mut self: Pin<&mut Self>, item: Vec<u8>) -> io::Result<()> {
277 Pin::new(&mut self.0).start_send(framed::OutgoingData::Binary(item))
278 }
279
280 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
281 Pin::new(&mut self.0).poll_flush(cx)
282 }
283
284 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
285 Pin::new(&mut self.0).poll_close(cx)
286 }
287}
288
289#[cfg(test)]
292mod tests {
293 use super::WsConfig;
294 use futures::prelude::*;
295 use libp2p_core::{multiaddr::Protocol, transport::ListenerId, Multiaddr, Transport};
296 use libp2p_identity::PeerId;
297 use libp2p_tcp as tcp;
298
299 #[test]
300 fn dialer_connects_to_listener_ipv4() {
301 let a = "/ip4/127.0.0.1/tcp/0/ws".parse().unwrap();
302 futures::executor::block_on(connect(a))
303 }
304
305 #[test]
306 fn dialer_connects_to_listener_ipv6() {
307 let a = "/ip6/::1/tcp/0/ws".parse().unwrap();
308 futures::executor::block_on(connect(a))
309 }
310
311 fn new_ws_config() -> WsConfig<tcp::async_io::Transport> {
312 WsConfig::new(tcp::async_io::Transport::new(tcp::Config::default()))
313 }
314
315 async fn connect(listen_addr: Multiaddr) {
316 let mut ws_config = new_ws_config().boxed();
317 ws_config
318 .listen_on(ListenerId::next(), listen_addr)
319 .expect("listener");
320
321 let addr = ws_config
322 .next()
323 .await
324 .expect("no error")
325 .into_new_address()
326 .expect("listen address");
327
328 assert_eq!(Some(Protocol::Ws("/".into())), addr.iter().nth(2));
329 assert_ne!(Some(Protocol::Tcp(0)), addr.iter().nth(1));
330
331 let inbound = async move {
332 let (conn, _addr) = ws_config
333 .select_next_some()
334 .map(|ev| ev.into_incoming())
335 .await
336 .unwrap();
337 conn.await
338 };
339
340 let outbound = new_ws_config()
341 .boxed()
342 .dial(addr.with(Protocol::P2p(PeerId::random())))
343 .unwrap();
344
345 let (a, b) = futures::join!(inbound, outbound);
346 a.and(b).unwrap();
347 }
348}