1use tokio::io::{AsyncRead, AsyncWrite};
3
4use tungstenite::{
5 client::uri_mode, error::Error, handshake::client::Response, protocol::WebSocketConfig,
6};
7
8use crate::{client_async_with_config, IntoClientRequest, WebSocketStream};
9
10pub use crate::stream::MaybeTlsStream;
11
12#[non_exhaustive]
16#[derive(Clone)]
17pub enum Connector {
18 Plain,
20 #[cfg(feature = "native-tls")]
22 NativeTls(native_tls_crate::TlsConnector),
23 #[cfg(feature = "__rustls-tls")]
25 Rustls(std::sync::Arc<rustls::ClientConfig>),
26}
27
28mod encryption {
29 #[cfg(feature = "native-tls")]
30 pub mod native_tls {
31 use native_tls_crate::TlsConnector;
32 use tokio_native_tls::TlsConnector as TokioTlsConnector;
33
34 use tokio::io::{AsyncRead, AsyncWrite};
35
36 use tungstenite::{error::TlsError, stream::Mode, Error};
37
38 use crate::stream::MaybeTlsStream;
39
40 pub async fn wrap_stream<S>(
41 socket: S,
42 domain: String,
43 mode: Mode,
44 tls_connector: Option<TlsConnector>,
45 ) -> Result<MaybeTlsStream<S>, Error>
46 where
47 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
48 {
49 match mode {
50 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
51 Mode::Tls => {
52 let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
53 let connector = try_connector.map_err(TlsError::Native)?;
54 let stream = TokioTlsConnector::from(connector);
55 let connected = stream.connect(&domain, socket).await;
56 match connected {
57 Err(e) => Err(Error::Tls(e.into())),
58 Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
59 }
60 }
61 }
62 }
63 }
64
65 #[cfg(feature = "__rustls-tls")]
66 pub mod rustls {
67 pub use rustls::ClientConfig;
68 use rustls::{RootCertStore, ServerName};
69 use tokio_rustls::TlsConnector as TokioTlsConnector;
70
71 use std::{convert::TryFrom, sync::Arc};
72 use tokio::io::{AsyncRead, AsyncWrite};
73
74 use tungstenite::{error::TlsError, stream::Mode, Error};
75
76 use crate::stream::MaybeTlsStream;
77
78 pub async fn wrap_stream<S>(
79 socket: S,
80 domain: String,
81 mode: Mode,
82 tls_connector: Option<Arc<ClientConfig>>,
83 ) -> Result<MaybeTlsStream<S>, Error>
84 where
85 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
86 {
87 match mode {
88 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
89 Mode::Tls => {
90 let config = match tls_connector {
91 Some(config) => config,
92 None => {
93 #[allow(unused_mut)]
94 let mut root_store = RootCertStore::empty();
95 #[cfg(feature = "rustls-tls-native-roots")]
96 {
97 let native_certs = rustls_native_certs::load_native_certs()?;
98 let der_certs: Vec<Vec<u8>> =
99 native_certs.into_iter().map(|cert| cert.0).collect();
100 let total_number = der_certs.len();
101 let (number_added, number_ignored) =
102 root_store.add_parsable_certificates(&der_certs);
103 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
104 }
105 #[cfg(feature = "rustls-tls-webpki-roots")]
106 {
107 root_store.add_trust_anchors(
108 webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| {
109 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
110 ta.subject,
111 ta.spki,
112 ta.name_constraints,
113 )
114 })
115 );
116 }
117
118 Arc::new(
119 ClientConfig::builder()
120 .with_safe_defaults()
121 .with_root_certificates(root_store)
122 .with_no_client_auth(),
123 )
124 }
125 };
126 let domain = ServerName::try_from(domain.as_str())
127 .map_err(|_| TlsError::InvalidDnsName)?;
128 let stream = TokioTlsConnector::from(config);
129 let connected = stream.connect(domain, socket).await;
130
131 match connected {
132 Err(e) => Err(Error::Io(e)),
133 Ok(s) => Ok(MaybeTlsStream::Rustls(s)),
134 }
135 }
136 }
137 }
138 }
139
140 pub mod plain {
141 use tokio::io::{AsyncRead, AsyncWrite};
142
143 use tungstenite::{
144 error::{Error, UrlError},
145 stream::Mode,
146 };
147
148 use crate::stream::MaybeTlsStream;
149
150 pub async fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>, Error>
151 where
152 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
153 {
154 match mode {
155 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
156 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
157 }
158 }
159 }
160}
161
162#[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
165pub async fn client_async_tls<R, S>(
166 request: R,
167 stream: S,
168) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
169where
170 R: IntoClientRequest + Unpin,
171 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
172 MaybeTlsStream<S>: Unpin,
173{
174 client_async_tls_with_config(request, stream, None, None).await
175}
176
177pub async fn client_async_tls_with_config<R, S>(
183 request: R,
184 stream: S,
185 config: Option<WebSocketConfig>,
186 connector: Option<Connector>,
187) -> Result<(WebSocketStream<MaybeTlsStream<S>>, Response), Error>
188where
189 R: IntoClientRequest + Unpin,
190 S: 'static + AsyncRead + AsyncWrite + Send + Unpin,
191 MaybeTlsStream<S>: Unpin,
192{
193 let request = request.into_client_request()?;
194
195 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
196 let domain = crate::domain(&request)?;
197
198 let mode = uri_mode(request.uri())?;
200
201 let stream = match connector {
202 Some(conn) => match conn {
203 #[cfg(feature = "native-tls")]
204 Connector::NativeTls(conn) => {
205 self::encryption::native_tls::wrap_stream(stream, domain, mode, Some(conn)).await
206 }
207 #[cfg(feature = "__rustls-tls")]
208 Connector::Rustls(conn) => {
209 self::encryption::rustls::wrap_stream(stream, domain, mode, Some(conn)).await
210 }
211 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode).await,
212 },
213 None => {
214 #[cfg(feature = "native-tls")]
215 {
216 self::encryption::native_tls::wrap_stream(stream, domain, mode, None).await
217 }
218 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
219 {
220 self::encryption::rustls::wrap_stream(stream, domain, mode, None).await
221 }
222 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
223 {
224 self::encryption::plain::wrap_stream(stream, mode).await
225 }
226 }
227 }?;
228
229 client_async_with_config(request, stream, config).await
230}