netlink_sys/
tokio.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    io,
5    os::unix::io::{AsRawFd, FromRawFd, RawFd},
6    task::{Context, Poll},
7};
8
9use futures::ready;
10use log::trace;
11use tokio::io::unix::AsyncFd;
12
13use crate::{AsyncSocket, Socket, SocketAddr};
14
15/// An I/O object representing a Netlink socket.
16pub struct TokioSocket(AsyncFd<Socket>);
17
18impl FromRawFd for TokioSocket {
19    unsafe fn from_raw_fd(fd: RawFd) -> Self {
20        let socket = Socket::from_raw_fd(fd);
21        socket.set_non_blocking(true).unwrap();
22        TokioSocket(AsyncFd::new(socket).unwrap())
23    }
24}
25
26impl AsRawFd for TokioSocket {
27    fn as_raw_fd(&self) -> RawFd {
28        self.0.get_ref().as_raw_fd()
29    }
30}
31
32impl AsyncSocket for TokioSocket {
33    fn socket_ref(&self) -> &Socket {
34        self.0.get_ref()
35    }
36
37    /// Mutable access to underyling [`Socket`]
38    fn socket_mut(&mut self) -> &mut Socket {
39        self.0.get_mut()
40    }
41
42    fn new(protocol: isize) -> io::Result<Self> {
43        let socket = Socket::new(protocol)?;
44        socket.set_non_blocking(true)?;
45        Ok(Self(AsyncFd::new(socket)?))
46    }
47
48    fn poll_send(
49        &mut self,
50        cx: &mut Context<'_>,
51        buf: &[u8],
52    ) -> Poll<io::Result<usize>> {
53        loop {
54            // Check if the socket it writable. If
55            // AsyncFd::poll_write_ready returns NotReady, it will
56            // already have arranged for the current task to be
57            // notified when the socket becomes writable, so we can
58            // just return Pending
59            let mut guard = ready!(self.0.poll_write_ready(cx))?;
60
61            match guard.try_io(|inner| inner.get_ref().send(buf, 0)) {
62                Ok(x) => return Poll::Ready(x),
63                Err(_would_block) => continue,
64            }
65        }
66    }
67
68    fn poll_send_to(
69        &mut self,
70        cx: &mut Context<'_>,
71        buf: &[u8],
72        addr: &SocketAddr,
73    ) -> Poll<io::Result<usize>> {
74        loop {
75            let mut guard = ready!(self.0.poll_write_ready(cx))?;
76
77            match guard.try_io(|inner| inner.get_ref().send_to(buf, addr, 0)) {
78                Ok(x) => return Poll::Ready(x),
79                Err(_would_block) => continue,
80            }
81        }
82    }
83
84    fn poll_recv<B>(
85        &mut self,
86        cx: &mut Context<'_>,
87        buf: &mut B,
88    ) -> Poll<io::Result<()>>
89    where
90        B: bytes::BufMut,
91    {
92        loop {
93            // Check if the socket is readable. If not,
94            // AsyncFd::poll_read_ready would have arranged for the
95            // current task to be polled again when the socket becomes
96            // readable, so we can just return Pending
97            let mut guard = ready!(self.0.poll_read_ready(cx))?;
98
99            match guard.try_io(|inner| inner.get_ref().recv(buf, 0)) {
100                Ok(x) => return Poll::Ready(x.map(|_len| ())),
101                Err(_would_block) => continue,
102            }
103        }
104    }
105
106    fn poll_recv_from<B>(
107        &mut self,
108        cx: &mut Context<'_>,
109        buf: &mut B,
110    ) -> Poll<io::Result<SocketAddr>>
111    where
112        B: bytes::BufMut,
113    {
114        loop {
115            trace!("poll_recv_from called");
116            let mut guard = ready!(self.0.poll_read_ready(cx))?;
117            trace!("poll_recv_from socket is ready for reading");
118
119            match guard.try_io(|inner| inner.get_ref().recv_from(buf, 0)) {
120                Ok(x) => {
121                    trace!("poll_recv_from {:?} bytes read", x);
122                    return Poll::Ready(x.map(|(_len, addr)| addr));
123                }
124                Err(_would_block) => {
125                    trace!("poll_recv_from socket would block");
126                    continue;
127                }
128            }
129        }
130    }
131
132    fn poll_recv_from_full(
133        &mut self,
134        cx: &mut Context<'_>,
135    ) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
136        loop {
137            trace!("poll_recv_from_full called");
138            let mut guard = ready!(self.0.poll_read_ready(cx))?;
139            trace!("poll_recv_from_full socket is ready for reading");
140
141            match guard.try_io(|inner| inner.get_ref().recv_from_full()) {
142                Ok(x) => {
143                    trace!("poll_recv_from_full {:?} bytes read", x);
144                    return Poll::Ready(x);
145                }
146                Err(_would_block) => {
147                    trace!("poll_recv_from_full socket would block");
148                    continue;
149                }
150            }
151        }
152    }
153}