netlink_proto/
connection.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    fmt::Debug,
5    io,
6    pin::Pin,
7    task::{Context, Poll},
8};
9
10use futures::{
11    channel::mpsc::{UnboundedReceiver, UnboundedSender},
12    Future,
13    Sink,
14    Stream,
15};
16use log::{error, warn};
17use netlink_packet_core::{
18    NetlinkDeserializable,
19    NetlinkMessage,
20    NetlinkPayload,
21    NetlinkSerializable,
22};
23
24use crate::{
25    codecs::{NetlinkCodec, NetlinkMessageCodec},
26    framed::NetlinkFramed,
27    sys::{AsyncSocket, SocketAddr},
28    Protocol,
29    Request,
30    Response,
31};
32
33#[cfg(feature = "tokio_socket")]
34use netlink_sys::TokioSocket as DefaultSocket;
35#[cfg(not(feature = "tokio_socket"))]
36type DefaultSocket = ();
37
38/// Connection to a Netlink socket, running in the background.
39///
40/// [`ConnectionHandle`](struct.ConnectionHandle.html) are used to pass new requests to the
41/// `Connection`, that in turn, sends them through the netlink socket.
42pub struct Connection<T, S = DefaultSocket, C = NetlinkCodec>
43where
44    T: Debug + NetlinkSerializable + NetlinkDeserializable,
45{
46    socket: NetlinkFramed<T, S, C>,
47
48    protocol: Protocol<T, UnboundedSender<NetlinkMessage<T>>>,
49
50    /// Channel used by the user to pass requests to the connection.
51    requests_rx: Option<UnboundedReceiver<Request<T>>>,
52
53    /// Channel used to transmit to the ConnectionHandle the unsolicited messages received from the
54    /// socket (multicast messages for instance).
55    unsolicited_messages_tx: Option<UnboundedSender<(NetlinkMessage<T>, SocketAddr)>>,
56
57    socket_closed: bool,
58}
59
60impl<T, S, C> Connection<T, S, C>
61where
62    T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
63    S: AsyncSocket,
64    C: NetlinkMessageCodec,
65{
66    pub(crate) fn new(
67        requests_rx: UnboundedReceiver<Request<T>>,
68        unsolicited_messages_tx: UnboundedSender<(NetlinkMessage<T>, SocketAddr)>,
69        protocol: isize,
70    ) -> io::Result<Self> {
71        let socket = S::new(protocol)?;
72        Ok(Connection {
73            socket: NetlinkFramed::new(socket),
74            protocol: Protocol::new(),
75            requests_rx: Some(requests_rx),
76            unsolicited_messages_tx: Some(unsolicited_messages_tx),
77            socket_closed: false,
78        })
79    }
80
81    pub fn socket_mut(&mut self) -> &mut S {
82        self.socket.get_mut()
83    }
84
85    pub fn poll_send_messages(&mut self, cx: &mut Context) {
86        trace!("poll_send_messages called");
87        let Connection {
88            ref mut socket,
89            ref mut protocol,
90            ..
91        } = self;
92        let mut socket = Pin::new(socket);
93
94        while !protocol.outgoing_messages.is_empty() {
95            trace!("found outgoing message to send checking if socket is ready");
96            if let Poll::Ready(Err(e)) = Pin::as_mut(&mut socket).poll_ready(cx) {
97                // Sink errors are usually not recoverable. The socket
98                // probably shut down.
99                warn!("netlink socket shut down: {:?}", e);
100                self.socket_closed = true;
101                return;
102            }
103
104            let (mut message, addr) = protocol.outgoing_messages.pop_front().unwrap();
105            message.finalize();
106
107            trace!("sending outgoing message");
108            if let Err(e) = Pin::as_mut(&mut socket).start_send((message, addr)) {
109                error!("failed to send message: {:?}", e);
110                self.socket_closed = true;
111                return;
112            }
113        }
114
115        trace!("poll_send_messages done");
116        self.poll_flush(cx)
117    }
118
119    pub fn poll_flush(&mut self, cx: &mut Context) {
120        trace!("poll_flush called");
121        if let Poll::Ready(Err(e)) = Pin::new(&mut self.socket).poll_flush(cx) {
122            warn!("error flushing netlink socket: {:?}", e);
123            self.socket_closed = true;
124        }
125    }
126
127    pub fn poll_read_messages(&mut self, cx: &mut Context) {
128        trace!("poll_read_messages called");
129        let mut socket = Pin::new(&mut self.socket);
130
131        loop {
132            trace!("polling socket");
133            match socket.as_mut().poll_next(cx) {
134                Poll::Ready(Some((message, addr))) => {
135                    trace!("read datagram from socket");
136                    self.protocol.handle_message(message, addr);
137                }
138                Poll::Ready(None) => {
139                    warn!("netlink socket stream shut down");
140                    self.socket_closed = true;
141                    return;
142                }
143                Poll::Pending => {
144                    trace!("no datagram read from socket");
145                    return;
146                }
147            }
148        }
149    }
150
151    pub fn poll_requests(&mut self, cx: &mut Context) {
152        trace!("poll_requests called");
153        if let Some(mut stream) = self.requests_rx.as_mut() {
154            loop {
155                match Pin::new(&mut stream).poll_next(cx) {
156                    Poll::Ready(Some(request)) => self.protocol.request(request),
157                    Poll::Ready(None) => break,
158                    Poll::Pending => return,
159                }
160            }
161            let _ = self.requests_rx.take();
162            trace!("no new requests to handle poll_requests done");
163        }
164    }
165
166    pub fn forward_unsolicited_messages(&mut self) {
167        if self.unsolicited_messages_tx.is_none() {
168            while let Some((message, source)) = self.protocol.incoming_requests.pop_front() {
169                warn!(
170                    "ignoring unsolicited message {:?} from {:?}",
171                    message, source
172                );
173            }
174            return;
175        }
176
177        trace!("forward_unsolicited_messages called");
178        let mut ready = false;
179
180        let Connection {
181            ref mut protocol,
182            ref mut unsolicited_messages_tx,
183            ..
184        } = self;
185
186        while let Some((message, source)) = protocol.incoming_requests.pop_front() {
187            if unsolicited_messages_tx
188                .as_mut()
189                .unwrap()
190                .unbounded_send((message, source))
191                .is_err()
192            {
193                // The channel is unbounded so the only error that can
194                // occur is that the channel is closed because the
195                // receiver was dropped
196                warn!("failed to forward message to connection handle: channel closed");
197                ready = true;
198                break;
199            }
200        }
201
202        if ready {
203            // The channel is closed so we can drop the sender.
204            let _ = self.unsolicited_messages_tx.take();
205            // purge `protocol.incoming_requests`
206            self.forward_unsolicited_messages();
207        }
208
209        trace!("forward_unsolicited_messages done");
210    }
211
212    pub fn forward_responses(&mut self) {
213        trace!("forward_responses called");
214        let protocol = &mut self.protocol;
215
216        while let Some(response) = protocol.incoming_responses.pop_front() {
217            let Response {
218                message,
219                done,
220                metadata: tx,
221            } = response;
222            if done {
223                use NetlinkPayload::*;
224                match &message.payload {
225                    // Since `self.protocol` set the `done` flag here,
226                    // we know it has already dropped the request and
227                    // its associated metadata, ie the UnboundedSender
228                    // used to forward messages back to the
229                    // ConnectionHandle. By just continuing we're
230                    // dropping the last instance of that sender,
231                    // hence closing the channel and signaling the
232                    // handle that no more messages are expected.
233                    Noop | Done | Ack(_) => {
234                        trace!("not forwarding Noop/Ack/Done message to the handle");
235                        continue;
236                    }
237                    // I'm not sure how we should handle overrun messages
238                    Overrun(_) => unimplemented!("overrun is not handled yet"),
239                    // We need to forward error messages and messages
240                    // that are part of the netlink subprotocol,
241                    // because only the user knows how they want to
242                    // handle them.
243                    Error(_) | InnerMessage(_) => {}
244                }
245            }
246
247            trace!("forwarding response to the handle");
248            if tx.unbounded_send(message).is_err() {
249                // With an unboundedsender, an error can
250                // only happen if the receiver is closed.
251                warn!("failed to forward response back to the handle");
252            }
253        }
254        trace!("forward_responses done");
255    }
256
257    pub fn should_shut_down(&self) -> bool {
258        self.socket_closed || (self.unsolicited_messages_tx.is_none() && self.requests_rx.is_none())
259    }
260}
261
262impl<T, S, C> Future for Connection<T, S, C>
263where
264    T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
265    S: AsyncSocket,
266    C: NetlinkMessageCodec,
267{
268    type Output = ();
269
270    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
271        trace!("polling Connection");
272        let pinned = self.get_mut();
273
274        debug!("reading incoming messages");
275        pinned.poll_read_messages(cx);
276
277        debug!("forwarding unsolicited messages to the connection handle");
278        pinned.forward_unsolicited_messages();
279
280        debug!("forwaring responses to previous requests to the connection handle");
281        pinned.forward_responses();
282
283        debug!("handling requests");
284        pinned.poll_requests(cx);
285
286        debug!("sending messages");
287        pinned.poll_send_messages(cx);
288
289        trace!("done polling Connection");
290
291        if pinned.should_shut_down() {
292            Poll::Ready(())
293        } else {
294            Poll::Pending
295        }
296    }
297}