1use crate::{error::Error, quicksink, tls};
22use either::Either;
23use futures::{future::BoxFuture, prelude::*, ready, stream::BoxStream};
24use futures_rustls::{client, rustls, server};
25use libp2p_core::{
26 connection::Endpoint,
27 multiaddr::{Multiaddr, Protocol},
28 transport::{ListenerId, TransportError, TransportEvent},
29 Transport,
30};
31use log::{debug, trace};
32use parking_lot::Mutex;
33use soketto::{
34 connection::{self, CloseReason},
35 handshake,
36};
37use std::{collections::HashMap, ops::DerefMut, sync::Arc};
38use std::{convert::TryInto, fmt, io, mem, pin::Pin, task::Context, task::Poll};
39use url::Url;
40
41const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
43
44#[derive(Debug)]
48pub struct WsConfig<T> {
49 transport: Arc<Mutex<T>>,
50 max_data_size: usize,
51 tls_config: tls::Config,
52 max_redirects: u8,
53 listener_protos: HashMap<ListenerId, Protocol<'static>>,
58}
59
60impl<T> WsConfig<T>
61where
62 T: Send,
63{
64 pub fn new(transport: T) -> Self {
66 WsConfig {
67 transport: Arc::new(Mutex::new(transport)),
68 max_data_size: MAX_DATA_SIZE,
69 tls_config: tls::Config::client(),
70 max_redirects: 0,
71 listener_protos: HashMap::new(),
72 }
73 }
74
75 pub fn max_redirects(&self) -> u8 {
77 self.max_redirects
78 }
79
80 pub fn set_max_redirects(&mut self, max: u8) -> &mut Self {
82 self.max_redirects = max;
83 self
84 }
85
86 pub fn max_data_size(&self) -> usize {
88 self.max_data_size
89 }
90
91 pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
93 self.max_data_size = size;
94 self
95 }
96
97 pub fn set_tls_config(&mut self, c: tls::Config) -> &mut Self {
99 self.tls_config = c;
100 self
101 }
102}
103
104type TlsOrPlain<T> = future::Either<future::Either<client::TlsStream<T>, server::TlsStream<T>>, T>;
105
106impl<T> Transport for WsConfig<T>
107where
108 T: Transport + Send + Unpin + 'static,
109 T::Error: Send + 'static,
110 T::Dial: Send + 'static,
111 T::ListenerUpgrade: Send + 'static,
112 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
113{
114 type Output = Connection<T::Output>;
115 type Error = Error<T::Error>;
116 type ListenerUpgrade = BoxFuture<'static, Result<Self::Output, Self::Error>>;
117 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
118
119 fn listen_on(
120 &mut self,
121 id: ListenerId,
122 addr: Multiaddr,
123 ) -> Result<(), TransportError<Self::Error>> {
124 let mut inner_addr = addr.clone();
125 let proto = match inner_addr.pop() {
126 Some(p @ Protocol::Wss(_)) => {
127 if self.tls_config.server.is_some() {
128 p
129 } else {
130 debug!("/wss address but TLS server support is not configured");
131 return Err(TransportError::MultiaddrNotSupported(addr));
132 }
133 }
134 Some(p @ Protocol::Ws(_)) => p,
135 _ => {
136 debug!("{} is not a websocket multiaddr", addr);
137 return Err(TransportError::MultiaddrNotSupported(addr));
138 }
139 };
140 match self.transport.lock().listen_on(id, inner_addr) {
141 Ok(()) => {
142 self.listener_protos.insert(id, proto);
143 Ok(())
144 }
145 Err(e) => Err(e.map(Error::Transport)),
146 }
147 }
148
149 fn remove_listener(&mut self, id: ListenerId) -> bool {
150 self.transport.lock().remove_listener(id)
151 }
152
153 fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
154 self.do_dial(addr, Endpoint::Dialer)
155 }
156
157 fn dial_as_listener(
158 &mut self,
159 addr: Multiaddr,
160 ) -> Result<Self::Dial, TransportError<Self::Error>> {
161 self.do_dial(addr, Endpoint::Listener)
162 }
163
164 fn address_translation(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
165 self.transport.lock().address_translation(server, observed)
166 }
167
168 fn poll(
169 mut self: Pin<&mut Self>,
170 cx: &mut Context<'_>,
171 ) -> Poll<libp2p_core::transport::TransportEvent<Self::ListenerUpgrade, Self::Error>> {
172 let inner_event = {
173 let mut transport = self.transport.lock();
174 match Transport::poll(Pin::new(transport.deref_mut()), cx) {
175 Poll::Ready(ev) => ev,
176 Poll::Pending => return Poll::Pending,
177 }
178 };
179 let event = match inner_event {
180 TransportEvent::NewAddress {
181 listener_id,
182 mut listen_addr,
183 } => {
184 let proto = self
186 .listener_protos
187 .get(&listener_id)
188 .expect("Protocol was inserted in Transport::listen_on.");
189 listen_addr.push(proto.clone());
190 debug!("Listening on {}", listen_addr);
191 TransportEvent::NewAddress {
192 listener_id,
193 listen_addr,
194 }
195 }
196 TransportEvent::AddressExpired {
197 listener_id,
198 mut listen_addr,
199 } => {
200 let proto = self
201 .listener_protos
202 .get(&listener_id)
203 .expect("Protocol was inserted in Transport::listen_on.");
204 listen_addr.push(proto.clone());
205 TransportEvent::AddressExpired {
206 listener_id,
207 listen_addr,
208 }
209 }
210 TransportEvent::ListenerError { listener_id, error } => TransportEvent::ListenerError {
211 listener_id,
212 error: Error::Transport(error),
213 },
214 TransportEvent::ListenerClosed {
215 listener_id,
216 reason,
217 } => {
218 self.listener_protos
219 .remove(&listener_id)
220 .expect("Protocol was inserted in Transport::listen_on.");
221 TransportEvent::ListenerClosed {
222 listener_id,
223 reason: reason.map_err(Error::Transport),
224 }
225 }
226 TransportEvent::Incoming {
227 listener_id,
228 upgrade,
229 mut local_addr,
230 mut send_back_addr,
231 } => {
232 let proto = self
233 .listener_protos
234 .get(&listener_id)
235 .expect("Protocol was inserted in Transport::listen_on.");
236 let use_tls = match proto {
237 Protocol::Wss(_) => true,
238 Protocol::Ws(_) => false,
239 _ => unreachable!("Map contains only ws and wss protocols."),
240 };
241 local_addr.push(proto.clone());
242 send_back_addr.push(proto.clone());
243 let upgrade = self.map_upgrade(upgrade, send_back_addr.clone(), use_tls);
244 TransportEvent::Incoming {
245 listener_id,
246 upgrade,
247 local_addr,
248 send_back_addr,
249 }
250 }
251 };
252 Poll::Ready(event)
253 }
254}
255
256impl<T> WsConfig<T>
257where
258 T: Transport + Send + Unpin + 'static,
259 T::Error: Send + 'static,
260 T::Dial: Send + 'static,
261 T::ListenerUpgrade: Send + 'static,
262 T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static,
263{
264 fn do_dial(
265 &mut self,
266 addr: Multiaddr,
267 role_override: Endpoint,
268 ) -> Result<<Self as Transport>::Dial, TransportError<<Self as Transport>::Error>> {
269 let mut addr = match parse_ws_dial_addr(addr) {
270 Ok(addr) => addr,
271 Err(Error::InvalidMultiaddr(a)) => {
272 return Err(TransportError::MultiaddrNotSupported(a))
273 }
274 Err(e) => return Err(TransportError::Other(e)),
275 };
276
277 let mut remaining_redirects = self.max_redirects;
279
280 let transport = self.transport.clone();
281 let tls_config = self.tls_config.clone();
282 let max_redirects = self.max_redirects;
283
284 let future = async move {
285 loop {
286 match Self::dial_once(transport.clone(), addr, tls_config.clone(), role_override)
287 .await
288 {
289 Ok(Either::Left(redirect)) => {
290 if remaining_redirects == 0 {
291 debug!("Too many redirects (> {})", max_redirects);
292 return Err(Error::TooManyRedirects);
293 }
294 remaining_redirects -= 1;
295 addr = parse_ws_dial_addr(location_to_multiaddr(&redirect)?)?
296 }
297 Ok(Either::Right(conn)) => return Ok(conn),
298 Err(e) => return Err(e),
299 }
300 }
301 };
302
303 Ok(Box::pin(future))
304 }
305
306 async fn dial_once(
308 transport: Arc<Mutex<T>>,
309 addr: WsAddress,
310 tls_config: tls::Config,
311 role_override: Endpoint,
312 ) -> Result<Either<String, Connection<T::Output>>, Error<T::Error>> {
313 trace!("Dialing websocket address: {:?}", addr);
314
315 let dial = match role_override {
316 Endpoint::Dialer => transport.lock().dial(addr.tcp_addr),
317 Endpoint::Listener => transport.lock().dial_as_listener(addr.tcp_addr),
318 }
319 .map_err(|e| match e {
320 TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
321 TransportError::Other(e) => Error::Transport(e),
322 })?;
323
324 let stream = dial.map_err(Error::Transport).await?;
325 trace!("TCP connection to {} established.", addr.host_port);
326
327 let stream = if addr.use_tls {
328 let dns_name = addr
330 .dns_name
331 .expect("for use_tls we have checked that dns_name is some");
332 trace!("Starting TLS handshake with {:?}", dns_name);
333 let stream = tls_config
334 .client
335 .connect(dns_name.clone(), stream)
336 .map_err(|e| {
337 debug!("TLS handshake with {:?} failed: {}", dns_name, e);
338 Error::Tls(tls::Error::from(e))
339 })
340 .await?;
341
342 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Left(stream));
343 stream
344 } else {
345 future::Either::Right(stream)
347 };
348
349 trace!("Sending websocket handshake to {}", addr.host_port);
350
351 let mut client = handshake::Client::new(stream, &addr.host_port, addr.path.as_ref());
352
353 match client
354 .handshake()
355 .map_err(|e| Error::Handshake(Box::new(e)))
356 .await?
357 {
358 handshake::ServerResponse::Redirect {
359 status_code,
360 location,
361 } => {
362 debug!(
363 "received redirect ({}); location: {}",
364 status_code, location
365 );
366 Ok(Either::Left(location))
367 }
368 handshake::ServerResponse::Rejected { status_code } => {
369 let msg = format!("server rejected handshake; status code = {status_code}");
370 Err(Error::Handshake(msg.into()))
371 }
372 handshake::ServerResponse::Accepted { .. } => {
373 trace!("websocket handshake with {} successful", addr.host_port);
374 Ok(Either::Right(Connection::new(client.into_builder())))
375 }
376 }
377 }
378
379 fn map_upgrade(
380 &self,
381 upgrade: T::ListenerUpgrade,
382 remote_addr: Multiaddr,
383 use_tls: bool,
384 ) -> <Self as Transport>::ListenerUpgrade {
385 let remote_addr2 = remote_addr.clone(); let tls_config = self.tls_config.clone();
387 let max_size = self.max_data_size;
388
389 async move {
390 let stream = upgrade.map_err(Error::Transport).await?;
391 trace!("incoming connection from {}", remote_addr);
392
393 let stream = if use_tls {
394 let server = tls_config
396 .server
397 .expect("for use_tls we checked server is not none");
398
399 trace!("awaiting TLS handshake with {}", remote_addr);
400
401 let stream = server
402 .accept(stream)
403 .map_err(move |e| {
404 debug!("TLS handshake with {} failed: {}", remote_addr, e);
405 Error::Tls(tls::Error::from(e))
406 })
407 .await?;
408
409 let stream: TlsOrPlain<_> = future::Either::Left(future::Either::Right(stream));
410
411 stream
412 } else {
413 future::Either::Right(stream)
415 };
416
417 trace!(
418 "receiving websocket handshake request from {}",
419 remote_addr2
420 );
421
422 let mut server = handshake::Server::new(stream);
423
424 let ws_key = {
425 let request = server
426 .receive_request()
427 .map_err(|e| Error::Handshake(Box::new(e)))
428 .await?;
429 request.key()
430 };
431
432 trace!(
433 "accepting websocket handshake request from {}",
434 remote_addr2
435 );
436
437 let response = handshake::server::Response::Accept {
438 key: ws_key,
439 protocol: None,
440 };
441
442 server
443 .send_response(&response)
444 .map_err(|e| Error::Handshake(Box::new(e)))
445 .await?;
446
447 let conn = {
448 let mut builder = server.into_builder();
449 builder.set_max_message_size(max_size);
450 builder.set_max_frame_size(max_size);
451 Connection::new(builder)
452 };
453
454 Ok(conn)
455 }
456 .boxed()
457 }
458}
459
460#[derive(Debug)]
461struct WsAddress {
462 host_port: String,
463 path: String,
464 dns_name: Option<rustls::ServerName>,
465 use_tls: bool,
466 tcp_addr: Multiaddr,
467}
468
469fn parse_ws_dial_addr<T>(addr: Multiaddr) -> Result<WsAddress, Error<T>> {
475 let mut protocols = addr.iter();
479 let mut ip = protocols.next();
480 let mut tcp = protocols.next();
481 let (host_port, dns_name) = loop {
482 match (ip, tcp) {
483 (Some(Protocol::Ip4(ip)), Some(Protocol::Tcp(port))) => {
484 break (format!("{ip}:{port}"), None)
485 }
486 (Some(Protocol::Ip6(ip)), Some(Protocol::Tcp(port))) => {
487 break (format!("{ip}:{port}"), None)
488 }
489 (Some(Protocol::Dns(h)), Some(Protocol::Tcp(port)))
490 | (Some(Protocol::Dns4(h)), Some(Protocol::Tcp(port)))
491 | (Some(Protocol::Dns6(h)), Some(Protocol::Tcp(port)))
492 | (Some(Protocol::Dnsaddr(h)), Some(Protocol::Tcp(port))) => {
493 break (format!("{}:{}", &h, port), Some(tls::dns_name_ref(&h)?))
494 }
495 (Some(_), Some(p)) => {
496 ip = Some(p);
497 tcp = protocols.next();
498 }
499 _ => return Err(Error::InvalidMultiaddr(addr)),
500 }
501 };
502
503 let mut protocols = addr.clone();
507 let mut p2p = None;
508 let (use_tls, path) = loop {
509 match protocols.pop() {
510 p @ Some(Protocol::P2p(_)) => p2p = p,
511 Some(Protocol::Ws(path)) => break (false, path.into_owned()),
512 Some(Protocol::Wss(path)) => {
513 if dns_name.is_none() {
514 debug!("Missing DNS name in WSS address: {}", addr);
515 return Err(Error::InvalidMultiaddr(addr));
516 }
517 break (true, path.into_owned());
518 }
519 _ => return Err(Error::InvalidMultiaddr(addr)),
520 }
521 };
522
523 let tcp_addr = match p2p {
526 Some(p) => protocols.with(p),
527 None => protocols,
528 };
529
530 Ok(WsAddress {
531 host_port,
532 dns_name,
533 path,
534 use_tls,
535 tcp_addr,
536 })
537}
538
539fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
541 match Url::parse(location) {
542 Ok(url) => {
543 let mut a = Multiaddr::empty();
544 match url.host() {
545 Some(url::Host::Domain(h)) => a.push(Protocol::Dns(h.into())),
546 Some(url::Host::Ipv4(ip)) => a.push(Protocol::Ip4(ip)),
547 Some(url::Host::Ipv6(ip)) => a.push(Protocol::Ip6(ip)),
548 None => return Err(Error::InvalidRedirectLocation),
549 }
550 if let Some(p) = url.port() {
551 a.push(Protocol::Tcp(p))
552 }
553 let s = url.scheme();
554 if s.eq_ignore_ascii_case("https") | s.eq_ignore_ascii_case("wss") {
555 a.push(Protocol::Wss(url.path().into()))
556 } else if s.eq_ignore_ascii_case("http") | s.eq_ignore_ascii_case("ws") {
557 a.push(Protocol::Ws(url.path().into()))
558 } else {
559 debug!("unsupported scheme: {}", s);
560 return Err(Error::InvalidRedirectLocation);
561 }
562 Ok(a)
563 }
564 Err(e) => {
565 debug!("failed to parse url as multi-address: {:?}", e);
566 Err(Error::InvalidRedirectLocation)
567 }
568 }
569}
570
571pub struct Connection<T> {
573 receiver: BoxStream<'static, Result<Incoming, connection::Error>>,
574 sender: Pin<Box<dyn Sink<OutgoingData, Error = quicksink::Error<connection::Error>> + Send>>,
575 _marker: std::marker::PhantomData<T>,
576}
577
578#[derive(Debug, Clone)]
580pub enum Incoming {
581 Data(Data),
583 Pong(Vec<u8>),
585 Closed(CloseReason),
587}
588
589#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
591pub enum Data {
592 Text(Vec<u8>),
594 Binary(Vec<u8>),
596}
597
598impl Data {
599 pub fn into_bytes(self) -> Vec<u8> {
600 match self {
601 Data::Text(d) => d,
602 Data::Binary(d) => d,
603 }
604 }
605}
606
607impl AsRef<[u8]> for Data {
608 fn as_ref(&self) -> &[u8] {
609 match self {
610 Data::Text(d) => d,
611 Data::Binary(d) => d,
612 }
613 }
614}
615
616impl Incoming {
617 pub fn is_data(&self) -> bool {
618 self.is_binary() || self.is_text()
619 }
620
621 pub fn is_binary(&self) -> bool {
622 matches!(self, Incoming::Data(Data::Binary(_)))
623 }
624
625 pub fn is_text(&self) -> bool {
626 matches!(self, Incoming::Data(Data::Text(_)))
627 }
628
629 pub fn is_pong(&self) -> bool {
630 matches!(self, Incoming::Pong(_))
631 }
632
633 pub fn is_close(&self) -> bool {
634 matches!(self, Incoming::Closed(_))
635 }
636}
637
638#[derive(Debug, Clone)]
640pub enum OutgoingData {
641 Binary(Vec<u8>),
643 Ping(Vec<u8>),
645 Pong(Vec<u8>),
648}
649
650impl<T> fmt::Debug for Connection<T> {
651 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
652 f.write_str("Connection")
653 }
654}
655
656impl<T> Connection<T>
657where
658 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
659{
660 fn new(builder: connection::Builder<TlsOrPlain<T>>) -> Self {
661 let (sender, receiver) = builder.finish();
662 let sink = quicksink::make_sink(sender, |mut sender, action| async move {
663 match action {
664 quicksink::Action::Send(OutgoingData::Binary(x)) => {
665 sender.send_binary_mut(x).await?
666 }
667 quicksink::Action::Send(OutgoingData::Ping(x)) => {
668 let data = x[..].try_into().map_err(|_| {
669 io::Error::new(io::ErrorKind::InvalidInput, "PING data must be < 126 bytes")
670 })?;
671 sender.send_ping(data).await?
672 }
673 quicksink::Action::Send(OutgoingData::Pong(x)) => {
674 let data = x[..].try_into().map_err(|_| {
675 io::Error::new(io::ErrorKind::InvalidInput, "PONG data must be < 126 bytes")
676 })?;
677 sender.send_pong(data).await?
678 }
679 quicksink::Action::Flush => sender.flush().await?,
680 quicksink::Action::Close => sender.close().await?,
681 }
682 Ok(sender)
683 });
684 let stream = stream::unfold((Vec::new(), receiver), |(mut data, mut receiver)| async {
685 match receiver.receive(&mut data).await {
686 Ok(soketto::Incoming::Data(soketto::Data::Text(_))) => Some((
687 Ok(Incoming::Data(Data::Text(mem::take(&mut data)))),
688 (data, receiver),
689 )),
690 Ok(soketto::Incoming::Data(soketto::Data::Binary(_))) => Some((
691 Ok(Incoming::Data(Data::Binary(mem::take(&mut data)))),
692 (data, receiver),
693 )),
694 Ok(soketto::Incoming::Pong(pong)) => {
695 Some((Ok(Incoming::Pong(Vec::from(pong))), (data, receiver)))
696 }
697 Ok(soketto::Incoming::Closed(reason)) => {
698 Some((Ok(Incoming::Closed(reason)), (data, receiver)))
699 }
700 Err(connection::Error::Closed) => None,
701 Err(e) => Some((Err(e), (data, receiver))),
702 }
703 });
704 Connection {
705 receiver: stream.boxed(),
706 sender: Box::pin(sink),
707 _marker: std::marker::PhantomData,
708 }
709 }
710
711 pub fn send_data(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
713 self.send(OutgoingData::Binary(data))
714 }
715
716 pub fn send_ping(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
718 self.send(OutgoingData::Ping(data))
719 }
720
721 pub fn send_pong(&mut self, data: Vec<u8>) -> sink::Send<'_, Self, OutgoingData> {
723 self.send(OutgoingData::Pong(data))
724 }
725}
726
727impl<T> Stream for Connection<T>
728where
729 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
730{
731 type Item = io::Result<Incoming>;
732
733 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
734 let item = ready!(self.receiver.poll_next_unpin(cx));
735 let item = item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
736 Poll::Ready(item)
737 }
738}
739
740impl<T> Sink<OutgoingData> for Connection<T>
741where
742 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
743{
744 type Error = io::Error;
745
746 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
747 Pin::new(&mut self.sender)
748 .poll_ready(cx)
749 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
750 }
751
752 fn start_send(mut self: Pin<&mut Self>, item: OutgoingData) -> io::Result<()> {
753 Pin::new(&mut self.sender)
754 .start_send(item)
755 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
756 }
757
758 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
759 Pin::new(&mut self.sender)
760 .poll_flush(cx)
761 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
762 }
763
764 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
765 Pin::new(&mut self.sender)
766 .poll_close(cx)
767 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))
768 }
769}