1mod stream;
28
29use std::io;
30use std::net::SocketAddr;
31use std::time::Duration;
32
33use base64::Engine;
34use futures_util::io::{BufReader, BufWriter};
35use jsonrpsee_core::client::{MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT};
36use jsonrpsee_core::TEN_MB_SIZE_BYTES;
37use jsonrpsee_core::{async_trait, Cow};
38use soketto::connection::Error::Utf8;
39use soketto::data::ByteSlice125;
40use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse};
41use soketto::{connection, Data, Incoming};
42use thiserror::Error;
43use tokio::net::TcpStream;
44use tokio_util::compat::{Compat, TokioAsyncReadCompatExt};
45
46pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri};
47pub use soketto::handshake::client::Header;
48pub use stream::EitherStream;
49pub use tokio::io::{AsyncRead, AsyncWrite};
50pub use url::Url;
51
52const LOG_TARGET: &str = "jsonrpsee-client";
53
54#[cfg(feature = "tls")]
56pub type CustomCertStore = rustls::ClientConfig;
57
58#[allow(clippy::large_enum_variant)]
61#[cfg(feature = "tls")]
62#[derive(Debug, Clone)]
63pub enum CertificateStore {
64 Native,
66 Custom(CustomCertStore),
68}
69
70#[derive(Debug)]
72pub struct Sender<T> {
73 inner: connection::Sender<BufReader<BufWriter<T>>>,
74 max_request_size: u32,
75}
76
77#[derive(Debug)]
79pub struct Receiver<T> {
80 inner: connection::Receiver<BufReader<BufWriter<T>>>,
81}
82
83#[derive(Debug)]
85pub struct WsTransportClientBuilder {
86 #[cfg(feature = "tls")]
87 pub certificate_store: CertificateStore,
89 pub connection_timeout: Duration,
91 pub headers: http::HeaderMap,
93 pub max_request_size: u32,
95 pub max_response_size: u32,
97 pub max_redirections: usize,
99 pub tcp_no_delay: bool,
101}
102
103impl Default for WsTransportClientBuilder {
104 fn default() -> Self {
105 Self {
106 #[cfg(feature = "tls")]
107 certificate_store: CertificateStore::Native,
108 max_request_size: TEN_MB_SIZE_BYTES,
109 max_response_size: TEN_MB_SIZE_BYTES,
110 connection_timeout: Duration::from_secs(10),
111 headers: http::HeaderMap::new(),
112 max_redirections: 5,
113 tcp_no_delay: true,
114 }
115 }
116}
117
118impl WsTransportClientBuilder {
119 #[cfg(feature = "tls")]
125 pub fn with_custom_cert_store(mut self, cfg: CustomCertStore) -> Self {
126 self.certificate_store = CertificateStore::Custom(cfg);
127 self
128 }
129
130 pub fn max_request_size(mut self, size: u32) -> Self {
132 self.max_request_size = size;
133 self
134 }
135
136 pub fn max_response_size(mut self, size: u32) -> Self {
138 self.max_response_size = size;
139 self
140 }
141
142 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
144 self.connection_timeout = timeout;
145 self
146 }
147
148 pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
152 self.headers = headers;
153 self
154 }
155
156 pub fn max_redirections(mut self, redirect: usize) -> Self {
159 self.max_redirections = redirect;
160 self
161 }
162}
163
164#[derive(Clone, Copy, Debug, PartialEq, Eq)]
166pub enum Mode {
167 Plain,
169 Tls,
171}
172
173#[derive(Debug, Error)]
178pub enum WsHandshakeError {
179 #[error("Failed to load system certs: {0}")]
181 CertificateStore(io::Error),
182
183 #[error("Invalid URL: {0}")]
185 Url(Cow<'static, str>),
186
187 #[error("Error when opening the TCP socket: {0}")]
189 Io(io::Error),
190
191 #[error("{0}")]
193 Transport(#[source] soketto::handshake::Error),
194
195 #[error("Connection rejected with status code: {status_code}")]
197 Rejected {
198 status_code: u16,
200 },
201
202 #[error("Connection redirected with status code: {status_code} and location: {location}")]
204 Redirected {
205 status_code: u16,
207 location: String,
209 },
210
211 #[error("Connection timeout exceeded: {0:?}")]
213 Timeout(Duration),
214
215 #[error("Failed to resolve IP addresses for this hostname: {0}")]
217 ResolutionFailed(io::Error),
218
219 #[error("No IP address found for this hostname: {0}")]
221 NoAddressFound(String),
222}
223
224#[derive(Debug, Error)]
226pub enum WsError {
227 #[error("{0}")]
229 Connection(#[source] soketto::connection::Error),
230 #[error("The message was too large")]
232 MessageTooLarge,
233}
234
235#[async_trait]
236impl<T> TransportSenderT for Sender<T>
237where
238 T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static,
239{
240 type Error = WsError;
241
242 async fn send(&mut self, body: String) -> Result<(), Self::Error> {
245 if body.len() > self.max_request_size as usize {
246 return Err(WsError::MessageTooLarge);
247 }
248
249 self.inner.send_text(body).await?;
250 self.inner.flush().await?;
251 Ok(())
252 }
253
254 async fn send_ping(&mut self) -> Result<(), Self::Error> {
257 tracing::debug!(target: LOG_TARGET, "Send ping");
258 let slice: &[u8] = &[];
260 let byte_slice = ByteSlice125::try_from(slice).expect("Empty slice should fit into ByteSlice125");
262
263 self.inner.send_ping(byte_slice).await?;
264 self.inner.flush().await?;
265 Ok(())
266 }
267
268 async fn close(&mut self) -> Result<(), WsError> {
270 self.inner.close().await.map_err(Into::into)
271 }
272}
273
274#[async_trait]
275impl<T> TransportReceiverT for Receiver<T>
276where
277 T: futures_util::io::AsyncRead + futures_util::io::AsyncWrite + Unpin + MaybeSend + 'static,
278{
279 type Error = WsError;
280
281 async fn receive(&mut self) -> Result<ReceivedMessage, Self::Error> {
283 loop {
284 let mut message = Vec::new();
285 let recv = self.inner.receive(&mut message).await?;
286
287 match recv {
288 Incoming::Data(Data::Text(_)) => {
289 let s = String::from_utf8(message).map_err(|err| WsError::Connection(Utf8(err.utf8_error())))?;
290 break Ok(ReceivedMessage::Text(s));
291 }
292 Incoming::Data(Data::Binary(_)) => break Ok(ReceivedMessage::Bytes(message)),
293 Incoming::Pong(_) => break Ok(ReceivedMessage::Pong),
294 _ => continue,
295 }
296 }
297 }
298}
299
300impl WsTransportClientBuilder {
301 pub async fn build(
305 self,
306 uri: Url,
307 ) -> Result<(Sender<Compat<EitherStream>>, Receiver<Compat<EitherStream>>), WsHandshakeError> {
308 self.try_connect_over_tcp(uri).await
309 }
310
311 pub async fn build_with_stream<T>(
313 self,
314 uri: Url,
315 data_stream: T,
316 ) -> Result<(Sender<Compat<T>>, Receiver<Compat<T>>), WsHandshakeError>
317 where
318 T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
319 {
320 let target: Target = uri.try_into()?;
321 self.try_connect(&target, data_stream.compat()).await
322 }
323
324 #[cfg(feature = "tls")]
325 fn tls_connector(&self, target: &Target) -> Result<Option<tokio_rustls::TlsConnector>, WsHandshakeError> {
326 let _ = rustls::crypto::ring::default_provider().install_default();
331
332 let connector = match target._mode {
333 Mode::Tls => Some(build_tls_config(&self.certificate_store)?),
334 Mode::Plain => None,
335 };
336 Ok(connector)
337 }
338
339 async fn try_connect_over_tcp(
341 &self,
342 uri: Url,
343 ) -> Result<(Sender<Compat<EitherStream>>, Receiver<Compat<EitherStream>>), WsHandshakeError> {
344 let mut target: Target = uri.clone().try_into()?;
345 let mut err = None;
346
347 #[cfg(feature = "tls")]
349 let mut connector = self.tls_connector(&target)?;
350
351 let mut target_sockaddrs = uri.socket_addrs(|| None).map_err(WsHandshakeError::ResolutionFailed)?;
353
354 for _ in 0..self.max_redirections {
355 tracing::debug!(target: LOG_TARGET, "Connecting to target: {:?}", target);
356
357 let sockaddrs = std::mem::take(&mut target_sockaddrs);
358
359 for sockaddr in &sockaddrs {
360 #[cfg(feature = "tls")]
361 let tcp_stream = match connect(
362 *sockaddr,
363 self.connection_timeout,
364 &target.host,
365 connector.as_ref(),
366 self.tcp_no_delay,
367 )
368 .await
369 {
370 Ok(stream) => stream,
371 Err(e) => {
372 tracing::debug!(target: LOG_TARGET, "Failed to connect to sockaddr: {:?}", sockaddr);
373 err = Some(Err(e));
374 continue;
375 }
376 };
377
378 #[cfg(not(feature = "tls"))]
379 let tcp_stream = match connect(*sockaddr, self.connection_timeout).await {
380 Ok(stream) => stream,
381 Err(e) => {
382 tracing::debug!(target: LOG_TARGET, "Failed to connect to sockaddr: {:?}", sockaddr);
383 err = Some(Err(e));
384 continue;
385 }
386 };
387
388 match self.try_connect(&target, tcp_stream.compat()).await {
389 Ok(result) => return Ok(result),
390
391 Err(WsHandshakeError::Redirected { status_code, location }) => {
392 tracing::debug!(target: LOG_TARGET, "Redirection: status_code: {}, location: {}", status_code, location);
393 match Url::parse(&location) {
394 Ok(uri) => {
396 target_sockaddrs = uri.socket_addrs(|| None).map_err(|e| {
398 tracing::debug!(target: LOG_TARGET, "Redirection failed: {:?}", e);
399 e
400 })?;
401
402 target = uri.try_into().map_err(|e| {
403 tracing::debug!(target: LOG_TARGET, "Redirection failed: {:?}", e);
404 e
405 })?;
406
407 #[cfg(feature = "tls")]
409 match target._mode {
410 Mode::Tls if connector.is_none() => {
411 connector = Some(build_tls_config(&self.certificate_store)?);
412 }
413 Mode::Tls => (),
414 Mode::Plain => {
416 connector = None;
417 }
418 };
419 }
420
421 Err(url::ParseError::RelativeUrlWithoutBase) => {
423 if location.starts_with('/') {
425 target.path_and_query = location;
426 } else {
427 match target.path_and_query.rfind('/') {
428 Some(offset) => target.path_and_query.replace_range(offset + 1.., &location),
429 None => {
430 let e = format!("path_and_query: {location}; this is a bug it must contain `/` please open issue");
431 err = Some(Err(WsHandshakeError::Url(e.into())));
432 continue;
433 }
434 };
435 }
436 target_sockaddrs = sockaddrs;
437 break;
438 }
439
440 Err(e) => {
441 err = Some(Err(WsHandshakeError::Url(e.to_string().into())));
442 }
443 };
444 }
445
446 Err(e) => {
447 err = Some(Err(e));
448 }
449 };
450 }
451 }
452 err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host)))
453 }
454
455 async fn try_connect<T>(
457 &self,
458 target: &Target,
459 data_stream: T,
460 ) -> Result<(Sender<T>, Receiver<T>), WsHandshakeError>
461 where
462 T: futures_util::AsyncRead + futures_util::AsyncWrite + Unpin,
463 {
464 let mut client = WsHandshakeClient::new(
465 BufReader::new(BufWriter::new(data_stream)),
466 &target.host_header,
467 &target.path_and_query,
468 );
469
470 let headers: Vec<_> = match &target.basic_auth {
471 Some(basic_auth) if !self.headers.contains_key(http::header::AUTHORIZATION) => {
472 let it1 =
473 self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() });
474 let it2 = std::iter::once(Header {
475 name: http::header::AUTHORIZATION.as_str(),
476 value: basic_auth.as_bytes(),
477 });
478
479 it1.chain(it2).collect()
480 }
481 _ => {
482 self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }).collect()
483 }
484 };
485
486 client.set_headers(&headers);
487
488 match client.handshake().await {
490 Ok(ServerResponse::Accepted { .. }) => {
491 tracing::debug!(target: LOG_TARGET, "Connection established to target: {:?}", target);
492 let mut builder = client.into_builder();
493 builder.set_max_message_size(self.max_response_size as usize);
494 let (sender, receiver) = builder.finish();
495 Ok((Sender { inner: sender, max_request_size: self.max_request_size }, Receiver { inner: receiver }))
496 }
497
498 Ok(ServerResponse::Rejected { status_code }) => {
499 tracing::debug!(target: LOG_TARGET, "Connection rejected: {:?}", status_code);
500 Err(WsHandshakeError::Rejected { status_code })
501 }
502
503 Ok(ServerResponse::Redirect { status_code, location }) => {
504 tracing::debug!(target: LOG_TARGET, "Redirection: status_code: {}, location: {}", status_code, location);
505 Err(WsHandshakeError::Redirected { status_code, location })
506 }
507
508 Err(e) => Err(e.into()),
509 }
510 }
511}
512
513#[cfg(feature = "tls")]
514async fn connect(
515 sockaddr: SocketAddr,
516 timeout_dur: Duration,
517 host: &str,
518 tls_connector: Option<&tokio_rustls::TlsConnector>,
519 tcp_no_delay: bool,
520) -> Result<EitherStream, WsHandshakeError> {
521 let socket = TcpStream::connect(sockaddr);
522 let timeout = tokio::time::sleep(timeout_dur);
523 tokio::select! {
524 socket = socket => {
525 let socket = socket?;
526 if let Err(err) = socket.set_nodelay(tcp_no_delay) {
527 tracing::warn!(target: LOG_TARGET, "set nodelay failed: {:?}", err);
528 }
529 match tls_connector {
530 None => Ok(EitherStream::Plain(socket)),
531 Some(connector) => {
532 let server_name: rustls_pki_types::ServerName = host.try_into().map_err(|e| WsHandshakeError::Url(format!("Invalid host: {host} {e:?}").into()))?;
533 let tls_stream = connector.connect(server_name.to_owned(), socket).await?;
534 Ok(EitherStream::Tls(tls_stream))
535 }
536 }
537 }
538 _ = timeout => Err(WsHandshakeError::Timeout(timeout_dur))
539 }
540}
541
542#[cfg(not(feature = "tls"))]
543async fn connect(sockaddr: SocketAddr, timeout_dur: Duration) -> Result<EitherStream, WsHandshakeError> {
544 let socket = TcpStream::connect(sockaddr);
545 let timeout = tokio::time::sleep(timeout_dur);
546 tokio::select! {
547 socket = socket => {
548 let socket = socket?;
549 if let Err(err) = socket.set_nodelay(true) {
550 tracing::warn!(target: LOG_TARGET, "set nodelay failed: {:?}", err);
551 }
552 Ok(EitherStream::Plain(socket))
553 }
554 _ = timeout => Err(WsHandshakeError::Timeout(timeout_dur))
555 }
556}
557
558impl From<io::Error> for WsHandshakeError {
559 fn from(err: io::Error) -> WsHandshakeError {
560 WsHandshakeError::Io(err)
561 }
562}
563
564impl From<soketto::handshake::Error> for WsHandshakeError {
565 fn from(err: soketto::handshake::Error) -> WsHandshakeError {
566 WsHandshakeError::Transport(err)
567 }
568}
569
570impl From<soketto::connection::Error> for WsError {
571 fn from(err: soketto::connection::Error) -> Self {
572 WsError::Connection(err)
573 }
574}
575
576#[derive(Debug, Clone, PartialEq, Eq)]
578pub(crate) struct Target {
579 host: String,
581 host_header: String,
583 _mode: Mode,
585 path_and_query: String,
587 basic_auth: Option<HeaderValue>,
589}
590
591impl TryFrom<url::Url> for Target {
592 type Error = WsHandshakeError;
593
594 fn try_from(url: Url) -> Result<Self, Self::Error> {
595 let _mode = match url.scheme() {
596 "ws" => Mode::Plain,
597 #[cfg(feature = "tls")]
598 "wss" => Mode::Tls,
599 invalid_scheme => {
600 #[cfg(feature = "tls")]
601 let err = format!("`{invalid_scheme}` not supported, expects 'ws' or 'wss'");
602 #[cfg(not(feature = "tls"))]
603 let err = format!("`{invalid_scheme}` not supported, expects 'ws' ('wss' requires the tls feature)");
604 return Err(WsHandshakeError::Url(err.into()));
605 }
606 };
607 let host = url.host_str().map(ToOwned::to_owned).ok_or_else(|| WsHandshakeError::Url("Invalid host".into()))?;
608
609 let mut path_and_query = url.path().to_owned();
610 if let Some(query) = url.query() {
611 path_and_query.push('?');
612 path_and_query.push_str(query);
613 }
614
615 let basic_auth = if let Some(pwd) = url.password() {
616 let digest = base64::engine::general_purpose::STANDARD.encode(format!("{}:{}", url.username(), pwd));
617 let val = HeaderValue::from_str(&format!("Basic {digest}"))
618 .map_err(|_| WsHandshakeError::Url("Header value `authorization basic user:pwd` invalid".into()))?;
619
620 Some(val)
621 } else {
622 None
623 };
624
625 let host_header = if let Some(port) = url.port() { format!("{host}:{port}") } else { host.to_string() };
626
627 Ok(Self { host, host_header, _mode, path_and_query: path_and_query.to_string(), basic_auth })
628 }
629}
630
631#[cfg(feature = "tls")]
633fn build_tls_config(cert_store: &CertificateStore) -> Result<tokio_rustls::TlsConnector, WsHandshakeError> {
634 let config = match cert_store {
635 #[cfg(feature = "tls-rustls-platform-verifier")]
636 CertificateStore::Native => rustls_platform_verifier::tls_config(),
637 #[cfg(not(feature = "tls-rustls-platform-verifier"))]
638 CertificateStore::Native => {
639 return Err(WsHandshakeError::CertificateStore(io::Error::new(
640 io::ErrorKind::Other,
641 "Native certificate store not supported, either call `Builder::with_custom_cert_store` or enable the `tls-rustls-platform-verifier` feature.",
642 )))
643 }
644 CertificateStore::Custom(cfg) => cfg.clone(),
645 };
646
647 Ok(std::sync::Arc::new(config).into())
648}
649
650#[cfg(test)]
651mod tests {
652 use http::HeaderValue;
653
654 use super::{Mode, Target, Url, WsHandshakeError};
655
656 fn assert_ws_target(
657 target: Target,
658 host: &str,
659 host_header: &str,
660 mode: Mode,
661 path_and_query: &str,
662 basic_auth: Option<HeaderValue>,
663 ) {
664 assert_eq!(&target.host, host);
665 assert_eq!(&target.host_header, host_header);
666 assert_eq!(target._mode, mode);
667 assert_eq!(&target.path_and_query, path_and_query);
668 assert_eq!(target.basic_auth, basic_auth);
669 }
670
671 fn parse_target(uri: &str) -> Result<Target, WsHandshakeError> {
672 Url::parse(uri).map_err(|e| WsHandshakeError::Url(e.to_string().into()))?.try_into()
673 }
674
675 #[test]
676 fn ws_works_with_port() {
677 let target = parse_target("ws://127.0.0.1:9933").unwrap();
678 assert_ws_target(target, "127.0.0.1", "127.0.0.1:9933", Mode::Plain, "/", None);
679 }
680
681 #[cfg(feature = "tls")]
682 #[test]
683 fn wss_works_with_port() {
684 let target = parse_target("wss://kusama-rpc.polkadot.io:9999").unwrap();
685 assert_ws_target(target, "kusama-rpc.polkadot.io", "kusama-rpc.polkadot.io:9999", Mode::Tls, "/", None);
686 }
687
688 #[cfg(not(feature = "tls"))]
689 #[test]
690 fn wss_fails_with_tls_feature() {
691 let err = parse_target("wss://kusama-rpc.polkadot.io").unwrap_err();
692 assert!(matches!(err, WsHandshakeError::Url(_)));
693 }
694
695 #[test]
696 fn faulty_url_scheme() {
697 let err = parse_target("http://kusama-rpc.polkadot.io:443").unwrap_err();
698 assert!(matches!(err, WsHandshakeError::Url(_)));
699 }
700
701 #[test]
702 fn faulty_port() {
703 let err = parse_target("ws://127.0.0.1:-43").unwrap_err();
704 assert!(matches!(err, WsHandshakeError::Url(_)));
705 let err = parse_target("ws://127.0.0.1:99999").unwrap_err();
706 assert!(matches!(err, WsHandshakeError::Url(_)));
707 }
708
709 #[test]
710 fn url_with_path_works() {
711 let target = parse_target("ws://127.0.0.1/my-special-path").unwrap();
712 assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my-special-path", None);
713 }
714
715 #[test]
716 fn url_with_query_works() {
717 let target = parse_target("ws://127.0.0.1/my?name1=value1&name2=value2").unwrap();
718 assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my?name1=value1&name2=value2", None);
719 }
720
721 #[test]
722 fn url_with_fragment_is_ignored() {
723 let target = parse_target("ws://127.0.0.1:/my.htm#ignore").unwrap();
724 assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/my.htm", None);
725 }
726
727 #[cfg(feature = "tls")]
728 #[test]
729 fn wss_default_port_is_omitted() {
730 let target = parse_target("wss://127.0.0.1:443").unwrap();
731 assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Tls, "/", None);
732 }
733
734 #[test]
735 fn ws_default_port_is_omitted() {
736 let target = parse_target("ws://127.0.0.1:80").unwrap();
737 assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/", None);
738 }
739
740 #[test]
741 fn ws_with_username_and_password() {
742 use base64::Engine;
743
744 let target = parse_target("ws://user:pwd@127.0.0.1").unwrap();
745 let digest = base64::engine::general_purpose::STANDARD.encode("user:pwd");
746 let basic_auth = HeaderValue::from_str(&format!("Basic {digest}")).unwrap();
747
748 assert_ws_target(target, "127.0.0.1", "127.0.0.1", Mode::Plain, "/", Some(basic_auth));
749 }
750}