netlink_packet_core/
message.rs

1// SPDX-License-Identifier: MIT
2
3use anyhow::Context;
4use std::fmt::Debug;
5
6use crate::{
7    payload::{NLMSG_DONE, NLMSG_ERROR, NLMSG_NOOP, NLMSG_OVERRUN},
8    AckMessage,
9    DecodeError,
10    Emitable,
11    ErrorBuffer,
12    ErrorMessage,
13    NetlinkBuffer,
14    NetlinkDeserializable,
15    NetlinkHeader,
16    NetlinkPayload,
17    NetlinkSerializable,
18    Parseable,
19};
20
21/// Represent a netlink message.
22#[derive(Debug, PartialEq, Eq, Clone)]
23pub struct NetlinkMessage<I> {
24    /// Message header (this is common to all the netlink protocols)
25    pub header: NetlinkHeader,
26    /// Inner message, which depends on the netlink protocol being used.
27    pub payload: NetlinkPayload<I>,
28}
29
30impl<I> NetlinkMessage<I> {
31    /// Create a new netlink message from the given header and payload
32    pub fn new(header: NetlinkHeader, payload: NetlinkPayload<I>) -> Self {
33        NetlinkMessage { header, payload }
34    }
35
36    /// Consume this message and return its header and payload
37    pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload<I>) {
38        (self.header, self.payload)
39    }
40}
41
42impl<I> NetlinkMessage<I>
43where
44    I: NetlinkDeserializable,
45{
46    /// Parse the given buffer as a netlink message
47    pub fn deserialize(buffer: &[u8]) -> Result<Self, DecodeError> {
48        let netlink_buffer = NetlinkBuffer::new_checked(&buffer)?;
49        <Self as Parseable<NetlinkBuffer<&&[u8]>>>::parse(&netlink_buffer)
50    }
51}
52
53impl<I> NetlinkMessage<I>
54where
55    I: NetlinkSerializable,
56{
57    /// Return the length of this message in bytes
58    pub fn buffer_len(&self) -> usize {
59        <Self as Emitable>::buffer_len(self)
60    }
61
62    /// Serialize this message and write the serialized data into the
63    /// given buffer. `buffer` must big large enough for the whole
64    /// message to fit, otherwise, this method will panic. To know how
65    /// big the serialized message is, call `buffer_len()`.
66    ///
67    /// # Panic
68    ///
69    /// This method panics if the buffer is not big enough.
70    pub fn serialize(&self, buffer: &mut [u8]) {
71        self.emit(buffer)
72    }
73
74    /// Ensure the header (`NetlinkHeader`) is consistent with the payload (`NetlinkPayload`):
75    ///
76    /// - compute the payload length and set the header's length field
77    /// - check the payload type and set the header's message type field accordingly
78    ///
79    /// If you are not 100% sure the header is correct, this method should be called before calling
80    /// [`Emitable::emit()`](trait.Emitable.html#tymethod.emit), as it could panic if the header is
81    /// inconsistent with the rest of the message.
82    pub fn finalize(&mut self) {
83        self.header.length = self.buffer_len() as u32;
84        self.header.message_type = self.payload.message_type();
85    }
86}
87
88impl<'buffer, B, I> Parseable<NetlinkBuffer<&'buffer B>> for NetlinkMessage<I>
89where
90    B: AsRef<[u8]> + 'buffer,
91    I: NetlinkDeserializable,
92{
93    fn parse(buf: &NetlinkBuffer<&'buffer B>) -> Result<Self, DecodeError> {
94        use self::NetlinkPayload::*;
95
96        let header = <NetlinkHeader as Parseable<NetlinkBuffer<&'buffer B>>>::parse(buf)
97            .context("failed to parse netlink header")?;
98
99        let bytes = buf.payload();
100        let payload = match header.message_type {
101            NLMSG_ERROR => {
102                let buf =
103                    ErrorBuffer::new_checked(&bytes).context("failed to parse NLMSG_ERROR")?;
104                let msg = ErrorMessage::parse(&buf).context("failed to parse NLMSG_ERROR")?;
105                if msg.code >= 0 {
106                    Ack(msg as AckMessage)
107                } else {
108                    Error(msg)
109                }
110            }
111            NLMSG_NOOP => Noop,
112            NLMSG_DONE => Done,
113            NLMSG_OVERRUN => Overrun(bytes.to_vec()),
114            message_type => {
115                let inner_msg = I::deserialize(&header, bytes).context(format!(
116                    "Failed to parse message with type {}",
117                    message_type
118                ))?;
119                InnerMessage(inner_msg)
120            }
121        };
122        Ok(NetlinkMessage { header, payload })
123    }
124}
125
126impl<I> Emitable for NetlinkMessage<I>
127where
128    I: NetlinkSerializable,
129{
130    fn buffer_len(&self) -> usize {
131        use self::NetlinkPayload::*;
132
133        let payload_len = match self.payload {
134            Noop | Done => 0,
135            Overrun(ref bytes) => bytes.len(),
136            Error(ref msg) => msg.buffer_len(),
137            Ack(ref msg) => msg.buffer_len(),
138            InnerMessage(ref msg) => msg.buffer_len(),
139        };
140
141        self.header.buffer_len() + payload_len
142    }
143
144    fn emit(&self, buffer: &mut [u8]) {
145        use self::NetlinkPayload::*;
146
147        self.header.emit(buffer);
148
149        let buffer = &mut buffer[self.header.buffer_len()..self.header.length as usize];
150        match self.payload {
151            Noop | Done => {}
152            Overrun(ref bytes) => buffer.copy_from_slice(bytes),
153            Error(ref msg) => msg.emit(buffer),
154            Ack(ref msg) => msg.emit(buffer),
155            InnerMessage(ref msg) => msg.serialize(buffer),
156        }
157    }
158}
159
160impl<T> From<T> for NetlinkMessage<T>
161where
162    T: Into<NetlinkPayload<T>>,
163{
164    fn from(inner_message: T) -> Self {
165        NetlinkMessage {
166            header: NetlinkHeader::default(),
167            payload: inner_message.into(),
168        }
169    }
170}