netlink_proto/protocol/
protocol.rs

1// SPDX-License-Identifier: MIT
2
3use std::{
4    collections::{hash_map, HashMap, VecDeque},
5    fmt::Debug,
6};
7
8use netlink_packet_core::{
9    constants::*,
10    NetlinkDeserializable,
11    NetlinkMessage,
12    NetlinkPayload,
13    NetlinkSerializable,
14};
15
16use super::Request;
17use crate::sys::SocketAddr;
18
19#[derive(Debug, Eq, PartialEq, Hash)]
20struct RequestId {
21    sequence_number: u32,
22    port: u32,
23}
24
25impl RequestId {
26    fn new(sequence_number: u32, port: u32) -> Self {
27        Self {
28            sequence_number,
29            port,
30        }
31    }
32}
33
34#[derive(Debug, Eq, PartialEq)]
35pub(crate) struct Response<T, M> {
36    pub done: bool,
37    pub message: NetlinkMessage<T>,
38    pub metadata: M,
39}
40
41#[derive(Debug)]
42struct PendingRequest<M> {
43    expecting_ack: bool,
44    metadata: M,
45}
46
47#[derive(Debug, Default)]
48pub(crate) struct Protocol<T, M> {
49    /// Counter that is incremented for each message sent
50    sequence_id: u32,
51
52    /// Requests for which we're awaiting a response. Metadata are
53    /// associated with each request.
54    pending_requests: HashMap<RequestId, PendingRequest<M>>,
55
56    /// Responses to pending requests
57    pub incoming_responses: VecDeque<Response<T, M>>,
58
59    /// Requests from remote peers
60    pub incoming_requests: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
61
62    /// The messages to be sent out
63    pub outgoing_messages: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
64}
65
66impl<T, M> Protocol<T, M>
67where
68    T: Debug + NetlinkSerializable + NetlinkDeserializable,
69    M: Debug + Clone,
70{
71    pub fn new() -> Self {
72        Self {
73            sequence_id: 0,
74            pending_requests: HashMap::new(),
75            incoming_responses: VecDeque::new(),
76            incoming_requests: VecDeque::new(),
77            outgoing_messages: VecDeque::new(),
78        }
79    }
80
81    pub fn handle_message(&mut self, message: NetlinkMessage<T>, source: SocketAddr) {
82        let request_id = RequestId::new(message.header.sequence_number, source.port_number());
83        debug!("handling messages (request id = {:?})", request_id);
84        if let hash_map::Entry::Occupied(entry) = self.pending_requests.entry(request_id) {
85            Self::handle_response(&mut self.incoming_responses, entry, message);
86        } else {
87            self.incoming_requests.push_back((message, source));
88        }
89    }
90
91    fn handle_response(
92        incoming_responses: &mut VecDeque<Response<T, M>>,
93        entry: hash_map::OccupiedEntry<RequestId, PendingRequest<M>>,
94        message: NetlinkMessage<T>,
95    ) {
96        let entry_key;
97        let mut request_id = entry.key();
98        debug!("handling response to request {:?}", request_id);
99
100        // A request is processed if we receive an Ack, Error,
101        // Done, Overrun, or InnerMessage without the
102        // multipart flag and we were not expecting an Ack
103        let done = match message.payload {
104            NetlinkPayload::InnerMessage(_)
105                if message.header.flags & NLM_F_MULTIPART == NLM_F_MULTIPART =>
106            {
107                false
108            }
109            NetlinkPayload::InnerMessage(_) => !entry.get().expecting_ack,
110            _ => true,
111        };
112
113        let metadata = if done {
114            trace!("request {:?} fully processed", request_id);
115            let (k, v) = entry.remove_entry();
116            entry_key = k;
117            request_id = &entry_key;
118            v.metadata
119        } else {
120            trace!("more responses to request {:?} may come", request_id);
121            entry.get().metadata.clone()
122        };
123
124        let response = Response::<T, M> {
125            done,
126            message,
127            metadata,
128        };
129        incoming_responses.push_back(response);
130        debug!("done handling response to request {:?}", request_id);
131    }
132
133    pub fn request(&mut self, request: Request<T, M>) {
134        let Request {
135            mut message,
136            metadata,
137            destination,
138        } = request;
139
140        self.set_sequence_id(&mut message);
141        let request_id = RequestId::new(self.sequence_id, destination.port_number());
142        let flags = message.header.flags;
143        self.outgoing_messages.push_back((message, destination));
144
145        // If we expect a response, we store the request id so that we
146        // can map the response to this specific request.
147        //
148        // Note that we expect responses in three cases only:
149        //  - when the request has the NLM_F_REQUEST flag
150        //  - when the request has the NLM_F_ACK flag
151        //  - when the request has the NLM_F_ECHO flag
152        let expecting_ack = flags & NLM_F_ACK == NLM_F_ACK;
153        if flags & NLM_F_REQUEST == NLM_F_REQUEST
154            || flags & NLM_F_ECHO == NLM_F_ECHO
155            || expecting_ack
156        {
157            self.pending_requests.insert(
158                request_id,
159                PendingRequest {
160                    expecting_ack,
161                    metadata,
162                },
163            );
164        }
165    }
166
167    fn set_sequence_id(&mut self, message: &mut NetlinkMessage<T>) {
168        self.sequence_id += 1;
169        message.header.sequence_number = self.sequence_id;
170    }
171}