1use crate::{
24 config::Role,
25 error::{AddressError, Error, NegotiationError},
26 transport::{
27 common::listener::{DialAddresses, GetSocketAddr, SocketListener, WebSocketAddress},
28 manager::TransportHandle,
29 websocket::{
30 config::Config,
31 connection::{NegotiatedConnection, WebSocketConnection},
32 },
33 Transport, TransportBuilder, TransportEvent,
34 },
35 types::ConnectionId,
36 DialError, PeerId,
37};
38
39use futures::{future::BoxFuture, stream::FuturesUnordered, Stream, StreamExt};
40use multiaddr::{Multiaddr, Protocol};
41use socket2::{Domain, Socket, Type};
42use std::net::SocketAddr;
43use tokio::net::TcpStream;
44use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
45
46use url::Url;
47
48use std::{
49 collections::{HashMap, HashSet},
50 pin::Pin,
51 task::{Context, Poll},
52 time::Duration,
53};
54
55pub(crate) use substream::Substream;
56
57mod connection;
58mod stream;
59mod substream;
60
61pub mod config;
62
63const LOG_TARGET: &str = "litep2p::websocket";
65
66struct PendingInboundConnection {
68 connection: TcpStream,
70 address: SocketAddr,
72}
73
74pub(crate) struct WebSocketTransport {
76 context: TransportHandle,
78
79 config: Config,
81
82 listener: SocketListener,
84
85 dial_addresses: DialAddresses,
87
88 pending_dials: HashMap<ConnectionId, Multiaddr>,
90
91 pending_inbound_connections: HashMap<ConnectionId, PendingInboundConnection>,
93
94 pending_connections: FuturesUnordered<
96 BoxFuture<'static, Result<NegotiatedConnection, (ConnectionId, DialError)>>,
97 >,
98
99 pending_raw_connections: FuturesUnordered<
101 BoxFuture<
102 'static,
103 Result<
104 (
105 ConnectionId,
106 Multiaddr,
107 WebSocketStream<MaybeTlsStream<TcpStream>>,
108 ),
109 (ConnectionId, Vec<(Multiaddr, DialError)>),
110 >,
111 >,
112 >,
113
114 opened_raw: HashMap<ConnectionId, (WebSocketStream<MaybeTlsStream<TcpStream>>, Multiaddr)>,
116
117 canceled: HashSet<ConnectionId>,
119
120 pending_open: HashMap<ConnectionId, NegotiatedConnection>,
122}
123
124impl WebSocketTransport {
125 fn on_inbound_connection(
127 &mut self,
128 connection_id: ConnectionId,
129 connection: TcpStream,
130 address: SocketAddr,
131 ) {
132 let keypair = self.context.keypair.clone();
133 let yamux_config = self.config.yamux_config.clone();
134 let connection_open_timeout = self.config.connection_open_timeout;
135 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
136 let max_write_buffer_size = self.config.noise_write_buffer_size;
137 let address = Multiaddr::empty()
138 .with(Protocol::from(address.ip()))
139 .with(Protocol::Tcp(address.port()))
140 .with(Protocol::Ws(std::borrow::Cow::Borrowed("/")));
141
142 self.pending_connections.push(Box::pin(async move {
143 match tokio::time::timeout(connection_open_timeout, async move {
144 WebSocketConnection::accept_connection(
145 connection,
146 connection_id,
147 keypair,
148 address,
149 yamux_config,
150 max_read_ahead_factor,
151 max_write_buffer_size,
152 )
153 .await
154 .map_err(|error| (connection_id, error.into()))
155 })
156 .await
157 {
158 Err(_) => Err((connection_id, DialError::Timeout)),
159 Ok(Err(error)) => Err(error),
160 Ok(Ok(result)) => Ok(result),
161 }
162 }));
163 }
164
165 fn multiaddr_into_url(address: Multiaddr) -> Result<(Url, PeerId), AddressError> {
167 let mut protocol_stack = address.iter();
168
169 let dial_address = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
170 Protocol::Ip4(address) => address.to_string(),
171 Protocol::Ip6(address) => format!("[{address}]"),
172 Protocol::Dns(address) | Protocol::Dns4(address) | Protocol::Dns6(address) =>
173 address.to_string(),
174
175 _ => return Err(AddressError::InvalidProtocol),
176 };
177
178 let url = match protocol_stack.next().ok_or(AddressError::InvalidProtocol)? {
179 Protocol::Tcp(port) => match protocol_stack.next() {
180 Some(Protocol::Ws(_)) => format!("ws://{dial_address}:{port}/"),
181 Some(Protocol::Wss(_)) => format!("wss://{dial_address}:{port}/"),
182 _ => return Err(AddressError::InvalidProtocol),
183 },
184 _ => return Err(AddressError::InvalidProtocol),
185 };
186
187 let peer = match protocol_stack.next() {
188 Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
189 protocol => {
190 tracing::warn!(
191 target: LOG_TARGET,
192 ?protocol,
193 "invalid protocol, expected `Protocol::Ws`/`Protocol::Wss`",
194 );
195 return Err(AddressError::PeerIdMissing);
196 }
197 };
198
199 tracing::trace!(target: LOG_TARGET, ?url, "parse address");
200
201 url::Url::parse(&url)
202 .map(|url| (url, peer))
203 .map_err(|_| AddressError::InvalidUrl)
204 }
205
206 async fn dial_peer(
208 address: Multiaddr,
209 dial_addresses: DialAddresses,
210 connection_open_timeout: Duration,
211 nodelay: bool,
212 ) -> Result<(Multiaddr, WebSocketStream<MaybeTlsStream<TcpStream>>), DialError> {
213 let (url, _) = Self::multiaddr_into_url(address.clone())?;
214
215 let (socket_address, _) = WebSocketAddress::multiaddr_to_socket_address(&address)?;
216 let remote_address =
217 match tokio::time::timeout(connection_open_timeout, socket_address.lookup_ip()).await {
218 Err(_) => return Err(DialError::Timeout),
219 Ok(Err(error)) => return Err(error.into()),
220 Ok(Ok(address)) => address,
221 };
222
223 let domain = match remote_address.is_ipv4() {
224 true => Domain::IPV4,
225 false => Domain::IPV6,
226 };
227 let socket = Socket::new(domain, Type::STREAM, Some(socket2::Protocol::TCP))?;
228 if remote_address.is_ipv6() {
229 socket.set_only_v6(true)?;
230 }
231 socket.set_nonblocking(true)?;
232 socket.set_nodelay(nodelay)?;
233
234 match dial_addresses.local_dial_address(&remote_address.ip()) {
235 Ok(Some(dial_address)) => {
236 socket.set_reuse_address(true)?;
237 #[cfg(unix)]
238 socket.set_reuse_port(true)?;
239 socket.bind(&dial_address.into())?;
240 }
241 Ok(None) => {}
242 Err(()) => {
243 tracing::debug!(
244 target: LOG_TARGET,
245 ?remote_address,
246 "tcp listener not enabled for remote address, using ephemeral port",
247 );
248 }
249 }
250
251 let future = async move {
252 match socket.connect(&remote_address.into()) {
253 Ok(()) => {}
254 Err(error) if error.raw_os_error() == Some(libc::EINPROGRESS) => {}
255 Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {}
256 Err(err) => return Err(DialError::from(err)),
257 }
258
259 let stream = TcpStream::try_from(Into::<std::net::TcpStream>::into(socket))?;
260 stream.writable().await?;
261 if let Some(e) = stream.take_error()? {
262 return Err(DialError::from(e));
263 }
264
265 Ok((
266 address,
267 tokio_tungstenite::client_async_tls(url, stream)
268 .await
269 .map_err(NegotiationError::WebSocket)?
270 .0,
271 ))
272 };
273
274 match tokio::time::timeout(connection_open_timeout, future).await {
275 Err(_) => Err(DialError::Timeout),
276 Ok(Err(error)) => Err(error),
277 Ok(Ok((address, stream))) => Ok((address, stream)),
278 }
279 }
280}
281
282impl TransportBuilder for WebSocketTransport {
283 type Config = Config;
284 type Transport = WebSocketTransport;
285
286 fn new(
288 context: TransportHandle,
289 mut config: Self::Config,
290 ) -> crate::Result<(Self, Vec<Multiaddr>)>
291 where
292 Self: Sized,
293 {
294 tracing::debug!(
295 target: LOG_TARGET,
296 listen_addresses = ?config.listen_addresses,
297 "start websocket transport",
298 );
299 let (listener, listen_addresses, dial_addresses) = SocketListener::new::<WebSocketAddress>(
300 std::mem::take(&mut config.listen_addresses),
301 config.reuse_port,
302 config.nodelay,
303 );
304
305 Ok((
306 Self {
307 listener,
308 config,
309 context,
310 dial_addresses,
311 canceled: HashSet::new(),
312 opened_raw: HashMap::new(),
313 pending_open: HashMap::new(),
314 pending_dials: HashMap::new(),
315 pending_inbound_connections: HashMap::new(),
316 pending_connections: FuturesUnordered::new(),
317 pending_raw_connections: FuturesUnordered::new(),
318 },
319 listen_addresses,
320 ))
321 }
322}
323
324impl Transport for WebSocketTransport {
325 fn dial(&mut self, connection_id: ConnectionId, address: Multiaddr) -> crate::Result<()> {
326 let yamux_config = self.config.yamux_config.clone();
327 let keypair = self.context.keypair.clone();
328 let (ws_address, peer) = Self::multiaddr_into_url(address.clone())?;
329 let connection_open_timeout = self.config.connection_open_timeout;
330 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
331 let max_write_buffer_size = self.config.noise_write_buffer_size;
332 let dial_addresses = self.dial_addresses.clone();
333 let nodelay = self.config.nodelay;
334
335 self.pending_dials.insert(connection_id, address.clone());
336
337 tracing::debug!(target: LOG_TARGET, ?connection_id, ?address, "open connection");
338
339 let future = async move {
340 let (_, stream) = WebSocketTransport::dial_peer(
341 address.clone(),
342 dial_addresses,
343 connection_open_timeout,
344 nodelay,
345 )
346 .await
347 .map_err(|error| (connection_id, error))?;
348
349 WebSocketConnection::open_connection(
350 connection_id,
351 keypair,
352 stream,
353 address,
354 peer,
355 ws_address,
356 yamux_config,
357 max_read_ahead_factor,
358 max_write_buffer_size,
359 )
360 .await
361 .map_err(|error| (connection_id, error.into()))
362 };
363
364 self.pending_connections.push(Box::pin(async move {
365 match tokio::time::timeout(connection_open_timeout, future).await {
366 Err(_) => Err((connection_id, DialError::Timeout)),
367 Ok(Err(error)) => Err(error),
368 Ok(Ok(result)) => Ok(result),
369 }
370 }));
371
372 Ok(())
373 }
374
375 fn accept(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
376 let context = self
377 .pending_open
378 .remove(&connection_id)
379 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
380 let protocol_set = self.context.protocol_set(connection_id);
381 let bandwidth_sink = self.context.bandwidth_sink.clone();
382 let substream_open_timeout = self.config.substream_open_timeout;
383
384 tracing::trace!(
385 target: LOG_TARGET,
386 ?connection_id,
387 "start connection",
388 );
389
390 self.context.executor.run(Box::pin(async move {
391 if let Err(error) = WebSocketConnection::new(
392 context,
393 protocol_set,
394 bandwidth_sink,
395 substream_open_timeout,
396 )
397 .start()
398 .await
399 {
400 tracing::debug!(
401 target: LOG_TARGET,
402 ?connection_id,
403 ?error,
404 "connection exited with error",
405 );
406 }
407 }));
408
409 Ok(())
410 }
411
412 fn reject(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
413 self.pending_open
414 .remove(&connection_id)
415 .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
416 }
417
418 fn accept_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
419 let pending = self
420 .pending_inbound_connections
421 .remove(&connection_id)
422 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
423
424 self.on_inbound_connection(connection_id, pending.connection, pending.address);
425
426 Ok(())
427 }
428
429 fn reject_pending(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
430 self.pending_open
431 .remove(&connection_id)
432 .map_or(Err(Error::ConnectionDoesntExist(connection_id)), |_| Ok(()))
433 }
434
435 fn open(
436 &mut self,
437 connection_id: ConnectionId,
438 addresses: Vec<Multiaddr>,
439 ) -> crate::Result<()> {
440 let num_addresses = addresses.len();
441 let mut futures: FuturesUnordered<_> = addresses
442 .into_iter()
443 .map(|address| {
444 let connection_open_timeout = self.config.connection_open_timeout;
445 let dial_addresses = self.dial_addresses.clone();
446 let nodelay = self.config.nodelay;
447
448 async move {
449 WebSocketTransport::dial_peer(
450 address.clone(),
451 dial_addresses,
452 connection_open_timeout,
453 nodelay,
454 )
455 .await
456 .map_err(|error| (address, error))
457 }
458 })
459 .collect();
460
461 self.pending_raw_connections.push(Box::pin(async move {
462 let mut errors = Vec::with_capacity(num_addresses);
463
464 while let Some(result) = futures.next().await {
465 match result {
466 Ok((address, stream)) => return Ok((connection_id, address, stream)),
467 Err(error) => {
468 tracing::debug!(
469 target: LOG_TARGET,
470 ?connection_id,
471 ?error,
472 "failed to open connection",
473 );
474 errors.push(error)
475 }
476 }
477 }
478
479 Err((connection_id, errors))
480 }));
481
482 Ok(())
483 }
484
485 fn negotiate(&mut self, connection_id: ConnectionId) -> crate::Result<()> {
486 let (stream, address) = self
487 .opened_raw
488 .remove(&connection_id)
489 .ok_or(Error::ConnectionDoesntExist(connection_id))?;
490
491 let peer = match address.iter().find(|protocol| std::matches!(protocol, Protocol::P2p(_))) {
492 Some(Protocol::P2p(multihash)) => PeerId::from_multihash(multihash)?,
493 _ => return Err(Error::InvalidState),
494 };
495 let yamux_config = self.config.yamux_config.clone();
496 let max_read_ahead_factor = self.config.noise_read_ahead_frame_count;
497 let max_write_buffer_size = self.config.noise_write_buffer_size;
498 let connection_open_timeout = self.config.connection_open_timeout;
499 let keypair = self.context.keypair.clone();
500
501 tracing::trace!(
502 target: LOG_TARGET,
503 ?peer,
504 ?connection_id,
505 ?address,
506 "negotiate connection",
507 );
508
509 self.pending_dials.insert(connection_id, address.clone());
510 self.pending_connections.push(Box::pin(async move {
511 match tokio::time::timeout(connection_open_timeout, async move {
512 WebSocketConnection::negotiate_connection(
513 stream,
514 Some(peer),
515 Role::Dialer,
516 address,
517 connection_id,
518 keypair,
519 yamux_config,
520 max_read_ahead_factor,
521 max_write_buffer_size,
522 )
523 .await
524 .map_err(|error| (connection_id, error.into()))
525 })
526 .await
527 {
528 Err(_) => Err((connection_id, DialError::Timeout)),
529 Ok(Err(error)) => Err(error),
530 Ok(Ok(connection)) => Ok(connection),
531 }
532 }));
533
534 Ok(())
535 }
536
537 fn cancel(&mut self, connection_id: ConnectionId) {
538 self.canceled.insert(connection_id);
539 }
540}
541
542impl Stream for WebSocketTransport {
543 type Item = TransportEvent;
544
545 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
546 if let Poll::Ready(Some(connection)) = self.listener.poll_next_unpin(cx) {
547 return match connection {
548 Err(_) => Poll::Ready(None),
549 Ok((connection, address)) => {
550 let connection_id = self.context.next_connection_id();
551
552 self.pending_inbound_connections.insert(
553 connection_id,
554 PendingInboundConnection {
555 connection,
556 address,
557 },
558 );
559
560 Poll::Ready(Some(TransportEvent::PendingInboundConnection {
561 connection_id,
562 }))
563 }
564 };
565 }
566
567 while let Poll::Ready(Some(result)) = self.pending_raw_connections.poll_next_unpin(cx) {
568 match result {
569 Ok((connection_id, address, stream)) => {
570 tracing::trace!(
571 target: LOG_TARGET,
572 ?connection_id,
573 ?address,
574 canceled = self.canceled.contains(&connection_id),
575 "connection opened",
576 );
577
578 if !self.canceled.remove(&connection_id) {
579 self.opened_raw.insert(connection_id, (stream, address.clone()));
580
581 return Poll::Ready(Some(TransportEvent::ConnectionOpened {
582 connection_id,
583 address,
584 }));
585 }
586 }
587 Err((connection_id, errors)) =>
588 if !self.canceled.remove(&connection_id) {
589 return Poll::Ready(Some(TransportEvent::OpenFailure {
590 connection_id,
591 errors,
592 }));
593 },
594 }
595 }
596
597 while let Poll::Ready(Some(connection)) = self.pending_connections.poll_next_unpin(cx) {
598 match connection {
599 Ok(connection) => {
600 let peer = connection.peer();
601 let endpoint = connection.endpoint();
602 self.pending_open.insert(connection.connection_id(), connection);
603
604 return Poll::Ready(Some(TransportEvent::ConnectionEstablished {
605 peer,
606 endpoint,
607 }));
608 }
609 Err((connection_id, error)) => {
610 if let Some(address) = self.pending_dials.remove(&connection_id) {
611 return Poll::Ready(Some(TransportEvent::DialFailure {
612 connection_id,
613 address,
614 error,
615 }));
616 } else {
617 tracing::debug!(target: LOG_TARGET, ?error, ?connection_id, "Pending inbound connection failed");
618 }
619 }
620 }
621 }
622
623 Poll::Pending
624 }
625}