netlink_proto/protocol/
protocol.rs1use std::{
4 collections::{hash_map, HashMap, VecDeque},
5 fmt::Debug,
6};
7
8use netlink_packet_core::{
9 constants::*,
10 NetlinkDeserializable,
11 NetlinkMessage,
12 NetlinkPayload,
13 NetlinkSerializable,
14};
15
16use super::Request;
17use crate::sys::SocketAddr;
18
19#[derive(Debug, Eq, PartialEq, Hash)]
20struct RequestId {
21 sequence_number: u32,
22 port: u32,
23}
24
25impl RequestId {
26 fn new(sequence_number: u32, port: u32) -> Self {
27 Self {
28 sequence_number,
29 port,
30 }
31 }
32}
33
34#[derive(Debug, Eq, PartialEq)]
35pub(crate) struct Response<T, M> {
36 pub done: bool,
37 pub message: NetlinkMessage<T>,
38 pub metadata: M,
39}
40
41#[derive(Debug)]
42struct PendingRequest<M> {
43 expecting_ack: bool,
44 metadata: M,
45}
46
47#[derive(Debug, Default)]
48pub(crate) struct Protocol<T, M> {
49 sequence_id: u32,
51
52 pending_requests: HashMap<RequestId, PendingRequest<M>>,
55
56 pub incoming_responses: VecDeque<Response<T, M>>,
58
59 pub incoming_requests: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
61
62 pub outgoing_messages: VecDeque<(NetlinkMessage<T>, SocketAddr)>,
64}
65
66impl<T, M> Protocol<T, M>
67where
68 T: Debug + NetlinkSerializable + NetlinkDeserializable,
69 M: Debug + Clone,
70{
71 pub fn new() -> Self {
72 Self {
73 sequence_id: 0,
74 pending_requests: HashMap::new(),
75 incoming_responses: VecDeque::new(),
76 incoming_requests: VecDeque::new(),
77 outgoing_messages: VecDeque::new(),
78 }
79 }
80
81 pub fn handle_message(&mut self, message: NetlinkMessage<T>, source: SocketAddr) {
82 let request_id = RequestId::new(message.header.sequence_number, source.port_number());
83 debug!("handling messages (request id = {:?})", request_id);
84 if let hash_map::Entry::Occupied(entry) = self.pending_requests.entry(request_id) {
85 Self::handle_response(&mut self.incoming_responses, entry, message);
86 } else {
87 self.incoming_requests.push_back((message, source));
88 }
89 }
90
91 fn handle_response(
92 incoming_responses: &mut VecDeque<Response<T, M>>,
93 entry: hash_map::OccupiedEntry<RequestId, PendingRequest<M>>,
94 message: NetlinkMessage<T>,
95 ) {
96 let entry_key;
97 let mut request_id = entry.key();
98 debug!("handling response to request {:?}", request_id);
99
100 let done = match message.payload {
104 NetlinkPayload::InnerMessage(_)
105 if message.header.flags & NLM_F_MULTIPART == NLM_F_MULTIPART =>
106 {
107 false
108 }
109 NetlinkPayload::InnerMessage(_) => !entry.get().expecting_ack,
110 _ => true,
111 };
112
113 let metadata = if done {
114 trace!("request {:?} fully processed", request_id);
115 let (k, v) = entry.remove_entry();
116 entry_key = k;
117 request_id = &entry_key;
118 v.metadata
119 } else {
120 trace!("more responses to request {:?} may come", request_id);
121 entry.get().metadata.clone()
122 };
123
124 let response = Response::<T, M> {
125 done,
126 message,
127 metadata,
128 };
129 incoming_responses.push_back(response);
130 debug!("done handling response to request {:?}", request_id);
131 }
132
133 pub fn request(&mut self, request: Request<T, M>) {
134 let Request {
135 mut message,
136 metadata,
137 destination,
138 } = request;
139
140 self.set_sequence_id(&mut message);
141 let request_id = RequestId::new(self.sequence_id, destination.port_number());
142 let flags = message.header.flags;
143 self.outgoing_messages.push_back((message, destination));
144
145 let expecting_ack = flags & NLM_F_ACK == NLM_F_ACK;
153 if flags & NLM_F_REQUEST == NLM_F_REQUEST
154 || flags & NLM_F_ECHO == NLM_F_ECHO
155 || expecting_ack
156 {
157 self.pending_requests.insert(
158 request_id,
159 PendingRequest {
160 expecting_ack,
161 metadata,
162 },
163 );
164 }
165 }
166
167 fn set_sequence_id(&mut self, message: &mut NetlinkMessage<T>) {
168 self.sequence_id += 1;
169 message.header.sequence_number = self.sequence_id;
170 }
171}