1use crate::{IfEvent, IpNet, Ipv4Net, Ipv6Net};
2use fnv::FnvHashSet;
3use futures::ready;
4use futures::stream::{FusedStream, Stream, TryStreamExt};
5use futures::StreamExt;
6use rtnetlink::constants::{RTMGRP_IPV4_IFADDR, RTMGRP_IPV6_IFADDR};
7use rtnetlink::packet::address::nlas::Nla;
8use rtnetlink::packet::{AddressMessage, RtnlMessage};
9use rtnetlink::proto::{Connection, NetlinkPayload};
10use rtnetlink::sys::{AsyncSocket, SocketAddr};
11use std::collections::VecDeque;
12use std::future::Future;
13use std::io::{Error, ErrorKind, Result};
14use std::net::{Ipv4Addr, Ipv6Addr};
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18#[cfg(feature = "tokio")]
19pub mod tokio {
20 use rtnetlink::sys::TokioSocket;
22
23 pub type IfWatcher = super::IfWatcher<TokioSocket>;
25}
26
27#[cfg(feature = "smol")]
28pub mod smol {
29 use rtnetlink::sys::SmolSocket;
31
32 pub type IfWatcher = super::IfWatcher<SmolSocket>;
34}
35
36pub struct IfWatcher<T> {
37 conn: Connection<RtnlMessage, T>,
38 messages: Pin<Box<dyn Stream<Item = Result<RtnlMessage>> + Send>>,
39 addrs: FnvHashSet<IpNet>,
40 queue: VecDeque<IfEvent>,
41}
42
43impl<T> std::fmt::Debug for IfWatcher<T> {
44 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
45 f.debug_struct("IfWatcher")
46 .field("addrs", &self.addrs)
47 .finish_non_exhaustive()
48 }
49}
50
51impl<T> IfWatcher<T>
52where
53 T: AsyncSocket + Unpin,
54{
55 pub fn new() -> Result<Self> {
57 let (mut conn, handle, messages) = rtnetlink::new_connection_with_socket::<T>()?;
58 let groups = RTMGRP_IPV4_IFADDR | RTMGRP_IPV6_IFADDR;
59 let addr = SocketAddr::new(0, groups);
60 conn.socket_mut().socket_mut().bind(&addr)?;
61 let get_addrs_stream = handle
62 .address()
63 .get()
64 .execute()
65 .map_ok(RtnlMessage::NewAddress)
66 .map_err(|err| Error::new(ErrorKind::Other, err));
67 let msg_stream = messages.filter_map(|(msg, _)| async {
68 match msg.payload {
69 NetlinkPayload::Error(err) => Some(Err(err.to_io())),
70 NetlinkPayload::InnerMessage(msg) => Some(Ok(msg)),
71 _ => None,
72 }
73 });
74 let messages = get_addrs_stream.chain(msg_stream).boxed();
75 let addrs = FnvHashSet::default();
76 let queue = VecDeque::default();
77 Ok(Self {
78 conn,
79 messages,
80 addrs,
81 queue,
82 })
83 }
84
85 pub fn iter(&self) -> impl Iterator<Item = &IpNet> {
87 self.addrs.iter()
88 }
89
90 fn add_address(&mut self, msg: AddressMessage) {
91 for net in iter_nets(msg) {
92 if self.addrs.insert(net) {
93 self.queue.push_back(IfEvent::Up(net));
94 }
95 }
96 }
97
98 fn rem_address(&mut self, msg: AddressMessage) {
99 for net in iter_nets(msg) {
100 if self.addrs.remove(&net) {
101 self.queue.push_back(IfEvent::Down(net));
102 }
103 }
104 }
105
106 pub fn poll_if_event(&mut self, cx: &mut Context) -> Poll<Result<IfEvent>> {
108 loop {
109 if let Some(event) = self.queue.pop_front() {
110 return Poll::Ready(Ok(event));
111 }
112 if Pin::new(&mut self.conn).poll(cx).is_ready() {
113 return Poll::Ready(Err(socket_err()));
114 }
115 let message = ready!(self.messages.poll_next_unpin(cx)).ok_or_else(socket_err)??;
116 match message {
117 RtnlMessage::NewAddress(msg) => self.add_address(msg),
118 RtnlMessage::DelAddress(msg) => self.rem_address(msg),
119 _ => {}
120 }
121 }
122 }
123}
124
125fn socket_err() -> std::io::Error {
126 std::io::Error::new(ErrorKind::BrokenPipe, "rtnetlink socket closed")
127}
128
129fn iter_nets(msg: AddressMessage) -> impl Iterator<Item = IpNet> {
130 let prefix = msg.header.prefix_len;
131 let family = msg.header.family;
132 msg.nlas.into_iter().filter_map(move |nla| {
133 if let Nla::Address(octets) = nla {
134 match family {
135 2 => {
136 let mut addr = [0; 4];
137 addr.copy_from_slice(&octets);
138 Some(IpNet::V4(
139 Ipv4Net::new(Ipv4Addr::from(addr), prefix).unwrap(),
140 ))
141 }
142 10 => {
143 let mut addr = [0; 16];
144 addr.copy_from_slice(&octets);
145 Some(IpNet::V6(
146 Ipv6Net::new(Ipv6Addr::from(addr), prefix).unwrap(),
147 ))
148 }
149 _ => None,
150 }
151 } else {
152 None
153 }
154 })
155}
156
157impl<T> Stream for IfWatcher<T>
158where
159 T: AsyncSocket + Unpin,
160{
161 type Item = Result<IfEvent>;
162 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
163 Pin::into_inner(self).poll_if_event(cx).map(Some)
164 }
165}
166
167impl<T> FusedStream for IfWatcher<T>
168where
169 T: AsyncSocket + AsyncSocket + Unpin,
170{
171 fn is_terminated(&self) -> bool {
172 false
173 }
174}