if_watch/
linux.rs

1use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net};
2use fnv::FnvHashSet;
3use futures::ready;
4use futures::stream::{FusedStream, Stream, TryStreamExt};
5use futures::StreamExt;
6use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
7use rtnetlink::packet::address::nlas::Nla;
8use rtnetlink::packet::{AddressMessage, RtnlMessage};
9use rtnetlink::proto::{Connection, NetlinkPayload};
10use rtnetlink::sys::{AsyncSocket, SocketAddr};
11use std::collections::VecDeque;
12use std::future::Future;
13use std::io::{Error, ErrorKind, Result};
14use std::net::{Ipv4Addr, Ipv6Addr};
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18#[cfg(feature = "tokio")]
19pub mod tokio {
20    //! An interface watcher that uses `rtnetlink`'s [`TokioSocket`](rtnetlink::sys::TokioSocket)
21    use rtnetlink::sys::TokioSocket;
22
23    /// Watches for interface changes.
24    pub type IfWatcher = super::IfWatcher<TokioSocket>;
25}
26
27#[cfg(feature = "smol")]
28pub mod smol {
29    //! An interface watcher that uses `rtnetlink`'s [`SmolSocket`](rtnetlink::sys::SmolSocket)
30    use rtnetlink::sys::SmolSocket;
31
32    /// Watches for interface changes.
33    pub type IfWatcher = super::IfWatcher<SmolSocket>;
34}
35
36pub struct IfWatcher<T> {
37    conn: Connection<RtnlMessage, T>,
38    messages: Pin<Box<dyn Stream<Item = Result<RtnlMessage>> + Send>>,
39    addrs: FnvHashSet<IpNet>,
40    queue: VecDeque<IfEvent>,
41}
42
43impl<T> std::fmt::Debug for IfWatcher<T> {
44    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45        f.debug_struct("IfWatcher")
46            .field("addrs", &self.addrs)
47            .finish_non_exhaustive()
48    }
49}
50
51impl<T> IfWatcher<T>
52where
53    T: AsyncSocket + Unpin,
54{
55    /// Create a watcher.
56    pub fn new() -> Result<Self> {
57        let (mut conn, handle, messages) = rtnetlink::new_connection_with_socket::<T>()?;
58        let groups = RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR;
59        let addr = SocketAddr::new(0, groups);
60        conn.socket_mut().socket_mut().bind(&addr)?;
61        let get_addrs_stream = handle
62            .address()
63            .get()
64            .execute()
65            .map_ok(RtnlMessage::NewAddress)
66            .map_err(|err| Error::new(ErrorKind::Other, err));
67        let msg_stream = messages.filter_map(|(msg, _)| async {
68            match msg.payload {
69                NetlinkPayload::Error(err) => Some(Err(err.to_io())),
70                NetlinkPayload::InnerMessage(msg) => Some(Ok(msg)),
71                _ => None,
72            }
73        });
74        let messages = get_addrs_stream.chain(msg_stream).boxed();
75        let addrs = FnvHashSet::default();
76        let queue = VecDeque::default();
77        Ok(Self {
78            conn,
79            messages,
80            addrs,
81            queue,
82        })
83    }
84
85    /// Iterate over current networks.
86    pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
87        self.addrs.iter()
88    }
89
90    fn add_address(&mut self, msg: AddressMessage) {
91        for net in iter_nets(msg) {
92            if self.addrs.insert(net) {
93                self.queue.push_back(IfEvent::Up(net));
94            }
95        }
96    }
97
98    fn rem_address(&mut self, msg: AddressMessage) {
99        for net in iter_nets(msg) {
100            if self.addrs.remove(&net) {
101                self.queue.push_back(IfEvent::Down(net));
102            }
103        }
104    }
105
106    /// Poll for an address change event.
107    pub fn poll_if_event(&mut self, cx: &mut Context) -> Poll<Result<IfEvent>> {
108        loop {
109            if let Some(event) = self.queue.pop_front() {
110                return Poll::Ready(Ok(event));
111            }
112            if Pin::new(&mut self.conn).poll(cx).is_ready() {
113                return Poll::Ready(Err(socket_err()));
114            }
115            let message = ready!(self.messages.poll_next_unpin(cx)).ok_or_else(socket_err)??;
116            match message {
117                RtnlMessage::NewAddress(msg) => self.add_address(msg),
118                RtnlMessage::DelAddress(msg) => self.rem_address(msg),
119                _ => {}
120            }
121        }
122    }
123}
124
125fn socket_err() -> std::io::Error {
126    std::io::Error::new(ErrorKind::BrokenPipe, "rtnetlink socket closed")
127}
128
129fn iter_nets(msg: AddressMessage) -> impl Iterator<Item = IpNet> {
130    let prefix = msg.header.prefix_len;
131    let family = msg.header.family;
132    msg.nlas.into_iter().filter_map(move |nla| {
133        if let Nla::Address(octets) = nla {
134            match family {
135                2 => {
136                    let mut addr = [0; 4];
137                    addr.copy_from_slice(&octets);
138                    Some(IpNet::V4(
139                        Ipv4Net::new(Ipv4Addr::from(addr), prefix).unwrap(),
140                    ))
141                }
142                10 => {
143                    let mut addr = [0; 16];
144                    addr.copy_from_slice(&octets);
145                    Some(IpNet::V6(
146                        Ipv6Net::new(Ipv6Addr::from(addr), prefix).unwrap(),
147                    ))
148                }
149                _ => None,
150            }
151        } else {
152            None
153        }
154    })
155}
156
157impl<T> Stream for IfWatcher<T>
158where
159    T: AsyncSocket + Unpin,
160{
161    type Item = Result<IfEvent>;
162    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
163        Pin::into_inner(self).poll_if_event(cx).map(Some)
164    }
165}
166
167impl<T> FusedStream for IfWatcher<T>
168where
169    T: AsyncSocket + AsyncSocket + Unpin,
170{
171    fn is_terminated(&self) -> bool {
172        false
173    }
174}