1use std::io::{Read, Write};
3
4use crate::{
5 client::{client_with_config, uri_mode, IntoClientRequest},
6 error::UrlError,
7 handshake::client::Response,
8 protocol::WebSocketConfig,
9 stream::MaybeTlsStream,
10 ClientHandshake, Error, HandshakeError, Result, WebSocket,
11};
12
13#[non_exhaustive]
17#[allow(missing_debug_implementations)]
18pub enum Connector {
19 Plain,
21 #[cfg(feature = "native-tls")]
23 NativeTls(native_tls_crate::TlsConnector),
24 #[cfg(feature = "__rustls-tls")]
26 Rustls(std::sync::Arc<rustls::ClientConfig>),
27}
28
29mod encryption {
30 #[cfg(feature = "native-tls")]
31 pub mod native_tls {
32 use native_tls_crate::{HandshakeError as TlsHandshakeError, TlsConnector};
33
34 use std::io::{Read, Write};
35
36 use crate::{
37 error::TlsError,
38 stream::{MaybeTlsStream, Mode},
39 Error, Result,
40 };
41
42 pub fn wrap_stream<S>(
43 socket: S,
44 domain: &str,
45 mode: Mode,
46 tls_connector: Option<TlsConnector>,
47 ) -> Result<MaybeTlsStream<S>>
48 where
49 S: Read + Write,
50 {
51 match mode {
52 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
53 Mode::Tls => {
54 let try_connector = tls_connector.map_or_else(TlsConnector::new, Ok);
55 let connector = try_connector.map_err(TlsError::Native)?;
56 let connected = connector.connect(domain, socket);
57 match connected {
58 Err(e) => match e {
59 TlsHandshakeError::Failure(f) => Err(Error::Tls(f.into())),
60 TlsHandshakeError::WouldBlock(_) => {
61 panic!("Bug: TLS handshake not blocked")
62 }
63 },
64 Ok(s) => Ok(MaybeTlsStream::NativeTls(s)),
65 }
66 }
67 }
68 }
69 }
70
71 #[cfg(feature = "__rustls-tls")]
72 pub mod rustls {
73 use rustls::{ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned};
74
75 use std::{
76 convert::TryFrom,
77 io::{Read, Write},
78 sync::Arc,
79 };
80
81 use crate::{
82 error::TlsError,
83 stream::{MaybeTlsStream, Mode},
84 Result,
85 };
86
87 pub fn wrap_stream<S>(
88 socket: S,
89 domain: &str,
90 mode: Mode,
91 tls_connector: Option<Arc<ClientConfig>>,
92 ) -> Result<MaybeTlsStream<S>>
93 where
94 S: Read + Write,
95 {
96 match mode {
97 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
98 Mode::Tls => {
99 let config = match tls_connector {
100 Some(config) => config,
101 None => {
102 #[allow(unused_mut)]
103 let mut root_store = RootCertStore::empty();
104
105 #[cfg(feature = "rustls-tls-native-roots")]
106 {
107 let native_certs = rustls_native_certs::load_native_certs()?;
108 let der_certs: Vec<Vec<u8>> =
109 native_certs.into_iter().map(|cert| cert.0).collect();
110 let total_number = der_certs.len();
111 let (number_added, number_ignored) =
112 root_store.add_parsable_certificates(&der_certs);
113 log::debug!("Added {number_added}/{total_number} native root certificates (ignored {number_ignored})");
114 }
115 #[cfg(feature = "rustls-tls-webpki-roots")]
116 {
117 root_store.add_server_trust_anchors(
118 webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
119 rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
120 ta.subject,
121 ta.spki,
122 ta.name_constraints,
123 )
124 })
125 );
126 }
127
128 Arc::new(
129 ClientConfig::builder()
130 .with_safe_defaults()
131 .with_root_certificates(root_store)
132 .with_no_client_auth(),
133 )
134 }
135 };
136 let domain =
137 ServerName::try_from(domain).map_err(|_| TlsError::InvalidDnsName)?;
138 let client = ClientConnection::new(config, domain).map_err(TlsError::Rustls)?;
139 let stream = StreamOwned::new(client, socket);
140
141 Ok(MaybeTlsStream::Rustls(stream))
142 }
143 }
144 }
145 }
146
147 pub mod plain {
148 use std::io::{Read, Write};
149
150 use crate::{
151 error::UrlError,
152 stream::{MaybeTlsStream, Mode},
153 Error, Result,
154 };
155
156 pub fn wrap_stream<S>(socket: S, mode: Mode) -> Result<MaybeTlsStream<S>>
157 where
158 S: Read + Write,
159 {
160 match mode {
161 Mode::Plain => Ok(MaybeTlsStream::Plain(socket)),
162 Mode::Tls => Err(Error::Url(UrlError::TlsFeatureNotEnabled)),
163 }
164 }
165 }
166}
167
168type TlsHandshakeError<S> = HandshakeError<ClientHandshake<MaybeTlsStream<S>>>;
169
170pub fn client_tls<R, S>(
173 request: R,
174 stream: S,
175) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
176where
177 R: IntoClientRequest,
178 S: Read + Write,
179{
180 client_tls_with_config(request, stream, None, None)
181}
182
183pub fn client_tls_with_config<R, S>(
189 request: R,
190 stream: S,
191 config: Option<WebSocketConfig>,
192 connector: Option<Connector>,
193) -> Result<(WebSocket<MaybeTlsStream<S>>, Response), TlsHandshakeError<S>>
194where
195 R: IntoClientRequest,
196 S: Read + Write,
197{
198 let request = request.into_client_request()?;
199
200 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
201 let domain = match request.uri().host() {
202 Some(d) => Ok(d.to_string()),
203 None => Err(Error::Url(UrlError::NoHostName)),
204 }?;
205
206 let mode = uri_mode(request.uri())?;
207
208 let stream = match connector {
209 Some(conn) => match conn {
210 #[cfg(feature = "native-tls")]
211 Connector::NativeTls(conn) => {
212 self::encryption::native_tls::wrap_stream(stream, &domain, mode, Some(conn))
213 }
214 #[cfg(feature = "__rustls-tls")]
215 Connector::Rustls(conn) => {
216 self::encryption::rustls::wrap_stream(stream, &domain, mode, Some(conn))
217 }
218 Connector::Plain => self::encryption::plain::wrap_stream(stream, mode),
219 },
220 None => {
221 #[cfg(feature = "native-tls")]
222 {
223 self::encryption::native_tls::wrap_stream(stream, &domain, mode, None)
224 }
225 #[cfg(all(feature = "__rustls-tls", not(feature = "native-tls")))]
226 {
227 self::encryption::rustls::wrap_stream(stream, &domain, mode, None)
228 }
229 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
230 {
231 self::encryption::plain::wrap_stream(stream, mode)
232 }
233 }
234 }?;
235
236 client_with_config(request, stream, config)
237}