netlink_proto/
framed.rs

1// SPDX-License-Identifier: MIT
2
3use bytes::BytesMut;
4use std::{
5    fmt::Debug,
6    io,
7    marker::PhantomData,
8    pin::Pin,
9    task::{Context, Poll},
10};
11
12use futures::{Sink, Stream};
13use log::error;
14
15use crate::{
16    codecs::NetlinkMessageCodec,
17    sys::{AsyncSocket, SocketAddr},
18};
19use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable};
20
21pub struct NetlinkFramed<T, S, C> {
22    socket: S,
23    // see https://doc.rust-lang.org/nomicon/phantom-data.html
24    // "invariant" seems like the safe choice; using `fn(T) -> T`
25    // should make it invariant but still Send+Sync.
26    msg_type: PhantomData<fn(T) -> T>, // invariant
27    codec: PhantomData<fn(C) -> C>,    // invariant
28    reader: BytesMut,
29    writer: BytesMut,
30    in_addr: SocketAddr,
31    out_addr: SocketAddr,
32    flushed: bool,
33}
34
35impl<T, S, C> Stream for NetlinkFramed<T, S, C>
36where
37    T: NetlinkDeserializable + Debug,
38    S: AsyncSocket,
39    C: NetlinkMessageCodec,
40{
41    type Item = (NetlinkMessage<T>, SocketAddr);
42
43    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44        let Self {
45            ref mut socket,
46            ref mut in_addr,
47            ref mut reader,
48            ..
49        } = Pin::get_mut(self);
50
51        loop {
52            match C::decode::<T>(reader) {
53                Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
54                Ok(None) => {}
55                Err(e) => {
56                    error!("unrecoverable error in decoder: {:?}", e);
57                    return Poll::Ready(None);
58                }
59            }
60
61            reader.clear();
62            reader.reserve(INITIAL_READER_CAPACITY);
63
64            *in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
65                Ok(addr) => addr,
66                Err(e) => {
67                    error!("failed to read from netlink socket: {:?}", e);
68                    return Poll::Ready(None);
69                }
70            };
71        }
72    }
73}
74
75impl<T, S, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, S, C>
76where
77    T: NetlinkSerializable + Debug,
78    S: AsyncSocket,
79    C: NetlinkMessageCodec,
80{
81    type Error = io::Error;
82
83    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84        if !self.flushed {
85            match self.poll_flush(cx)? {
86                Poll::Ready(()) => {}
87                Poll::Pending => return Poll::Pending,
88            }
89        }
90
91        Poll::Ready(Ok(()))
92    }
93
94    fn start_send(
95        self: Pin<&mut Self>,
96        item: (NetlinkMessage<T>, SocketAddr),
97    ) -> Result<(), Self::Error> {
98        trace!("sending frame");
99        let (frame, out_addr) = item;
100        let pin = self.get_mut();
101        C::encode(frame, &mut pin.writer)?;
102        pin.out_addr = out_addr;
103        pin.flushed = false;
104        trace!("frame encoded; length={}", pin.writer.len());
105        Ok(())
106    }
107
108    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        if self.flushed {
110            return Poll::Ready(Ok(()));
111        }
112
113        trace!("flushing frame; length={}", self.writer.len());
114        let Self {
115            ref mut socket,
116            ref mut out_addr,
117            ref mut writer,
118            ..
119        } = *self;
120
121        let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
122        trace!("written {}", n);
123
124        let wrote_all = n == self.writer.len();
125        self.writer.clear();
126        self.flushed = true;
127
128        let res = if wrote_all {
129            Ok(())
130        } else {
131            Err(io::Error::new(
132                io::ErrorKind::Other,
133                "failed to write entire datagram to socket",
134            ))
135        };
136
137        Poll::Ready(res)
138    }
139
140    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141        ready!(self.poll_flush(cx))?;
142        Poll::Ready(Ok(()))
143    }
144}
145
146// The theoritical max netlink packet size is 32KB for a netlink
147// message since Linux 4.9 (16KB before). See:
148// https://git.kernel.org/pub/scm/linux/kernel/git/davem/net-next.git/commit/?id=d35c99ff77ecb2eb239731b799386f3b3637a31e
149const INITIAL_READER_CAPACITY: usize = 64 * 1024;
150const INITIAL_WRITER_CAPACITY: usize = 8 * 1024;
151
152impl<T, S, C> NetlinkFramed<T, S, C> {
153    /// Create a new `NetlinkFramed` backed by the given socket and codec.
154    ///
155    /// See struct level documentation for more details.
156    pub fn new(socket: S) -> Self {
157        Self {
158            socket,
159            msg_type: PhantomData,
160            codec: PhantomData,
161            out_addr: SocketAddr::new(0, 0),
162            in_addr: SocketAddr::new(0, 0),
163            reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
164            writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
165            flushed: true,
166        }
167    }
168
169    /// Returns a reference to the underlying I/O stream wrapped by `Framed`.
170    ///
171    /// # Note
172    ///
173    /// Care should be taken to not tamper with the underlying stream of data
174    /// coming in as it may corrupt the stream of frames otherwise being worked
175    /// with.
176    pub fn get_ref(&self) -> &S {
177        &self.socket
178    }
179
180    /// Returns a mutable reference to the underlying I/O stream wrapped by
181    /// `Framed`.
182    ///
183    /// # Note
184    ///
185    /// Care should be taken to not tamper with the underlying stream of data
186    /// coming in as it may corrupt the stream of frames otherwise being worked
187    /// with.
188    pub fn get_mut(&mut self) -> &mut S {
189        &mut self.socket
190    }
191
192    /// Consumes the `Framed`, returning its underlying I/O stream.
193    pub fn into_inner(self) -> S {
194        self.socket
195    }
196}