1use crate::{
24 error::{AddressError, Error, NegotiationError},
25 transport::{
26 common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
27 manager::TransportHandle,
28 websocket::{
29 config::Config,
30 connection::{NegotiatedConnection, WebSocketConnection},
31 },
32 Transport, TransportBuilder, TransportEvent,
33 },
34 types::ConnectionId,
35 utils::futures_stream::FuturesStream,
36 DialError, PeerId,
37};
38
39use futures::{
40 future::BoxFuture,
41 stream::{AbortHandle, FuturesUnordered},
42 Stream, StreamExt, TryFutureExt,
43};
44use hickory_resolver::TokioResolver;
45use multiaddr::{Multiaddr, Protocol};
46use socket2::{Domain, Socket, Type};
47use std::{net::SocketAddr, sync::Arc};
48use tokio::net::TcpStream;
49use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
50
51use url::Url;
52
53use std::{
54 collections::HashMap,
55 pin::Pin,
56 task::{Context, Poll},
57 time::Duration,
58};
59
60pub(crate) use substream::Substream;
61
62mod connection;
63mod stream;
64mod substream;
65
66pub mod config;
67
68const LOG_TARGET: &str = "litep2p::websocket";
70
71struct PendingInboundConnection {
73 connection: TcpStream,
75 address: SocketAddr,
77}
78
79#[derive(Debug)]
80enum RawConnectionResult {
81 Connected {
83 negotiated: NegotiatedConnection,
84 errors: Vec<(Multiaddr, DialError)>,
85 },
86
87 Failed {
89 connection_id: ConnectionId,
90 errors: Vec<(Multiaddr, DialError)>,
91 },
92
93 Canceled { connection_id: ConnectionId },
95}
96
97pub(crate) struct WebSocketTransport {
99 context: TransportHandle,
101
102 config: Config,
104
105 listener: SocketListener,
107
108 dial_addresses: DialAddresses,
110
111 pending_dials: HashMap<ConnectionId, Multiaddr>,
113
114 pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
116
117 pending_connections:
119 FuturesStream<BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>>,
120
121 pending_raw_connections: FuturesStream<BoxFuture<'static, RawConnectionResult>>,
123
124 opened: HashMap<ConnectionId, NegotiatedConnection>,
126
127 cancel_futures: HashMap<ConnectionId, AbortHandle>,
131
132 pending_open: HashMap<ConnectionId, NegotiatedConnection>,
134
135 resolver: Arc<TokioResolver>,
137}
138
139impl WebSocketTransport {
140 fn on_inbound_connection(
142 &mut self,
143 connection_id: ConnectionId,
144 connection: TcpStream,
145 address: SocketAddr,
146 ) {
147 let keypair = self.context.keypair.clone();
148 let yamux_config = self.config.yamux_config.clone();
149 let connection_open_timeout = self.config.connection_open_timeout;
150 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
151 let max_write_buffer_size = self.config.noise_write_buffer_size;
152 let substream_open_timeout = self.config.substream_open_timeout;
153 let address = Multiaddr::empty()
154 .with(Protocol::from(address.ip()))
155 .with(Protocol::Tcp(address.port()))
156 .with(Protocol::Ws(std::borrow::Cow::Borrowed("/")));
157
158 self.pending_connections.push(Box::pin(async move {
159 match tokio::time::timeout(connection_open_timeout, async move {
160 WebSocketConnection::accept_connection(
161 connection,
162 connection_id,
163 keypair,
164 address,
165 yamux_config,
166 max_read_ahead_factor,
167 max_write_buffer_size,
168 substream_open_timeout,
169 )
170 .await
171 .map_err(|error| (connection_id, error.into()))
172 })
173 .await
174 {
175 Err(_) => Err((connection_id, DialError::Timeout)),
176 Ok(Err(error)) => Err(error),
177 Ok(Ok(result)) => Ok(result),
178 }
179 }));
180 }
181
182 fn multiaddr_into_url(address: Multiaddr) -> Result<(Url, PeerId), AddressError> {
184 let mut protocol_stack = address.iter();
185
186 let dial_address = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
187 Protocol::Ip4(address) => address.to_string(),
188 Protocol::Ip6(address) => format!("[{address}]"),
189 Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) =>
190 address.to_string(),
191
192 _ => return Err(AddressError::InvalidProtocol),
193 };
194
195 let url = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
196 Protocol::Tcp(port) => match protocol_stack.next() {
197 Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"),
198 Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"),
199 _ => return Err(AddressError::InvalidProtocol),
200 },
201 _ => return Err(AddressError::InvalidProtocol),
202 };
203
204 let peer = match protocol_stack.next() {
205 Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
206 protocol => {
207 tracing::warn!(
208 target: LOG_TARGET,
209 ?protocol,
210 "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`",
211 );
212 return Err(AddressError::PeerIdMissing);
213 }
214 };
215
216 tracing::trace!(target: LOG_TARGET, ?url, "parse address");
217
218 url::Url::parse(&url)
219 .map(|url| (url, peer))
220 .map_err(|_| AddressError::InvalidUrl)
221 }
222
223 async fn dial_peer(
225 address: Multiaddr,
226 dial_addresses: DialAddresses,
227 connection_open_timeout: Duration,
228 nodelay: bool,
229 resolver: Arc<TokioResolver>,
230 ) -> Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>), DialError> {
231 let (url, _) = Self::multiaddr_into_url(address.clone())?;
232
233 let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
234 let remote_address =
235 match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip(resolver))
236 .await
237 {
238 Err(_) => return Err(DialError::Timeout),
239 Ok(Err(error)) => return Err(error.into()),
240 Ok(Ok(address)) => address,
241 };
242
243 let domain = match remote_address.is_ipv4() {
244 true => Domain::IPV4,
245 false => Domain::IPV6,
246 };
247 let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
248 if remote_address.is_ipv6() {
249 socket.set_only_v6(true)?;
250 }
251 socket.set_nonblocking(true)?;
252 socket.set_nodelay(nodelay)?;
253
254 match dial_addresses.local_dial_address(&remote_address.ip()) {
255 Ok(Some(dial_address)) => {
256 socket.set_reuse_address(true)?;
257 #[cfg(unix)]
258 socket.set_reuse_port(true)?;
259 socket.bind(&dial_address.into())?;
260 }
261 Ok(None) => {}
262 Err(()) => {
263 tracing::debug!(
264 target: LOG_TARGET,
265 ?remote_address,
266 "tcp listener not enabled for remote address, using ephemeral port",
267 );
268 }
269 }
270
271 let future = async move {
272 match socket.connect(&remote_address.into()) {
273 Ok(()) => {}
274 Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {}
275 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {}
276 Err(err) => return Err(DialError::from(err)),
277 }
278
279 let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
280 stream.writable().await?;
281 if let Some(e) = stream.take_error()? {
282 return Err(DialError::from(e));
283 }
284
285 Ok((
286 address,
287 tokio_tungstenite::client_async_tls(url, stream)
288 .await
289 .map_err(NegotiationError::WebSocket)?
290 .0,
291 ))
292 };
293
294 match tokio::time::timeout(connection_open_timeout, future).await {
295 Err(_) => Err(DialError::Timeout),
296 Ok(Err(error)) => Err(error),
297 Ok(Ok((address, stream))) => Ok((address, stream)),
298 }
299 }
300}
301
302impl TransportBuilder for WebSocketTransport {
303 type Config = Config;
304 type Transport = WebSocketTransport;
305
306 fn new(
308 context: TransportHandle,
309 mut config: Self::Config,
310 resolver: Arc<TokioResolver>,
311 ) -> crate::Result<(Self, Vec<Multiaddr>)>
312 where
313 Self: Sized,
314 {
315 tracing::debug!(
316 target: LOG_TARGET,
317 listen_addresses = ?config.listen_addresses,
318 "start websocket transport",
319 );
320 let (listener, listen_addresses, dial_addresses) = SocketListener::new::<WebSocketAddress>(
321 std::mem::take(&mut config.listen_addresses),
322 config.reuse_port,
323 config.nodelay,
324 );
325
326 Ok((
327 Self {
328 listener,
329 config,
330 context,
331 dial_addresses,
332 opened: HashMap::new(),
333 pending_open: HashMap::new(),
334 pending_dials: HashMap::new(),
335 pending_inbound_connections: HashMap::new(),
336 pending_connections: FuturesStream::new(),
337 pending_raw_connections: FuturesStream::new(),
338 cancel_futures: HashMap::new(),
339 resolver,
340 },
341 listen_addresses,
342 ))
343 }
344}
345
346impl Transport for WebSocketTransport {
347 fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
348 let yamux_config = self.config.yamux_config.clone();
349 let keypair = self.context.keypair.clone();
350 let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?;
351 let connection_open_timeout = self.config.connection_open_timeout;
352 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
353 let max_write_buffer_size = self.config.noise_write_buffer_size;
354 let substream_open_timeout = self.config.substream_open_timeout;
355 let dial_addresses = self.dial_addresses.clone();
356 let nodelay = self.config.nodelay;
357 let resolver = self.resolver.clone();
358
359 self.pending_dials.insert(connection_id, address.clone());
360
361 tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
362
363 let future = async move {
364 let (_, stream) = WebSocketTransport::dial_peer(
365 address.clone(),
366 dial_addresses,
367 connection_open_timeout,
368 nodelay,
369 resolver,
370 )
371 .await
372 .map_err(|error| (connection_id, error))?;
373
374 WebSocketConnection::open_connection(
375 connection_id,
376 keypair,
377 stream,
378 address,
379 peer,
380 ws_address,
381 yamux_config,
382 max_read_ahead_factor,
383 max_write_buffer_size,
384 substream_open_timeout,
385 )
386 .await
387 .map_err(|error| (connection_id, error.into()))
388 };
389
390 self.pending_connections.push(Box::pin(async move {
391 match tokio::time::timeout(connection_open_timeout, future).await {
392 Err(_) => Err((connection_id, DialError::Timeout)),
393 Ok(Err(error)) => Err(error),
394 Ok(Ok(result)) => Ok(result),
395 }
396 }));
397
398 Ok(())
399 }
400
401 fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
402 let context = self
403 .pending_open
404 .remove(&connection_id)
405 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
406 let protocol_set = self.context.protocol_set(connection_id);
407 let bandwidth_sink = self.context.bandwidth_sink.clone();
408 let substream_open_timeout = self.config.substream_open_timeout;
409
410 tracing::trace!(
411 target: LOG_TARGET,
412 ?connection_id,
413 "start connection",
414 );
415
416 self.context.executor.run(Box::pin(async move {
417 if let Err(error) = WebSocketConnection::new(
418 context,
419 protocol_set,
420 bandwidth_sink,
421 substream_open_timeout,
422 )
423 .start()
424 .await
425 {
426 tracing::debug!(
427 target: LOG_TARGET,
428 ?connection_id,
429 ?error,
430 "connection exited with error",
431 );
432 }
433 }));
434
435 Ok(())
436 }
437
438 fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
439 self.pending_open
440 .remove(&connection_id)
441 .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
442 }
443
444 fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
445 let pending = self.pending_inbound_connections.remove(&connection_id).ok_or_else(|| {
446 tracing::error!(
447 target: LOG_TARGET,
448 ?connection_id,
449 "Cannot accept non existent pending connection",
450 );
451
452 Error::ConnectionDoesntExist(connection_id)
453 })?;
454
455 self.on_inbound_connection(connection_id, pending.connection, pending.address);
456
457 Ok(())
458 }
459
460 fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
461 self.pending_inbound_connections.remove(&connection_id).map_or_else(
462 || {
463 tracing::error!(
464 target: LOG_TARGET,
465 ?connection_id,
466 "Cannot reject non existent pending connection",
467 );
468
469 Err(Error::ConnectionDoesntExist(connection_id))
470 },
471 |_| Ok(()),
472 )
473 }
474
475 fn open(
476 &mut self,
477 connection_id: ConnectionId,
478 addresses: Vec<Multiaddr>,
479 ) -> crate::Result<()> {
480 let num_addresses = addresses.len();
481
482 let mut futures: FuturesUnordered<_> = addresses
483 .into_iter()
484 .map(|address| {
485 let yamux_config = self.config.yamux_config.clone();
486 let keypair = self.context.keypair.clone();
487 let connection_open_timeout = self.config.connection_open_timeout;
488 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
489 let max_write_buffer_size = self.config.noise_write_buffer_size;
490 let substream_open_timeout = self.config.substream_open_timeout;
491 let dial_addresses = self.dial_addresses.clone();
492 let nodelay = self.config.nodelay;
493 let resolver = self.resolver.clone();
494
495 async move {
496 let (address, stream) = WebSocketTransport::dial_peer(
497 address.clone(),
498 dial_addresses,
499 connection_open_timeout,
500 nodelay,
501 resolver,
502 )
503 .await
504 .map_err(|error| (address, error))?;
505
506 let open_address = address.clone();
507 let (ws_address, peer) = Self::multiaddr_into_url(address.clone())
508 .map_err(|error| (address.clone(), error.into()))?;
509
510 WebSocketConnection::open_connection(
511 connection_id,
512 keypair,
513 stream,
514 address,
515 peer,
516 ws_address,
517 yamux_config,
518 max_read_ahead_factor,
519 max_write_buffer_size,
520 substream_open_timeout,
521 )
522 .await
523 .map_err(|error| (open_address, error.into()))
524 }
525 })
526 .collect();
527
528 let future = async move {
530 let mut errors = Vec::with_capacity(num_addresses);
531 while let Some(result) = futures.next().await {
532 match result {
533 Ok(negotiated) => return RawConnectionResult::Connected { negotiated, errors },
534 Err(error) => {
535 tracing::debug!(
536 target: LOG_TARGET,
537 ?connection_id,
538 ?error,
539 "failed to open connection",
540 );
541 errors.push(error)
542 }
543 }
544 }
545
546 RawConnectionResult::Failed {
547 connection_id,
548 errors,
549 }
550 };
551
552 let (fut, handle) = futures::future::abortable(future);
553 let fut = fut.unwrap_or_else(move |_| RawConnectionResult::Canceled { connection_id });
554 self.pending_raw_connections.push(Box::pin(fut));
555 self.cancel_futures.insert(connection_id, handle);
556
557 Ok(())
558 }
559
560 fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
561 let negotiated = self
562 .opened
563 .remove(&connection_id)
564 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
565
566 self.pending_connections.push(Box::pin(async move { Ok(negotiated) }));
567
568 Ok(())
569 }
570
571 fn cancel(&mut self, connection_id: ConnectionId) {
572 if let Some(handle) = self.cancel_futures.get(&connection_id) {
575 handle.abort();
576 }
577 }
578}
579
580impl Stream for WebSocketTransport {
581 type Item = TransportEvent;
582
583 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
584 if let Poll::Ready(event) = self.listener.poll_next_unpin(cx) {
585 return match event {
586 None => {
587 tracing::error!(
588 target: LOG_TARGET,
589 "Websocket listener terminated, ignore if the node is stopping",
590 );
591
592 Poll::Ready(None)
593 }
594 Some(Err(error)) => {
595 tracing::error!(
596 target: LOG_TARGET,
597 ?error,
598 "Websocket listener terminated with error",
599 );
600
601 Poll::Ready(None)
602 }
603 Some(Ok((connection, address))) => {
604 let connection_id = self.context.next_connection_id();
605 tracing::trace!(
606 target: LOG_TARGET,
607 ?connection_id,
608 ?address,
609 "pending inbound Websocket connection",
610 );
611
612 self.pending_inbound_connections.insert(
613 connection_id,
614 PendingInboundConnection {
615 connection,
616 address,
617 },
618 );
619
620 Poll::Ready(Some(TransportEvent::PendingInboundConnection {
621 connection_id,
622 }))
623 }
624 };
625 }
626
627 while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
628 tracing::trace!(target: LOG_TARGET, ?result, "raw connection result");
629
630 match result {
631 RawConnectionResult::Connected { negotiated, errors } => {
632 let Some(handle) = self.cancel_futures.remove(&negotiated.connection_id())
633 else {
634 tracing::warn!(
635 target: LOG_TARGET,
636 connection_id = ?negotiated.connection_id(),
637 address = ?negotiated.endpoint().address(),
638 ?errors,
639 "raw connection without a cancel handle",
640 );
641 continue;
642 };
643
644 if !handle.is_aborted() {
645 let connection_id = negotiated.connection_id();
646 let address = negotiated.endpoint().address().clone();
647
648 self.opened.insert(connection_id, negotiated);
649
650 return Poll::Ready(Some(TransportEvent::ConnectionOpened {
651 connection_id,
652 address,
653 }));
654 }
655 }
656
657 RawConnectionResult::Failed {
658 connection_id,
659 errors,
660 } => {
661 let Some(handle) = self.cancel_futures.remove(&connection_id) else {
662 tracing::warn!(
663 target: LOG_TARGET,
664 ?connection_id,
665 ?errors,
666 "raw connection without a cancel handle",
667 );
668 continue;
669 };
670
671 if !handle.is_aborted() {
672 return Poll::Ready(Some(TransportEvent::OpenFailure {
673 connection_id,
674 errors,
675 }));
676 }
677 }
678 RawConnectionResult::Canceled { connection_id } => {
679 if self.cancel_futures.remove(&connection_id).is_none() {
680 tracing::warn!(
681 target: LOG_TARGET,
682 ?connection_id,
683 "raw cancelled connection without a cancel handle",
684 );
685 }
686 }
687 }
688 }
689
690 while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
691 match connection {
692 Ok(connection) => {
693 let peer = connection.peer();
694 let endpoint = connection.endpoint();
695 self.pending_dials.remove(&connection.connection_id());
696 self.pending_open.insert(connection.connection_id(), connection);
697
698 return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
699 peer,
700 endpoint,
701 }));
702 }
703 Err((connection_id, error)) => {
704 if let Some(address) = self.pending_dials.remove(&connection_id) {
705 return Poll::Ready(Some(TransportEvent::DialFailure {
706 connection_id,
707 address,
708 error,
709 }));
710 } else {
711 tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
712 }
713 }
714 }
715 }
716
717 Poll::Pending
718 }
719}