1use crate::{
24 error::{AddressError, Error, NegotiationError},
25 transport::{
26 common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
27 manager::TransportHandle,
28 websocket::{
29 config::Config,
30 connection::{NegotiatedConnection, WebSocketConnection},
31 },
32 Transport, TransportBuilder, TransportEvent, DIAL_DEADLINE_MULTIPLIER,
33 },
34 types::ConnectionId,
35 utils::futures_stream::FuturesStream,
36 DialError, PeerId,
37};
38
39use futures::{future::BoxFuture, stream::AbortHandle, Stream, StreamExt, TryFutureExt};
40use hickory_resolver::TokioResolver;
41use multiaddr::{Multiaddr, Protocol};
42use socket2::{Domain, Socket, Type};
43use std::{net::SocketAddr, sync::Arc};
44use tokio::net::TcpStream;
45use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
46
47use url::Url;
48
49use std::{
50 collections::HashMap,
51 pin::Pin,
52 task::{Context, Poll},
53 time::Duration,
54};
55
56pub(crate) use substream::Substream;
57
58mod connection;
59mod stream;
60mod substream;
61
62pub mod config;
63
64const LOG_TARGET: &str = "litep2p::websocket";
66
67struct PendingInboundConnection {
69 connection: TcpStream,
71 address: SocketAddr,
73}
74
75#[derive(Debug)]
76enum RawConnectionResult {
77 Connected {
79 negotiated: NegotiatedConnection,
80 errors: Vec<(Multiaddr, DialError)>,
81 },
82
83 Failed {
85 connection_id: ConnectionId,
86 errors: Vec<(Multiaddr, DialError)>,
87 },
88
89 Canceled { connection_id: ConnectionId },
91}
92
93pub(crate) struct WebSocketTransport {
95 context: TransportHandle,
97
98 config: Config,
100
101 listener: SocketListener,
103
104 dial_addresses: DialAddresses,
106
107 pending_dials: HashMap<ConnectionId, Multiaddr>,
109
110 pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
112
113 pending_connections:
115 FuturesStream<BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>>,
116
117 pending_raw_connections: FuturesStream<BoxFuture<'static, RawConnectionResult>>,
119
120 opened: HashMap<ConnectionId, NegotiatedConnection>,
122
123 cancel_futures: HashMap<ConnectionId, AbortHandle>,
127
128 pending_open: HashMap<ConnectionId, NegotiatedConnection>,
130
131 resolver: Arc<TokioResolver>,
133}
134
135impl WebSocketTransport {
136 fn on_inbound_connection(
138 &mut self,
139 connection_id: ConnectionId,
140 connection: TcpStream,
141 address: SocketAddr,
142 ) {
143 let keypair = self.context.keypair.clone();
144 let yamux_config = self.config.yamux_config.clone();
145 let connection_open_timeout = self.config.connection_open_timeout;
146 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
147 let max_write_buffer_size = self.config.noise_write_buffer_size;
148 let substream_open_timeout = self.config.substream_open_timeout;
149 let address = Multiaddr::empty()
150 .with(Protocol::from(address.ip()))
151 .with(Protocol::Tcp(address.port()))
152 .with(Protocol::Ws(std::borrow::Cow::Borrowed("/")));
153
154 self.pending_connections.push(Box::pin(async move {
155 match tokio::time::timeout(connection_open_timeout, async move {
156 WebSocketConnection::accept_connection(
157 connection,
158 connection_id,
159 keypair,
160 address,
161 yamux_config,
162 max_read_ahead_factor,
163 max_write_buffer_size,
164 substream_open_timeout,
165 )
166 .await
167 .map_err(|error| (connection_id, error.into()))
168 })
169 .await
170 {
171 Err(_) => Err((connection_id, DialError::Timeout)),
172 Ok(Err(error)) => Err(error),
173 Ok(Ok(result)) => Ok(result),
174 }
175 }));
176 }
177
178 fn multiaddr_into_url(address: Multiaddr) -> Result<(Url, PeerId), AddressError> {
180 let mut protocol_stack = address.iter();
181
182 let dial_address = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
183 Protocol::Ip4(address) => address.to_string(),
184 Protocol::Ip6(address) => format!("[{address}]"),
185 Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) =>
186 address.to_string(),
187
188 _ => return Err(AddressError::InvalidProtocol),
189 };
190
191 let url = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
192 Protocol::Tcp(port) => match protocol_stack.next() {
193 Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"),
194 Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"),
195 _ => return Err(AddressError::InvalidProtocol),
196 },
197 _ => return Err(AddressError::InvalidProtocol),
198 };
199
200 let peer = match protocol_stack.next() {
201 Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
202 protocol => {
203 tracing::warn!(
204 target: LOG_TARGET,
205 ?protocol,
206 "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`",
207 );
208 return Err(AddressError::PeerIdMissing);
209 }
210 };
211
212 tracing::trace!(target: LOG_TARGET, ?url, "parse address");
213
214 url::Url::parse(&url)
215 .map(|url| (url, peer))
216 .map_err(|_| AddressError::InvalidUrl)
217 }
218
219 async fn dial_peer(
221 address: Multiaddr,
222 dial_addresses: DialAddresses,
223 connection_open_timeout: Duration,
224 nodelay: bool,
225 resolver: Arc<TokioResolver>,
226 ) -> Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>), DialError> {
227 let (url, _) = Self::multiaddr_into_url(address.clone())?;
228
229 let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
230 let remote_address =
231 match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip(resolver))
232 .await
233 {
234 Err(_) => return Err(DialError::Timeout),
235 Ok(Err(error)) => return Err(error.into()),
236 Ok(Ok(address)) => address,
237 };
238
239 let domain = match remote_address.is_ipv4() {
240 true => Domain::IPV4,
241 false => Domain::IPV6,
242 };
243 let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
244 if remote_address.is_ipv6() {
245 socket.set_only_v6(true)?;
246 }
247 socket.set_nonblocking(true)?;
248 socket.set_nodelay(nodelay)?;
249
250 match dial_addresses.local_dial_address(&remote_address.ip()) {
251 Ok(Some(dial_address)) => {
252 socket.set_reuse_address(true)?;
253 #[cfg(unix)]
254 socket.set_reuse_port(true)?;
255 socket.bind(&dial_address.into())?;
256 }
257 Ok(None) => {}
258 Err(()) => {
259 tracing::debug!(
260 target: LOG_TARGET,
261 ?remote_address,
262 "tcp listener not enabled for remote address, using ephemeral port",
263 );
264 }
265 }
266
267 let future = async move {
268 match socket.connect(&remote_address.into()) {
269 Ok(()) => {}
270 Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {}
271 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {}
272 Err(err) => return Err(DialError::from(err)),
273 }
274
275 let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
276 stream.writable().await?;
277 if let Some(e) = stream.take_error()? {
278 return Err(DialError::from(e));
279 }
280
281 Ok((
282 address,
283 tokio_tungstenite::client_async_tls(url, stream)
284 .await
285 .map_err(NegotiationError::WebSocket)?
286 .0,
287 ))
288 };
289
290 match tokio::time::timeout(connection_open_timeout, future).await {
291 Err(_) => Err(DialError::Timeout),
292 Ok(Err(error)) => Err(error),
293 Ok(Ok((address, stream))) => Ok((address, stream)),
294 }
295 }
296}
297
298impl TransportBuilder for WebSocketTransport {
299 type Config = Config;
300 type Transport = WebSocketTransport;
301
302 fn new(
304 context: TransportHandle,
305 mut config: Self::Config,
306 resolver: Arc<TokioResolver>,
307 ) -> crate::Result<(Self, Vec<Multiaddr>)>
308 where
309 Self: Sized,
310 {
311 tracing::debug!(
312 target: LOG_TARGET,
313 listen_addresses = ?config.listen_addresses,
314 "start websocket transport",
315 );
316 let (listener, listen_addresses, dial_addresses) = SocketListener::new::<WebSocketAddress>(
317 std::mem::take(&mut config.listen_addresses),
318 config.reuse_port,
319 config.nodelay,
320 );
321
322 Ok((
323 Self {
324 listener,
325 config,
326 context,
327 dial_addresses,
328 opened: HashMap::new(),
329 pending_open: HashMap::new(),
330 pending_dials: HashMap::new(),
331 pending_inbound_connections: HashMap::new(),
332 pending_connections: FuturesStream::new(),
333 pending_raw_connections: FuturesStream::new(),
334 cancel_futures: HashMap::new(),
335 resolver,
336 },
337 listen_addresses,
338 ))
339 }
340}
341
342impl Transport for WebSocketTransport {
343 fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
344 let yamux_config = self.config.yamux_config.clone();
345 let keypair = self.context.keypair.clone();
346 let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?;
347 let connection_open_timeout = self.config.connection_open_timeout;
348 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
349 let max_write_buffer_size = self.config.noise_write_buffer_size;
350 let substream_open_timeout = self.config.substream_open_timeout;
351 let dial_addresses = self.dial_addresses.clone();
352 let nodelay = self.config.nodelay;
353 let resolver = self.resolver.clone();
354
355 self.pending_dials.insert(connection_id, address.clone());
356
357 tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
358
359 let future = async move {
360 let (_, stream) = WebSocketTransport::dial_peer(
361 address.clone(),
362 dial_addresses,
363 connection_open_timeout,
364 nodelay,
365 resolver,
366 )
367 .await
368 .map_err(|error| (connection_id, error))?;
369
370 WebSocketConnection::open_connection(
371 connection_id,
372 keypair,
373 stream,
374 address,
375 peer,
376 ws_address,
377 yamux_config,
378 max_read_ahead_factor,
379 max_write_buffer_size,
380 substream_open_timeout,
381 )
382 .await
383 .map_err(|error| (connection_id, error.into()))
384 };
385
386 self.pending_connections.push(Box::pin(async move {
387 match tokio::time::timeout(connection_open_timeout, future).await {
388 Err(_) => Err((connection_id, DialError::Timeout)),
389 Ok(Err(error)) => Err(error),
390 Ok(Ok(result)) => Ok(result),
391 }
392 }));
393
394 Ok(())
395 }
396
397 fn accept(
398 &mut self,
399 connection_id: ConnectionId,
400 ) -> crate::Result<BoxFuture<'static, crate::Result<()>>> {
401 let context = self
402 .pending_open
403 .remove(&connection_id)
404 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
405 let mut protocol_set = self.context.protocol_set(connection_id);
406 let bandwidth_sink = self.context.bandwidth_sink.clone();
407 let substream_open_timeout = self.config.substream_open_timeout;
408 let executor = self.context.executor.clone();
409
410 tracing::trace!(
411 target: LOG_TARGET,
412 ?connection_id,
413 "start connection",
414 );
415
416 let peer = context.peer();
417 let endpoint = context.endpoint();
418
419 Ok(Box::pin(async move {
420 protocol_set.report_connection_established(peer, endpoint).await?;
422
423 executor.run(Box::pin(async move {
425 if let Err(error) = WebSocketConnection::new(
426 context,
427 protocol_set,
428 bandwidth_sink,
429 substream_open_timeout,
430 )
431 .start()
432 .await
433 {
434 tracing::debug!(
435 target: LOG_TARGET,
436 ?connection_id,
437 ?error,
438 "connection exited with error",
439 );
440 }
441 }));
442
443 Ok(())
444 }))
445 }
446
447 fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
448 self.pending_open
449 .remove(&connection_id)
450 .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
451 }
452
453 fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
454 let pending = self.pending_inbound_connections.remove(&connection_id).ok_or_else(|| {
455 tracing::error!(
456 target: LOG_TARGET,
457 ?connection_id,
458 "Cannot accept non existent pending connection",
459 );
460
461 Error::ConnectionDoesntExist(connection_id)
462 })?;
463
464 self.on_inbound_connection(connection_id, pending.connection, pending.address);
465
466 Ok(())
467 }
468
469 fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
470 self.pending_inbound_connections.remove(&connection_id).map_or_else(
471 || {
472 tracing::error!(
473 target: LOG_TARGET,
474 ?connection_id,
475 "Cannot reject non existent pending connection",
476 );
477
478 Err(Error::ConnectionDoesntExist(connection_id))
479 },
480 |_| Ok(()),
481 )
482 }
483
484 fn open(
485 &mut self,
486 connection_id: ConnectionId,
487 addresses: Vec<Multiaddr>,
488 ) -> crate::Result<()> {
489 let num_addresses = addresses.len();
490
491 let yamux_config = self.config.yamux_config.clone();
492 let keypair = self.context.keypair.clone();
493 let connection_open_timeout = self.config.connection_open_timeout;
494 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
495 let max_write_buffer_size = self.config.noise_write_buffer_size;
496 let substream_open_timeout = self.config.substream_open_timeout;
497 let max_parallel_dials = self.config.max_parallel_dials;
498 let dial_addresses = self.dial_addresses.clone();
499 let nodelay = self.config.nodelay;
500 let resolver = self.resolver.clone();
501
502 let futures = futures::stream::iter(addresses.into_iter().map(move |address| {
503 let yamux_config = yamux_config.clone();
504 let keypair = keypair.clone();
505 let dial_addresses = dial_addresses.clone();
506 let resolver = resolver.clone();
507
508 async move {
509 let (address, stream) = WebSocketTransport::dial_peer(
510 address.clone(),
511 dial_addresses,
512 connection_open_timeout,
513 nodelay,
514 resolver,
515 )
516 .await
517 .map_err(|error| (address, error))?;
518
519 let open_address = address.clone();
520 let (ws_address, peer) = Self::multiaddr_into_url(address.clone())
521 .map_err(|error| (address.clone(), error.into()))?;
522
523 WebSocketConnection::open_connection(
524 connection_id,
525 keypair,
526 stream,
527 address,
528 peer,
529 ws_address,
530 yamux_config,
531 max_read_ahead_factor,
532 max_write_buffer_size,
533 substream_open_timeout,
534 )
535 .await
536 .map_err(|error| (open_address, error.into()))
537 }
538 }))
539 .buffer_unordered(max_parallel_dials);
540
541 let future = async move {
546 let mut errors = Vec::with_capacity(num_addresses);
547 let dial_deadline = DIAL_DEADLINE_MULTIPLIER * connection_open_timeout;
550 let deadline = tokio::time::sleep(dial_deadline);
551
552 tokio::pin!(deadline);
553 tokio::pin!(futures);
554
555 loop {
556 tokio::select! {
557 result = futures.next() => {
558 match result {
559 Some(Ok(negotiated)) => {
560 return RawConnectionResult::Connected {
561 negotiated,
562 errors,
563 };
564 }
565 Some(Err(error)) => {
566 tracing::debug!(
567 target: LOG_TARGET,
568 ?connection_id,
569 ?error,
570 "failed to open connection",
571 );
572 errors.push(error);
573 }
574 None => {
575 return RawConnectionResult::Failed {
576 connection_id,
577 errors,
578 };
579 }
580 }
581 }
582 _ = &mut deadline => {
583 tracing::debug!(
584 target: LOG_TARGET,
585 ?connection_id,
586 ?dial_deadline,
587 "overall dial timeout exceeded",
588 );
589 return RawConnectionResult::Failed {
590 connection_id,
591 errors,
592 };
593 }
594 }
595 }
596 };
597
598 let (fut, handle) = futures::future::abortable(future);
599 let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id });
600 self.pending_raw_connections.push(Box::pin(fut));
601 self.cancel_futures.insert(connection_id, handle);
602
603 Ok(())
604 }
605
606 fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
607 let negotiated = self
608 .opened
609 .remove(&connection_id)
610 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
611
612 self.pending_connections.push(Box::pin(async move { Ok(negotiated) }));
613
614 Ok(())
615 }
616
617 fn cancel(&mut self, connection_id: ConnectionId) {
618 if let Some(handle) = self.cancel_futures.get(&connection_id) {
621 handle.abort();
622 }
623 }
624}
625
626impl Stream for WebSocketTransport {
627 type Item = TransportEvent;
628
629 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
630 if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
631 return match event {
632 None => {
633 tracing::error!(
634 target: LOG_TARGET,
635 "Websocket listener terminated, ignore if the node is stopping",
636 );
637
638 Poll::Ready(None)
639 }
640 Some(Err(error)) => {
641 tracing::error!(
642 target: LOG_TARGET,
643 ?error,
644 "Websocket listener terminated with error",
645 );
646
647 Poll::Ready(None)
648 }
649 Some(Ok((connection, address))) => {
650 let connection_id = self.context.next_connection_id();
651 tracing::trace!(
652 target: LOG_TARGET,
653 ?connection_id,
654 ?address,
655 "pending inbound Websocket connection",
656 );
657
658 self.pending_inbound_connections.insert(
659 connection_id,
660 PendingInboundConnection {
661 connection,
662 address,
663 },
664 );
665
666 Poll::Ready(Some(TransportEvent::PendingInboundConnection {
667 connection_id,
668 }))
669 }
670 };
671 }
672
673 while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
674 tracing::trace!(target: LOG_TARGET, ?result, "raw connection result");
675
676 match result {
677 RawConnectionResult::Connected { negotiated, errors } => {
678 let Some(handle) = self.cancel_futures.remove(&negotiated.connection_id())
679 else {
680 tracing::warn!(
681 target: LOG_TARGET,
682 connection_id = ?negotiated.connection_id(),
683 address = ?negotiated.endpoint().address(),
684 ?errors,
685 "raw connection without a cancel handle",
686 );
687 continue;
688 };
689
690 if !handle.is_aborted() {
691 let connection_id = negotiated.connection_id();
692 let address = negotiated.endpoint().address().clone();
693
694 self.opened.insert(connection_id, negotiated);
695
696 return Poll::Ready(Some(TransportEvent::ConnectionOpened {
697 connection_id,
698 address,
699 errors,
700 }));
701 }
702 }
703
704 RawConnectionResult::Failed {
705 connection_id,
706 errors,
707 } => {
708 let Some(handle) = self.cancel_futures.remove(&connection_id) else {
709 tracing::warn!(
710 target: LOG_TARGET,
711 ?connection_id,
712 ?errors,
713 "raw connection without a cancel handle",
714 );
715 continue;
716 };
717
718 if !handle.is_aborted() {
719 return Poll::Ready(Some(TransportEvent::OpenFailure {
720 connection_id,
721 errors,
722 }));
723 }
724 }
725 RawConnectionResult::Canceled { connection_id } => {
726 if self.cancel_futures.remove(&connection_id).is_none() {
727 tracing::warn!(
728 target: LOG_TARGET,
729 ?connection_id,
730 "raw cancelled connection without a cancel handle",
731 );
732 }
733 }
734 }
735 }
736
737 while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
738 match connection {
739 Ok(connection) => {
740 let peer = connection.peer();
741 let endpoint = connection.endpoint();
742 self.pending_dials.remove(&connection.connection_id());
743 self.pending_open.insert(connection.connection_id(), connection);
744
745 return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
746 peer,
747 endpoint,
748 }));
749 }
750 Err((connection_id, error)) => {
751 if let Some(address) = self.pending_dials.remove(&connection_id) {
752 return Poll::Ready(Some(TransportEvent::DialFailure {
753 connection_id,
754 address,
755 error,
756 }));
757 } else {
758 tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
759 }
760 }
761 }
762 }
763
764 Poll::Pending
765 }
766}