hyper_util/client/legacy/connect/
http.rs

1use std::error::Error as StdError;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::marker::PhantomData;
6use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7use std::pin::Pin;
8use std::sync::Arc;
9use std::task::{self, Poll};
10use std::time::Duration;
11
12use futures_util::future::Either;
13use http::uri::{Scheme, Uri};
14use pin_project_lite::pin_project;
15use socket2::TcpKeepalive;
16use tokio::net::{TcpSocket, TcpStream};
17use tokio::time::Sleep;
18use tracing::{debug, trace, warn};
19
20use super::dns::{self, resolve, GaiResolver, Resolve};
21use super::{Connected, Connection};
22use crate::rt::TokioIo;
23
24/// A connector for the `http` scheme.
25///
26/// Performs DNS resolution in a thread pool, and then connects over TCP.
27///
28/// # Note
29///
30/// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
31/// transport information such as the remote socket address used.
32#[derive(Clone)]
33pub struct HttpConnector<R = GaiResolver> {
34    config: Arc<Config>,
35    resolver: R,
36}
37
38/// Extra information about the transport when an HttpConnector is used.
39///
40/// # Example
41///
42/// ```
43/// # fn doc(res: http::Response<()>) {
44/// use hyper_util::client::legacy::connect::HttpInfo;
45///
46/// // res = http::Response
47/// res
48///     .extensions()
49///     .get::<HttpInfo>()
50///     .map(|info| {
51///         println!("remote addr = {}", info.remote_addr());
52///     });
53/// # }
54/// ```
55///
56/// # Note
57///
58/// If a different connector is used besides [`HttpConnector`](HttpConnector),
59/// this value will not exist in the extensions. Consult that specific
60/// connector to see what "extra" information it might provide to responses.
61#[derive(Clone, Debug)]
62pub struct HttpInfo {
63    remote_addr: SocketAddr,
64    local_addr: SocketAddr,
65}
66
67#[derive(Clone)]
68struct Config {
69    connect_timeout: Option<Duration>,
70    enforce_http: bool,
71    happy_eyeballs_timeout: Option<Duration>,
72    tcp_keepalive_config: TcpKeepaliveConfig,
73    local_address_ipv4: Option<Ipv4Addr>,
74    local_address_ipv6: Option<Ipv6Addr>,
75    nodelay: bool,
76    reuse_address: bool,
77    send_buffer_size: Option<usize>,
78    recv_buffer_size: Option<usize>,
79    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
80    interface: Option<String>,
81}
82
83#[derive(Default, Debug, Clone, Copy)]
84struct TcpKeepaliveConfig {
85    time: Option<Duration>,
86    interval: Option<Duration>,
87    retries: Option<u32>,
88}
89
90impl TcpKeepaliveConfig {
91    /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
92    fn into_tcpkeepalive(self) -> Option<TcpKeepalive> {
93        let mut dirty = false;
94        let mut ka = TcpKeepalive::new();
95        if let Some(time) = self.time {
96            ka = ka.with_time(time);
97            dirty = true
98        }
99        if let Some(interval) = self.interval {
100            ka = Self::ka_with_interval(ka, interval, &mut dirty)
101        };
102        if let Some(retries) = self.retries {
103            ka = Self::ka_with_retries(ka, retries, &mut dirty)
104        };
105        if dirty {
106            Some(ka)
107        } else {
108            None
109        }
110    }
111
112    #[cfg(not(any(
113        target_os = "aix",
114        target_os = "openbsd",
115        target_os = "redox",
116        target_os = "solaris"
117    )))]
118    fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
119        *dirty = true;
120        ka.with_interval(interval)
121    }
122
123    #[cfg(any(
124        target_os = "aix",
125        target_os = "openbsd",
126        target_os = "redox",
127        target_os = "solaris"
128    ))]
129    fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
130        ka // no-op as keepalive interval is not supported on this platform
131    }
132
133    #[cfg(not(any(
134        target_os = "aix",
135        target_os = "openbsd",
136        target_os = "redox",
137        target_os = "solaris",
138        target_os = "windows"
139    )))]
140    fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
141        *dirty = true;
142        ka.with_retries(retries)
143    }
144
145    #[cfg(any(
146        target_os = "aix",
147        target_os = "openbsd",
148        target_os = "redox",
149        target_os = "solaris",
150        target_os = "windows"
151    ))]
152    fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
153        ka // no-op as keepalive retries is not supported on this platform
154    }
155}
156
157// ===== impl HttpConnector =====
158
159impl HttpConnector {
160    /// Construct a new HttpConnector.
161    pub fn new() -> HttpConnector {
162        HttpConnector::new_with_resolver(GaiResolver::new())
163    }
164}
165
166impl<R> HttpConnector<R> {
167    /// Construct a new HttpConnector.
168    ///
169    /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
170    pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
171        HttpConnector {
172            config: Arc::new(Config {
173                connect_timeout: None,
174                enforce_http: true,
175                happy_eyeballs_timeout: Some(Duration::from_millis(300)),
176                tcp_keepalive_config: TcpKeepaliveConfig::default(),
177                local_address_ipv4: None,
178                local_address_ipv6: None,
179                nodelay: false,
180                reuse_address: false,
181                send_buffer_size: None,
182                recv_buffer_size: None,
183                #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
184                interface: None,
185            }),
186            resolver,
187        }
188    }
189
190    /// Option to enforce all `Uri`s have the `http` scheme.
191    ///
192    /// Enabled by default.
193    #[inline]
194    pub fn enforce_http(&mut self, is_enforced: bool) {
195        self.config_mut().enforce_http = is_enforced;
196    }
197
198    /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration
199    /// to remain idle before sending TCP keepalive probes.
200    ///
201    /// If `None`, keepalive is disabled.
202    ///
203    /// Default is `None`.
204    #[inline]
205    pub fn set_keepalive(&mut self, time: Option<Duration>) {
206        self.config_mut().tcp_keepalive_config.time = time;
207    }
208
209    /// Set the duration between two successive TCP keepalive retransmissions,
210    /// if acknowledgement to the previous keepalive transmission is not received.
211    #[inline]
212    pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) {
213        self.config_mut().tcp_keepalive_config.interval = interval;
214    }
215
216    /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
217    #[inline]
218    pub fn set_keepalive_retries(&mut self, retries: Option<u32>) {
219        self.config_mut().tcp_keepalive_config.retries = retries;
220    }
221
222    /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
223    ///
224    /// Default is `false`.
225    #[inline]
226    pub fn set_nodelay(&mut self, nodelay: bool) {
227        self.config_mut().nodelay = nodelay;
228    }
229
230    /// Sets the value of the SO_SNDBUF option on the socket.
231    #[inline]
232    pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
233        self.config_mut().send_buffer_size = size;
234    }
235
236    /// Sets the value of the SO_RCVBUF option on the socket.
237    #[inline]
238    pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
239        self.config_mut().recv_buffer_size = size;
240    }
241
242    /// Set that all sockets are bound to the configured address before connection.
243    ///
244    /// If `None`, the sockets will not be bound.
245    ///
246    /// Default is `None`.
247    #[inline]
248    pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
249        let (v4, v6) = match addr {
250            Some(IpAddr::V4(a)) => (Some(a), None),
251            Some(IpAddr::V6(a)) => (None, Some(a)),
252            _ => (None, None),
253        };
254
255        let cfg = self.config_mut();
256
257        cfg.local_address_ipv4 = v4;
258        cfg.local_address_ipv6 = v6;
259    }
260
261    /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
262    /// preferences) before connection.
263    #[inline]
264    pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
265        let cfg = self.config_mut();
266
267        cfg.local_address_ipv4 = Some(addr_ipv4);
268        cfg.local_address_ipv6 = Some(addr_ipv6);
269    }
270
271    /// Set the connect timeout.
272    ///
273    /// If a domain resolves to multiple IP addresses, the timeout will be
274    /// evenly divided across them.
275    ///
276    /// Default is `None`.
277    #[inline]
278    pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
279        self.config_mut().connect_timeout = dur;
280    }
281
282    /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
283    ///
284    /// If hostname resolves to both IPv4 and IPv6 addresses and connection
285    /// cannot be established using preferred address family before timeout
286    /// elapses, then connector will in parallel attempt connection using other
287    /// address family.
288    ///
289    /// If `None`, parallel connection attempts are disabled.
290    ///
291    /// Default is 300 milliseconds.
292    ///
293    /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
294    #[inline]
295    pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
296        self.config_mut().happy_eyeballs_timeout = dur;
297    }
298
299    /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
300    ///
301    /// Default is `false`.
302    #[inline]
303    pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
304        self.config_mut().reuse_address = reuse_address;
305        self
306    }
307
308    /// Sets the value for the `SO_BINDTODEVICE` option on this socket.
309    ///
310    /// If a socket is bound to an interface, only packets received from that particular
311    /// interface are processed by the socket. Note that this only works for some socket
312    /// types, particularly AF_INET sockets.
313    ///
314    /// On Linux it can be used to specify a [VRF], but the binary needs
315    /// to either have `CAP_NET_RAW` or to be run as root.
316    ///
317    /// This function is only available on Android、Fuchsia and Linux.
318    ///
319    /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt
320    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
321    #[inline]
322    pub fn set_interface<S: Into<String>>(&mut self, interface: S) -> &mut Self {
323        self.config_mut().interface = Some(interface.into());
324        self
325    }
326
327    // private
328
329    fn config_mut(&mut self) -> &mut Config {
330        // If the are HttpConnector clones, this will clone the inner
331        // config. So mutating the config won't ever affect previous
332        // clones.
333        Arc::make_mut(&mut self.config)
334    }
335}
336
337static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
338static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
339static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
340
341// R: Debug required for now to allow adding it to debug output later...
342impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
343    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
344        f.debug_struct("HttpConnector").finish()
345    }
346}
347
348impl<R> tower_service::Service<Uri> for HttpConnector<R>
349where
350    R: Resolve + Clone + Send + Sync + 'static,
351    R::Future: Send,
352{
353    type Response = TokioIo<TcpStream>;
354    type Error = ConnectError;
355    type Future = HttpConnecting<R>;
356
357    fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
358        futures_util::ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
359        Poll::Ready(Ok(()))
360    }
361
362    fn call(&mut self, dst: Uri) -> Self::Future {
363        let mut self_ = self.clone();
364        HttpConnecting {
365            fut: Box::pin(async move { self_.call_async(dst).await }),
366            _marker: PhantomData,
367        }
368    }
369}
370
371fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
372    trace!(
373        "Http::connect; scheme={:?}, host={:?}, port={:?}",
374        dst.scheme(),
375        dst.host(),
376        dst.port(),
377    );
378
379    if config.enforce_http {
380        if dst.scheme() != Some(&Scheme::HTTP) {
381            return Err(ConnectError {
382                msg: INVALID_NOT_HTTP.into(),
383                cause: None,
384            });
385        }
386    } else if dst.scheme().is_none() {
387        return Err(ConnectError {
388            msg: INVALID_MISSING_SCHEME.into(),
389            cause: None,
390        });
391    }
392
393    let host = match dst.host() {
394        Some(s) => s,
395        None => {
396            return Err(ConnectError {
397                msg: INVALID_MISSING_HOST.into(),
398                cause: None,
399            })
400        }
401    };
402    let port = match dst.port() {
403        Some(port) => port.as_u16(),
404        None => {
405            if dst.scheme() == Some(&Scheme::HTTPS) {
406                443
407            } else {
408                80
409            }
410        }
411    };
412
413    Ok((host, port))
414}
415
416impl<R> HttpConnector<R>
417where
418    R: Resolve,
419{
420    async fn call_async(&mut self, dst: Uri) -> Result<TokioIo<TcpStream>, ConnectError> {
421        let config = &self.config;
422
423        let (host, port) = get_host_port(config, &dst)?;
424        let host = host.trim_start_matches('[').trim_end_matches(']');
425
426        // If the host is already an IP addr (v4 or v6),
427        // skip resolving the dns and start connecting right away.
428        let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
429            addrs
430        } else {
431            let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
432                .await
433                .map_err(ConnectError::dns)?;
434            let addrs = addrs
435                .map(|mut addr| {
436                    set_port(&mut addr, port, dst.port().is_some());
437
438                    addr
439                })
440                .collect();
441            dns::SocketAddrs::new(addrs)
442        };
443
444        let c = ConnectingTcp::new(addrs, config);
445
446        let sock = c.connect().await?;
447
448        if let Err(e) = sock.set_nodelay(config.nodelay) {
449            warn!("tcp set_nodelay error: {}", e);
450        }
451
452        Ok(TokioIo::new(sock))
453    }
454}
455
456impl Connection for TcpStream {
457    fn connected(&self) -> Connected {
458        let connected = Connected::new();
459        if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
460            connected.extra(HttpInfo {
461                remote_addr,
462                local_addr,
463            })
464        } else {
465            connected
466        }
467    }
468}
469
470// Implement `Connection` for generic `TokioIo<T>` so that external crates can
471// implement their own `HttpConnector` with `TokioIo<CustomTcpStream>`.
472impl<T> Connection for TokioIo<T>
473where
474    T: Connection,
475{
476    fn connected(&self) -> Connected {
477        self.inner().connected()
478    }
479}
480
481impl HttpInfo {
482    /// Get the remote address of the transport used.
483    pub fn remote_addr(&self) -> SocketAddr {
484        self.remote_addr
485    }
486
487    /// Get the local address of the transport used.
488    pub fn local_addr(&self) -> SocketAddr {
489        self.local_addr
490    }
491}
492
493pin_project! {
494    // Not publicly exported (so missing_docs doesn't trigger).
495    //
496    // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
497    // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
498    // (and thus we can change the type in the future).
499    #[must_use = "futures do nothing unless polled"]
500    #[allow(missing_debug_implementations)]
501    pub struct HttpConnecting<R> {
502        #[pin]
503        fut: BoxConnecting,
504        _marker: PhantomData<R>,
505    }
506}
507
508type ConnectResult = Result<TokioIo<TcpStream>, ConnectError>;
509type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
510
511impl<R: Resolve> Future for HttpConnecting<R> {
512    type Output = ConnectResult;
513
514    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
515        self.project().fut.poll(cx)
516    }
517}
518
519// Not publicly exported (so missing_docs doesn't trigger).
520pub struct ConnectError {
521    msg: Box<str>,
522    cause: Option<Box<dyn StdError + Send + Sync>>,
523}
524
525impl ConnectError {
526    fn new<S, E>(msg: S, cause: E) -> ConnectError
527    where
528        S: Into<Box<str>>,
529        E: Into<Box<dyn StdError + Send + Sync>>,
530    {
531        ConnectError {
532            msg: msg.into(),
533            cause: Some(cause.into()),
534        }
535    }
536
537    fn dns<E>(cause: E) -> ConnectError
538    where
539        E: Into<Box<dyn StdError + Send + Sync>>,
540    {
541        ConnectError::new("dns error", cause)
542    }
543
544    fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
545    where
546        S: Into<Box<str>>,
547        E: Into<Box<dyn StdError + Send + Sync>>,
548    {
549        move |cause| ConnectError::new(msg, cause)
550    }
551}
552
553impl fmt::Debug for ConnectError {
554    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
555        if let Some(ref cause) = self.cause {
556            f.debug_tuple("ConnectError")
557                .field(&self.msg)
558                .field(cause)
559                .finish()
560        } else {
561            self.msg.fmt(f)
562        }
563    }
564}
565
566impl fmt::Display for ConnectError {
567    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568        f.write_str(&self.msg)?;
569
570        if let Some(ref cause) = self.cause {
571            write!(f, ": {}", cause)?;
572        }
573
574        Ok(())
575    }
576}
577
578impl StdError for ConnectError {
579    fn source(&self) -> Option<&(dyn StdError + 'static)> {
580        self.cause.as_ref().map(|e| &**e as _)
581    }
582}
583
584struct ConnectingTcp<'a> {
585    preferred: ConnectingTcpRemote,
586    fallback: Option<ConnectingTcpFallback>,
587    config: &'a Config,
588}
589
590impl<'a> ConnectingTcp<'a> {
591    fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
592        if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
593            let (preferred_addrs, fallback_addrs) = remote_addrs
594                .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
595            if fallback_addrs.is_empty() {
596                return ConnectingTcp {
597                    preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
598                    fallback: None,
599                    config,
600                };
601            }
602
603            ConnectingTcp {
604                preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
605                fallback: Some(ConnectingTcpFallback {
606                    delay: tokio::time::sleep(fallback_timeout),
607                    remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
608                }),
609                config,
610            }
611        } else {
612            ConnectingTcp {
613                preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
614                fallback: None,
615                config,
616            }
617        }
618    }
619}
620
621struct ConnectingTcpFallback {
622    delay: Sleep,
623    remote: ConnectingTcpRemote,
624}
625
626struct ConnectingTcpRemote {
627    addrs: dns::SocketAddrs,
628    connect_timeout: Option<Duration>,
629}
630
631impl ConnectingTcpRemote {
632    fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
633        let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32));
634
635        Self {
636            addrs,
637            connect_timeout,
638        }
639    }
640}
641
642impl ConnectingTcpRemote {
643    async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
644        let mut err = None;
645        for addr in &mut self.addrs {
646            debug!("connecting to {}", addr);
647            match connect(&addr, config, self.connect_timeout)?.await {
648                Ok(tcp) => {
649                    debug!("connected to {}", addr);
650                    return Ok(tcp);
651                }
652                Err(e) => {
653                    trace!("connect error for {}: {:?}", addr, e);
654                    err = Some(e);
655                }
656            }
657        }
658
659        match err {
660            Some(e) => Err(e),
661            None => Err(ConnectError::new(
662                "tcp connect error",
663                std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
664            )),
665        }
666    }
667}
668
669fn bind_local_address(
670    socket: &socket2::Socket,
671    dst_addr: &SocketAddr,
672    local_addr_ipv4: &Option<Ipv4Addr>,
673    local_addr_ipv6: &Option<Ipv6Addr>,
674) -> io::Result<()> {
675    match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
676        (SocketAddr::V4(_), Some(addr), _) => {
677            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
678        }
679        (SocketAddr::V6(_), _, Some(addr)) => {
680            socket.bind(&SocketAddr::new((*addr).into(), 0).into())?;
681        }
682        _ => {
683            if cfg!(windows) {
684                // Windows requires a socket be bound before calling connect
685                let any: SocketAddr = match *dst_addr {
686                    SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
687                    SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
688                };
689                socket.bind(&any.into())?;
690            }
691        }
692    }
693
694    Ok(())
695}
696
697fn connect(
698    addr: &SocketAddr,
699    config: &Config,
700    connect_timeout: Option<Duration>,
701) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
702    // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
703    // keepalive timeout, it would be nice to use that instead of socket2,
704    // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
705    use socket2::{Domain, Protocol, Socket, Type};
706
707    let domain = Domain::for_address(*addr);
708    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
709        .map_err(ConnectError::m("tcp open error"))?;
710
711    // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
712    // responsible for ensuring O_NONBLOCK is set.
713    socket
714        .set_nonblocking(true)
715        .map_err(ConnectError::m("tcp set_nonblocking error"))?;
716
717    if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() {
718        if let Err(e) = socket.set_tcp_keepalive(tcp_keepalive) {
719            warn!("tcp set_keepalive error: {}", e);
720        }
721    }
722
723    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
724    // That this only works for some socket types, particularly AF_INET sockets.
725    if let Some(interface) = &config.interface {
726        socket
727            .bind_device(Some(interface.as_bytes()))
728            .map_err(ConnectError::m("tcp bind interface error"))?;
729    }
730
731    bind_local_address(
732        &socket,
733        addr,
734        &config.local_address_ipv4,
735        &config.local_address_ipv6,
736    )
737    .map_err(ConnectError::m("tcp bind local error"))?;
738
739    #[cfg(unix)]
740    let socket = unsafe {
741        // Safety: `from_raw_fd` is only safe to call if ownership of the raw
742        // file descriptor is transferred. Since we call `into_raw_fd` on the
743        // socket2 socket, it gives up ownership of the fd and will not close
744        // it, so this is safe.
745        use std::os::unix::io::{FromRawFd, IntoRawFd};
746        TcpSocket::from_raw_fd(socket.into_raw_fd())
747    };
748    #[cfg(windows)]
749    let socket = unsafe {
750        // Safety: `from_raw_socket` is only safe to call if ownership of the raw
751        // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
752        // socket2 socket, it gives up ownership of the SOCKET and will not close
753        // it, so this is safe.
754        use std::os::windows::io::{FromRawSocket, IntoRawSocket};
755        TcpSocket::from_raw_socket(socket.into_raw_socket())
756    };
757
758    if config.reuse_address {
759        if let Err(e) = socket.set_reuseaddr(true) {
760            warn!("tcp set_reuse_address error: {}", e);
761        }
762    }
763
764    if let Some(size) = config.send_buffer_size {
765        if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
766            warn!("tcp set_buffer_size error: {}", e);
767        }
768    }
769
770    if let Some(size) = config.recv_buffer_size {
771        if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) {
772            warn!("tcp set_recv_buffer_size error: {}", e);
773        }
774    }
775
776    let connect = socket.connect(*addr);
777    Ok(async move {
778        match connect_timeout {
779            Some(dur) => match tokio::time::timeout(dur, connect).await {
780                Ok(Ok(s)) => Ok(s),
781                Ok(Err(e)) => Err(e),
782                Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
783            },
784            None => connect.await,
785        }
786        .map_err(ConnectError::m("tcp connect error"))
787    })
788}
789
790impl ConnectingTcp<'_> {
791    async fn connect(mut self) -> Result<TcpStream, ConnectError> {
792        match self.fallback {
793            None => self.preferred.connect(self.config).await,
794            Some(mut fallback) => {
795                let preferred_fut = self.preferred.connect(self.config);
796                futures_util::pin_mut!(preferred_fut);
797
798                let fallback_fut = fallback.remote.connect(self.config);
799                futures_util::pin_mut!(fallback_fut);
800
801                let fallback_delay = fallback.delay;
802                futures_util::pin_mut!(fallback_delay);
803
804                let (result, future) =
805                    match futures_util::future::select(preferred_fut, fallback_delay).await {
806                        Either::Left((result, _fallback_delay)) => {
807                            (result, Either::Right(fallback_fut))
808                        }
809                        Either::Right(((), preferred_fut)) => {
810                            // Delay is done, start polling both the preferred and the fallback
811                            futures_util::future::select(preferred_fut, fallback_fut)
812                                .await
813                                .factor_first()
814                        }
815                    };
816
817                if result.is_err() {
818                    // Fallback to the remaining future (could be preferred or fallback)
819                    // if we get an error
820                    future.await
821                } else {
822                    result
823                }
824            }
825        }
826    }
827}
828
829/// Respect explicit ports in the URI, if none, either
830/// keep non `0` ports resolved from a custom dns resolver,
831/// or use the default port for the scheme.
832fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) {
833    if explicit || addr.port() == 0 {
834        addr.set_port(host_port)
835    };
836}
837
838#[cfg(test)]
839mod tests {
840    use std::io;
841    use std::net::SocketAddr;
842
843    use ::http::Uri;
844
845    use crate::client::legacy::connect::http::TcpKeepaliveConfig;
846
847    use super::super::sealed::{Connect, ConnectSvc};
848    use super::{Config, ConnectError, HttpConnector};
849
850    use super::set_port;
851
852    async fn connect<C>(
853        connector: C,
854        dst: Uri,
855    ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
856    where
857        C: Connect,
858    {
859        connector.connect(super::super::sealed::Internal, dst).await
860    }
861
862    #[tokio::test]
863    #[cfg_attr(miri, ignore)]
864    async fn test_errors_enforce_http() {
865        let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
866        let connector = HttpConnector::new();
867
868        let err = connect(connector, dst).await.unwrap_err();
869        assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
870    }
871
872    #[cfg(any(target_os = "linux", target_os = "macos"))]
873    fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
874        use std::net::{IpAddr, TcpListener};
875
876        let mut ip_v4 = None;
877        let mut ip_v6 = None;
878
879        let ips = pnet_datalink::interfaces()
880            .into_iter()
881            .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
882
883        for ip in ips {
884            match ip {
885                IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
886                IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
887                _ => (),
888            }
889
890            if ip_v4.is_some() && ip_v6.is_some() {
891                break;
892            }
893        }
894
895        (ip_v4, ip_v6)
896    }
897
898    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
899    fn default_interface() -> Option<String> {
900        pnet_datalink::interfaces()
901            .iter()
902            .find(|e| e.is_up() && !e.is_loopback() && !e.ips.is_empty())
903            .map(|e| e.name.clone())
904    }
905
906    #[tokio::test]
907    #[cfg_attr(miri, ignore)]
908    async fn test_errors_missing_scheme() {
909        let dst = "example.domain".parse().unwrap();
910        let mut connector = HttpConnector::new();
911        connector.enforce_http(false);
912
913        let err = connect(connector, dst).await.unwrap_err();
914        assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
915    }
916
917    // NOTE: pnet crate that we use in this test doesn't compile on Windows
918    #[cfg(any(target_os = "linux", target_os = "macos"))]
919    #[cfg_attr(miri, ignore)]
920    #[tokio::test]
921    async fn local_address() {
922        use std::net::{IpAddr, TcpListener};
923
924        let (bind_ip_v4, bind_ip_v6) = get_local_ips();
925        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
926        let port = server4.local_addr().unwrap().port();
927        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
928
929        let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
930            let mut connector = HttpConnector::new();
931
932            match (bind_ip_v4, bind_ip_v6) {
933                (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
934                (Some(v4), None) => connector.set_local_address(Some(v4.into())),
935                (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
936                _ => unreachable!(),
937            }
938
939            connect(connector, dst.parse().unwrap()).await.unwrap();
940
941            let (_, client_addr) = server.accept().unwrap();
942
943            assert_eq!(client_addr.ip(), expected_ip);
944        };
945
946        if let Some(ip) = bind_ip_v4 {
947            assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
948        }
949
950        if let Some(ip) = bind_ip_v6 {
951            assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
952        }
953    }
954
955    // NOTE: pnet crate that we use in this test doesn't compile on Windows
956    #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))]
957    #[tokio::test]
958    #[ignore = "setting `SO_BINDTODEVICE` requires the `CAP_NET_RAW` capability (works when running as root)"]
959    async fn interface() {
960        use socket2::{Domain, Protocol, Socket, Type};
961        use std::net::TcpListener;
962
963        let interface: Option<String> = default_interface();
964
965        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
966        let port = server4.local_addr().unwrap().port();
967
968        let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
969
970        let assert_interface_name =
971            |dst: String,
972             server: TcpListener,
973             bind_iface: Option<String>,
974             expected_interface: Option<String>| async move {
975                let mut connector = HttpConnector::new();
976                if let Some(iface) = bind_iface {
977                    connector.set_interface(iface);
978                }
979
980                connect(connector, dst.parse().unwrap()).await.unwrap();
981                let domain = Domain::for_address(server.local_addr().unwrap());
982                let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)).unwrap();
983
984                assert_eq!(
985                    socket.device().unwrap().as_deref(),
986                    expected_interface.as_deref().map(|val| val.as_bytes())
987                );
988            };
989
990        assert_interface_name(
991            format!("http://127.0.0.1:{}", port),
992            server4,
993            interface.clone(),
994            interface.clone(),
995        )
996        .await;
997        assert_interface_name(
998            format!("http://[::1]:{}", port),
999            server6,
1000            interface.clone(),
1001            interface.clone(),
1002        )
1003        .await;
1004    }
1005
1006    #[test]
1007    #[ignore] // TODO
1008    #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
1009    fn client_happy_eyeballs() {
1010        use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
1011        use std::time::{Duration, Instant};
1012
1013        use super::dns;
1014        use super::ConnectingTcp;
1015
1016        let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
1017        let addr = server4.local_addr().unwrap();
1018        let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
1019        let rt = tokio::runtime::Builder::new_current_thread()
1020            .enable_all()
1021            .build()
1022            .unwrap();
1023
1024        let local_timeout = Duration::default();
1025        let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
1026        let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
1027        let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
1028            + Duration::from_millis(250);
1029
1030        let scenarios = &[
1031            // Fast primary, without fallback.
1032            (&[local_ipv4_addr()][..], 4, local_timeout, false),
1033            (&[local_ipv6_addr()][..], 6, local_timeout, false),
1034            // Fast primary, with (unused) fallback.
1035            (
1036                &[local_ipv4_addr(), local_ipv6_addr()][..],
1037                4,
1038                local_timeout,
1039                false,
1040            ),
1041            (
1042                &[local_ipv6_addr(), local_ipv4_addr()][..],
1043                6,
1044                local_timeout,
1045                false,
1046            ),
1047            // Unreachable + fast primary, without fallback.
1048            (
1049                &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
1050                4,
1051                unreachable_v4_timeout,
1052                false,
1053            ),
1054            (
1055                &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
1056                6,
1057                unreachable_v6_timeout,
1058                false,
1059            ),
1060            // Unreachable + fast primary, with (unused) fallback.
1061            (
1062                &[
1063                    unreachable_ipv4_addr(),
1064                    local_ipv4_addr(),
1065                    local_ipv6_addr(),
1066                ][..],
1067                4,
1068                unreachable_v4_timeout,
1069                false,
1070            ),
1071            (
1072                &[
1073                    unreachable_ipv6_addr(),
1074                    local_ipv6_addr(),
1075                    local_ipv4_addr(),
1076                ][..],
1077                6,
1078                unreachable_v6_timeout,
1079                true,
1080            ),
1081            // Slow primary, with (used) fallback.
1082            (
1083                &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
1084                6,
1085                fallback_timeout,
1086                false,
1087            ),
1088            (
1089                &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
1090                4,
1091                fallback_timeout,
1092                true,
1093            ),
1094            // Slow primary, with (used) unreachable + fast fallback.
1095            (
1096                &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
1097                6,
1098                fallback_timeout + unreachable_v6_timeout,
1099                false,
1100            ),
1101            (
1102                &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
1103                4,
1104                fallback_timeout + unreachable_v4_timeout,
1105                true,
1106            ),
1107        ];
1108
1109        // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
1110        // Otherwise, connection to "slow" IPv6 address will error-out immediately.
1111        let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
1112
1113        for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
1114            if needs_ipv6_access && !ipv6_accessible {
1115                continue;
1116            }
1117
1118            let (start, stream) = rt
1119                .block_on(async move {
1120                    let addrs = hosts
1121                        .iter()
1122                        .map(|host| (host.clone(), addr.port()).into())
1123                        .collect();
1124                    let cfg = Config {
1125                        local_address_ipv4: None,
1126                        local_address_ipv6: None,
1127                        connect_timeout: None,
1128                        tcp_keepalive_config: TcpKeepaliveConfig::default(),
1129                        happy_eyeballs_timeout: Some(fallback_timeout),
1130                        nodelay: false,
1131                        reuse_address: false,
1132                        enforce_http: false,
1133                        send_buffer_size: None,
1134                        recv_buffer_size: None,
1135                        #[cfg(any(
1136                            target_os = "android",
1137                            target_os = "fuchsia",
1138                            target_os = "linux"
1139                        ))]
1140                        interface: None,
1141                    };
1142                    let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
1143                    let start = Instant::now();
1144                    Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
1145                })
1146                .unwrap();
1147            let res = if stream.peer_addr().unwrap().is_ipv4() {
1148                4
1149            } else {
1150                6
1151            };
1152            let duration = start.elapsed();
1153
1154            // Allow actual duration to be +/- 150ms off.
1155            let min_duration = if timeout >= Duration::from_millis(150) {
1156                timeout - Duration::from_millis(150)
1157            } else {
1158                Duration::default()
1159            };
1160            let max_duration = timeout + Duration::from_millis(150);
1161
1162            assert_eq!(res, family);
1163            assert!(duration >= min_duration);
1164            assert!(duration <= max_duration);
1165        }
1166
1167        fn local_ipv4_addr() -> IpAddr {
1168            Ipv4Addr::new(127, 0, 0, 1).into()
1169        }
1170
1171        fn local_ipv6_addr() -> IpAddr {
1172            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
1173        }
1174
1175        fn unreachable_ipv4_addr() -> IpAddr {
1176            Ipv4Addr::new(127, 0, 0, 2).into()
1177        }
1178
1179        fn unreachable_ipv6_addr() -> IpAddr {
1180            Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
1181        }
1182
1183        fn slow_ipv4_addr() -> IpAddr {
1184            // RFC 6890 reserved IPv4 address.
1185            Ipv4Addr::new(198, 18, 0, 25).into()
1186        }
1187
1188        fn slow_ipv6_addr() -> IpAddr {
1189            // RFC 6890 reserved IPv6 address.
1190            Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1191        }
1192
1193        fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1194            let start = Instant::now();
1195            let result =
1196                std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1197
1198            let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1199            let duration = start.elapsed();
1200            (reachable, duration)
1201        }
1202    }
1203
1204    use std::time::Duration;
1205
1206    #[test]
1207    fn no_tcp_keepalive_config() {
1208        assert!(TcpKeepaliveConfig::default().into_tcpkeepalive().is_none());
1209    }
1210
1211    #[test]
1212    fn tcp_keepalive_time_config() {
1213        let mut kac = TcpKeepaliveConfig::default();
1214        kac.time = Some(Duration::from_secs(60));
1215        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1216            assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
1217        } else {
1218            panic!("test failed");
1219        }
1220    }
1221
1222    #[cfg(not(any(target_os = "openbsd", target_os = "redox", target_os = "solaris")))]
1223    #[test]
1224    fn tcp_keepalive_interval_config() {
1225        let mut kac = TcpKeepaliveConfig::default();
1226        kac.interval = Some(Duration::from_secs(1));
1227        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1228            assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
1229        } else {
1230            panic!("test failed");
1231        }
1232    }
1233
1234    #[cfg(not(any(
1235        target_os = "openbsd",
1236        target_os = "redox",
1237        target_os = "solaris",
1238        target_os = "windows"
1239    )))]
1240    #[test]
1241    fn tcp_keepalive_retries_config() {
1242        let mut kac = TcpKeepaliveConfig::default();
1243        kac.retries = Some(3);
1244        if let Some(tcp_keepalive) = kac.into_tcpkeepalive() {
1245            assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
1246        } else {
1247            panic!("test failed");
1248        }
1249    }
1250
1251    #[test]
1252    fn test_set_port() {
1253        // Respect explicit ports no matter what the resolved port is.
1254        let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1255        set_port(&mut addr, 42, true);
1256        assert_eq!(addr.port(), 42);
1257
1258        // Ignore default  host port, and use the socket port instead.
1259        let mut addr = SocketAddr::from(([0, 0, 0, 0], 6881));
1260        set_port(&mut addr, 443, false);
1261        assert_eq!(addr.port(), 6881);
1262
1263        // Use the default port if the resolved port is `0`.
1264        let mut addr = SocketAddr::from(([0, 0, 0, 0], 0));
1265        set_port(&mut addr, 443, false);
1266        assert_eq!(addr.port(), 443);
1267    }
1268}