1use std::{
4 fmt::Debug,
5 io,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use futures::{
11 channel::mpsc::{UnboundedReceiver, UnboundedSender},
12 Future,
13 Sink,
14 Stream,
15};
16use log::{error, warn};
17use netlink_packet_core::{
18 NetlinkDeserializable,
19 NetlinkMessage,
20 NetlinkPayload,
21 NetlinkSerializable,
22};
23
24use crate::{
25 codecs::{NetlinkCodec, NetlinkMessageCodec},
26 framed::NetlinkFramed,
27 sys::{AsyncSocket, SocketAddr},
28 Protocol,
29 Request,
30 Response,
31};
32
33#[cfg(feature = "tokio_socket")]
34use netlink_sys::TokioSocket as DefaultSocket;
35#[cfg(not(feature = "tokio_socket"))]
36type DefaultSocket = ();
37
38pub struct Connection<T, S = DefaultSocket, C = NetlinkCodec>
43where
44 T: Debug + NetlinkSerializable + NetlinkDeserializable,
45{
46 socket: NetlinkFramed<T, S, C>,
47
48 protocol: Protocol<T, UnboundedSender<NetlinkMessage<T>>>,
49
50 requests_rx: Option<UnboundedReceiver<Request<T>>>,
52
53 unsolicited_messages_tx: Option<UnboundedSender<(NetlinkMessage<T>, SocketAddr)>>,
56
57 socket_closed: bool,
58}
59
60impl<T, S, C> Connection<T, S, C>
61where
62 T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
63 S: AsyncSocket,
64 C: NetlinkMessageCodec,
65{
66 pub(crate) fn new(
67 requests_rx: UnboundedReceiver<Request<T>>,
68 unsolicited_messages_tx: UnboundedSender<(NetlinkMessage<T>, SocketAddr)>,
69 protocol: isize,
70 ) -> io::Result<Self> {
71 let socket = S::new(protocol)?;
72 Ok(Connection {
73 socket: NetlinkFramed::new(socket),
74 protocol: Protocol::new(),
75 requests_rx: Some(requests_rx),
76 unsolicited_messages_tx: Some(unsolicited_messages_tx),
77 socket_closed: false,
78 })
79 }
80
81 pub fn socket_mut(&mut self) -> &mut S {
82 self.socket.get_mut()
83 }
84
85 pub fn poll_send_messages(&mut self, cx: &mut Context) {
86 trace!("poll_send_messages called");
87 let Connection {
88 ref mut socket,
89 ref mut protocol,
90 ..
91 } = self;
92 let mut socket = Pin::new(socket);
93
94 while !protocol.outgoing_messages.is_empty() {
95 trace!("found outgoing message to send checking if socket is ready");
96 if let Poll::Ready(Err(e)) = Pin::as_mut(&mut socket).poll_ready(cx) {
97 warn!("netlink socket shut down: {:?}", e);
100 self.socket_closed = true;
101 return;
102 }
103
104 let (mut message, addr) = protocol.outgoing_messages.pop_front().unwrap();
105 message.finalize();
106
107 trace!("sending outgoing message");
108 if let Err(e) = Pin::as_mut(&mut socket).start_send((message, addr)) {
109 error!("failed to send message: {:?}", e);
110 self.socket_closed = true;
111 return;
112 }
113 }
114
115 trace!("poll_send_messages done");
116 self.poll_flush(cx)
117 }
118
119 pub fn poll_flush(&mut self, cx: &mut Context) {
120 trace!("poll_flush called");
121 if let Poll::Ready(Err(e)) = Pin::new(&mut self.socket).poll_flush(cx) {
122 warn!("error flushing netlink socket: {:?}", e);
123 self.socket_closed = true;
124 }
125 }
126
127 pub fn poll_read_messages(&mut self, cx: &mut Context) {
128 trace!("poll_read_messages called");
129 let mut socket = Pin::new(&mut self.socket);
130
131 loop {
132 trace!("polling socket");
133 match socket.as_mut().poll_next(cx) {
134 Poll::Ready(Some((message, addr))) => {
135 trace!("read datagram from socket");
136 self.protocol.handle_message(message, addr);
137 }
138 Poll::Ready(None) => {
139 warn!("netlink socket stream shut down");
140 self.socket_closed = true;
141 return;
142 }
143 Poll::Pending => {
144 trace!("no datagram read from socket");
145 return;
146 }
147 }
148 }
149 }
150
151 pub fn poll_requests(&mut self, cx: &mut Context) {
152 trace!("poll_requests called");
153 if let Some(mut stream) = self.requests_rx.as_mut() {
154 loop {
155 match Pin::new(&mut stream).poll_next(cx) {
156 Poll::Ready(Some(request)) => self.protocol.request(request),
157 Poll::Ready(None) => break,
158 Poll::Pending => return,
159 }
160 }
161 let _ = self.requests_rx.take();
162 trace!("no new requests to handle poll_requests done");
163 }
164 }
165
166 pub fn forward_unsolicited_messages(&mut self) {
167 if self.unsolicited_messages_tx.is_none() {
168 while let Some((message, source)) = self.protocol.incoming_requests.pop_front() {
169 warn!(
170 "ignoring unsolicited message {:?} from {:?}",
171 message, source
172 );
173 }
174 return;
175 }
176
177 trace!("forward_unsolicited_messages called");
178 let mut ready = false;
179
180 let Connection {
181 ref mut protocol,
182 ref mut unsolicited_messages_tx,
183 ..
184 } = self;
185
186 while let Some((message, source)) = protocol.incoming_requests.pop_front() {
187 if unsolicited_messages_tx
188 .as_mut()
189 .unwrap()
190 .unbounded_send((message, source))
191 .is_err()
192 {
193 warn!("failed to forward message to connection handle: channel closed");
197 ready = true;
198 break;
199 }
200 }
201
202 if ready {
203 let _ = self.unsolicited_messages_tx.take();
205 self.forward_unsolicited_messages();
207 }
208
209 trace!("forward_unsolicited_messages done");
210 }
211
212 pub fn forward_responses(&mut self) {
213 trace!("forward_responses called");
214 let protocol = &mut self.protocol;
215
216 while let Some(response) = protocol.incoming_responses.pop_front() {
217 let Response {
218 message,
219 done,
220 metadata: tx,
221 } = response;
222 if done {
223 use NetlinkPayload::*;
224 match &message.payload {
225 Noop | Done | Ack(_) => {
234 trace!("not forwarding Noop/Ack/Done message to the handle");
235 continue;
236 }
237 Overrun(_) => unimplemented!("overrun is not handled yet"),
239 Error(_) | InnerMessage(_) => {}
244 }
245 }
246
247 trace!("forwarding response to the handle");
248 if tx.unbounded_send(message).is_err() {
249 warn!("failed to forward response back to the handle");
252 }
253 }
254 trace!("forward_responses done");
255 }
256
257 pub fn should_shut_down(&self) -> bool {
258 self.socket_closed || (self.unsolicited_messages_tx.is_none() && self.requests_rx.is_none())
259 }
260}
261
262impl<T, S, C> Future for Connection<T, S, C>
263where
264 T: Debug + NetlinkSerializable + NetlinkDeserializable + Unpin,
265 S: AsyncSocket,
266 C: NetlinkMessageCodec,
267{
268 type Output = ();
269
270 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
271 trace!("polling Connection");
272 let pinned = self.get_mut();
273
274 debug!("reading incoming messages");
275 pinned.poll_read_messages(cx);
276
277 debug!("forwarding unsolicited messages to the connection handle");
278 pinned.forward_unsolicited_messages();
279
280 debug!("forwaring responses to previous requests to the connection handle");
281 pinned.forward_responses();
282
283 debug!("handling requests");
284 pinned.poll_requests(cx);
285
286 debug!("sending messages");
287 pinned.poll_send_messages(cx);
288
289 trace!("done polling Connection");
290
291 if pinned.should_shut_down() {
292 Poll::Ready(())
293 } else {
294 Poll::Pending
295 }
296 }
297}