1use bytes::BytesMut;
4use std::{
5 fmt::Debug,
6 io,
7 marker::PhantomData,
8 pin::Pin,
9 task::{Context, Poll},
10};
11
12use futures::{Sink, Stream};
13use log::error;
14
15use crate::{
16 codecs::NetlinkMessageCodec,
17 sys::{AsyncSocket, SocketAddr},
18};
19use netlink_packet_core::{NetlinkDeserializable, NetlinkMessage, NetlinkSerializable};
20
21pub struct NetlinkFramed<T, S, C> {
22 socket: S,
23 msg_type: PhantomData<fn(T) -> T>, codec: PhantomData<fn(C) -> C>, reader: BytesMut,
29 writer: BytesMut,
30 in_addr: SocketAddr,
31 out_addr: SocketAddr,
32 flushed: bool,
33}
34
35impl<T, S, C> Stream for NetlinkFramed<T, S, C>
36where
37 T: NetlinkDeserializable + Debug,
38 S: AsyncSocket,
39 C: NetlinkMessageCodec,
40{
41 type Item = (NetlinkMessage<T>, SocketAddr);
42
43 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
44 let Self {
45 ref mut socket,
46 ref mut in_addr,
47 ref mut reader,
48 ..
49 } = Pin::get_mut(self);
50
51 loop {
52 match C::decode::<T>(reader) {
53 Ok(Some(item)) => return Poll::Ready(Some((item, *in_addr))),
54 Ok(None) => {}
55 Err(e) => {
56 error!("unrecoverable error in decoder: {:?}", e);
57 return Poll::Ready(None);
58 }
59 }
60
61 reader.clear();
62 reader.reserve(INITIAL_READER_CAPACITY);
63
64 *in_addr = match ready!(socket.poll_recv_from(cx, reader)) {
65 Ok(addr) => addr,
66 Err(e) => {
67 error!("failed to read from netlink socket: {:?}", e);
68 return Poll::Ready(None);
69 }
70 };
71 }
72 }
73}
74
75impl<T, S, C> Sink<(NetlinkMessage<T>, SocketAddr)> for NetlinkFramed<T, S, C>
76where
77 T: NetlinkSerializable + Debug,
78 S: AsyncSocket,
79 C: NetlinkMessageCodec,
80{
81 type Error = io::Error;
82
83 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
84 if !self.flushed {
85 match self.poll_flush(cx)? {
86 Poll::Ready(()) => {}
87 Poll::Pending => return Poll::Pending,
88 }
89 }
90
91 Poll::Ready(Ok(()))
92 }
93
94 fn start_send(
95 self: Pin<&mut Self>,
96 item: (NetlinkMessage<T>, SocketAddr),
97 ) -> Result<(), Self::Error> {
98 trace!("sending frame");
99 let (frame, out_addr) = item;
100 let pin = self.get_mut();
101 C::encode(frame, &mut pin.writer)?;
102 pin.out_addr = out_addr;
103 pin.flushed = false;
104 trace!("frame encoded; length={}", pin.writer.len());
105 Ok(())
106 }
107
108 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 if self.flushed {
110 return Poll::Ready(Ok(()));
111 }
112
113 trace!("flushing frame; length={}", self.writer.len());
114 let Self {
115 ref mut socket,
116 ref mut out_addr,
117 ref mut writer,
118 ..
119 } = *self;
120
121 let n = ready!(socket.poll_send_to(cx, writer, out_addr))?;
122 trace!("written {}", n);
123
124 let wrote_all = n == self.writer.len();
125 self.writer.clear();
126 self.flushed = true;
127
128 let res = if wrote_all {
129 Ok(())
130 } else {
131 Err(io::Error::new(
132 io::ErrorKind::Other,
133 "failed to write entire datagram to socket",
134 ))
135 };
136
137 Poll::Ready(res)
138 }
139
140 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141 ready!(self.poll_flush(cx))?;
142 Poll::Ready(Ok(()))
143 }
144}
145
146const INITIAL_READER_CAPACITY: usize = 64 * 1024;
150const INITIAL_WRITER_CAPACITY: usize = 8 * 1024;
151
152impl<T, S, C> NetlinkFramed<T, S, C> {
153 pub fn new(socket: S) -> Self {
157 Self {
158 socket,
159 msg_type: PhantomData,
160 codec: PhantomData,
161 out_addr: SocketAddr::new(0, 0),
162 in_addr: SocketAddr::new(0, 0),
163 reader: BytesMut::with_capacity(INITIAL_READER_CAPACITY),
164 writer: BytesMut::with_capacity(INITIAL_WRITER_CAPACITY),
165 flushed: true,
166 }
167 }
168
169 pub fn get_ref(&self) -> &S {
177 &self.socket
178 }
179
180 pub fn get_mut(&mut self) -> &mut S {
189 &mut self.socket
190 }
191
192 pub fn into_inner(self) -> S {
194 self.socket
195 }
196}