use core::ops::Range;
use anyhow::Context;
use byteorder::{ByteOrder, NativeEndian};
use crate::{
traits::{Emitable, Parseable},
DecodeError,
};
type Field = Range<usize>;
pub const NLA_F_NESTED: u16 = 0x8000;
pub const NLA_F_NET_BYTEORDER: u16 = 0x4000;
pub const NLA_TYPE_MASK: u16 = !(NLA_F_NET_BYTEORDER | NLA_F_NESTED);
pub const NLA_ALIGNTO: usize = 4;
pub const NLA_HEADER_SIZE: usize = 4;
#[macro_export]
macro_rules! nla_align {
($len: expr) => {
($len + NLA_ALIGNTO - 1) & !(NLA_ALIGNTO - 1)
};
}
const LENGTH: Field = 0..2;
const TYPE: Field = 2..4;
#[allow(non_snake_case)]
fn VALUE(length: usize) -> Field {
TYPE.end..TYPE.end + length
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct NlaBuffer<T: AsRef<[u8]>> {
buffer: T,
}
impl<T: AsRef<[u8]>> NlaBuffer<T> {
pub fn new(buffer: T) -> NlaBuffer<T> {
NlaBuffer { buffer }
}
pub fn new_checked(buffer: T) -> Result<NlaBuffer<T>, DecodeError> {
let buffer = Self::new(buffer);
buffer.check_buffer_length().context("invalid NLA buffer")?;
Ok(buffer)
}
pub fn check_buffer_length(&self) -> Result<(), DecodeError> {
let len = self.buffer.as_ref().len();
if len < TYPE.end {
Err(format!(
"buffer has length {}, but an NLA header is {} bytes",
len, TYPE.end
)
.into())
} else if len < self.length() as usize {
Err(format!(
"buffer has length: {}, but the NLA is {} bytes",
len,
self.length()
)
.into())
} else if (self.length() as usize) < TYPE.end {
Err(format!(
"NLA has invalid length: {} (should be at least {} bytes",
self.length(),
TYPE.end,
)
.into())
} else {
Ok(())
}
}
pub fn into_inner(self) -> T {
self.buffer
}
pub fn inner(&mut self) -> &T {
&self.buffer
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.buffer
}
pub fn kind(&self) -> u16 {
let data = self.buffer.as_ref();
NativeEndian::read_u16(&data[TYPE]) & NLA_TYPE_MASK
}
pub fn nested_flag(&self) -> bool {
let data = self.buffer.as_ref();
(NativeEndian::read_u16(&data[TYPE]) & NLA_F_NESTED) != 0
}
pub fn network_byte_order_flag(&self) -> bool {
let data = self.buffer.as_ref();
(NativeEndian::read_u16(&data[TYPE]) & NLA_F_NET_BYTEORDER) != 0
}
pub fn length(&self) -> u16 {
let data = self.buffer.as_ref();
NativeEndian::read_u16(&data[LENGTH])
}
pub fn value_length(&self) -> usize {
self.length() as usize - TYPE.end
}
}
impl<T: AsRef<[u8]> + AsMut<[u8]>> NlaBuffer<T> {
pub fn set_kind(&mut self, kind: u16) {
let data = self.buffer.as_mut();
NativeEndian::write_u16(&mut data[TYPE], kind & NLA_TYPE_MASK)
}
pub fn set_nested_flag(&mut self) {
let kind = self.kind();
let data = self.buffer.as_mut();
NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NESTED)
}
pub fn set_network_byte_order_flag(&mut self) {
let kind = self.kind();
let data = self.buffer.as_mut();
NativeEndian::write_u16(&mut data[TYPE], kind | NLA_F_NET_BYTEORDER)
}
pub fn set_length(&mut self, length: u16) {
let data = self.buffer.as_mut();
NativeEndian::write_u16(&mut data[LENGTH], length)
}
}
impl<'buffer, T: AsRef<[u8]> + ?Sized> NlaBuffer<&'buffer T> {
pub fn value(&self) -> &[u8] {
&self.buffer.as_ref()[VALUE(self.value_length())]
}
}
impl<'buffer, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NlaBuffer<&'buffer mut T> {
pub fn value_mut(&mut self) -> &mut [u8] {
let length = VALUE(self.value_length());
&mut self.buffer.as_mut()[length]
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct DefaultNla {
kind: u16,
value: Vec<u8>,
}
impl DefaultNla {
pub fn new(kind: u16, value: Vec<u8>) -> Self {
Self { kind, value }
}
}
impl Nla for DefaultNla {
fn value_len(&self) -> usize {
self.value.len()
}
fn kind(&self) -> u16 {
self.kind
}
fn emit_value(&self, buffer: &mut [u8]) {
buffer.copy_from_slice(self.value.as_slice());
}
}
impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'buffer T>>
for DefaultNla
{
fn parse(buf: &NlaBuffer<&'buffer T>) -> Result<Self, DecodeError> {
let mut kind = buf.kind();
if buf.network_byte_order_flag() {
kind |= NLA_F_NET_BYTEORDER;
}
if buf.nested_flag() {
kind |= NLA_F_NESTED;
}
Ok(DefaultNla {
kind,
value: buf.value().to_vec(),
})
}
}
pub trait Nla {
fn value_len(&self) -> usize;
fn kind(&self) -> u16;
fn emit_value(&self, buffer: &mut [u8]);
#[inline]
fn is_nested(&self) -> bool {
(self.kind() & NLA_F_NESTED) != 0
}
#[inline]
fn is_network_byteorder(&self) -> bool {
(self.kind() & NLA_F_NET_BYTEORDER) != 0
}
}
impl<T: Nla> Emitable for T {
fn buffer_len(&self) -> usize {
nla_align!(self.value_len()) + NLA_HEADER_SIZE
}
fn emit(&self, buffer: &mut [u8]) {
let mut buffer = NlaBuffer::new(buffer);
buffer.set_kind(self.kind());
if self.is_network_byteorder() {
buffer.set_network_byte_order_flag()
}
if self.is_nested() {
buffer.set_nested_flag()
}
buffer.set_length(self.value_len() as u16 + NLA_HEADER_SIZE as u16);
self.emit_value(buffer.value_mut());
let padding = nla_align!(self.value_len()) - self.value_len();
for i in 0..padding {
buffer.inner_mut()[NLA_HEADER_SIZE + self.value_len() + i] = 0;
}
}
}
impl<'a, T: Nla> Emitable for &'a [T] {
fn buffer_len(&self) -> usize {
self.iter().fold(0, |acc, nla| {
assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
acc + nla.buffer_len()
})
}
fn emit(&self, buffer: &mut [u8]) {
let mut start = 0;
let mut end: usize;
for nla in self.iter() {
let attr_len = nla.buffer_len();
assert_eq!(nla.buffer_len() % NLA_ALIGNTO, 0);
end = start + attr_len;
nla.emit(&mut buffer[start..end]);
start = end;
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct NlasIterator<T> {
position: usize,
buffer: T,
}
impl<T> NlasIterator<T> {
pub fn new(buffer: T) -> Self {
NlasIterator {
position: 0,
buffer,
}
}
}
impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator
for NlasIterator<&'buffer T>
{
type Item = Result<NlaBuffer<&'buffer [u8]>, DecodeError>;
fn next(&mut self) -> Option<Self::Item> {
if self.position >= self.buffer.as_ref().len() {
return None;
}
match NlaBuffer::new_checked(&self.buffer.as_ref()[self.position..]) {
Ok(nla_buffer) => {
self.position += nla_align!(nla_buffer.length() as usize);
Some(Ok(nla_buffer))
}
Err(e) => {
self.position = self.buffer.as_ref().len();
Some(Err(e))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn network_byteorder() {
static TEST_ATTRIBUTE: &[u8] =
&[0x08, 0x00, 0x06, 0x40, 0x00, 0x00, 0x0e, 0x10];
let buffer = NlaBuffer::new(TEST_ATTRIBUTE);
let buffer_is_net = buffer.network_byte_order_flag();
let buffer_is_nest = buffer.nested_flag();
let nla = DefaultNla::parse(&buffer).unwrap();
let mut emitted_buffer = vec![0; nla.buffer_len()];
nla.emit(&mut emitted_buffer);
let attr_is_net = nla.is_network_byteorder();
let attr_is_nest = nla.is_nested();
let emit = NlaBuffer::new(emitted_buffer);
let emit_is_net = emit.network_byte_order_flag();
let emit_is_nest = emit.nested_flag();
assert_eq!(
[buffer_is_net, buffer_is_nest],
[attr_is_net, attr_is_nest]
);
assert_eq!([attr_is_net, attr_is_nest], [emit_is_net, emit_is_nest]);
}
fn get_len() -> usize {
18446744073709551615
}
#[test]
fn test_align() {
assert_eq!(nla_align!(13), 16);
assert_eq!(nla_align!(16), 16);
assert_eq!(nla_align!(0), 0);
assert_eq!(nla_align!(1), 4);
assert_eq!(nla_align!(get_len() - 4), usize::MAX - 3);
}
#[test]
#[should_panic]
fn test_align_overflow() {
assert_eq!(nla_align!(get_len() - 3), usize::MAX);
}
}