simple_dns/dns/
header.rs

1use std::{convert::TryInto, io::Write};
2
3use crate::{rdata::OPT, ResourceRecord};
4
5use super::{PacketFlag, OPCODE, RCODE};
6
7pub(crate) mod masks {
8    pub const OPCODE_MASK: u16 = 0b0111_1000_0000_0000;
9    pub const RESERVED_MASK: u16 = 0b0000_0000_0100_0000;
10    pub const RESPONSE_CODE_MASK: u16 = 0b0000_0000_0000_1111;
11}
12/// Contains general information about the packet
13#[derive(Debug, Clone)]
14pub(crate) struct Header<'a> {
15    /// The identification of the packet, must be defined when querying
16    pub id: u16,
17    /// Indicates the type of query in this packet
18    pub opcode: OPCODE,
19    /// [RCODE](`RCODE`) indicates the response code for this packet
20    pub response_code: RCODE,
21
22    pub z_flags: PacketFlag,
23
24    pub opt: Option<OPT<'a>>,
25}
26
27impl<'a> Header<'a> {
28    /// Creates a new header for a query packet
29    pub fn new_query(id: u16) -> Self {
30        Self {
31            id,
32            opcode: OPCODE::StandardQuery,
33            response_code: RCODE::NoError,
34            z_flags: PacketFlag::empty(),
35            opt: None,
36        }
37    }
38
39    /// Creates a new header for a reply packet
40    pub fn new_reply(id: u16, opcode: OPCODE) -> Self {
41        Self {
42            id,
43            opcode,
44            response_code: RCODE::NoError,
45            z_flags: PacketFlag::RESPONSE,
46            opt: None,
47        }
48    }
49
50    pub fn set_flags(&mut self, flags: PacketFlag) {
51        self.z_flags |= flags;
52    }
53
54    pub fn remove_flags(&mut self, flags: PacketFlag) {
55        self.z_flags.remove(flags);
56    }
57
58    pub fn has_flags(&self, flags: PacketFlag) -> bool {
59        self.z_flags.contains(flags)
60    }
61
62    /// Parse a slice of 12 bytes into a Packet header
63    pub fn parse(data: &[u8]) -> crate::Result<Self> {
64        if data.len() < 12 {
65            return Err(crate::SimpleDnsError::InsufficientData);
66        }
67
68        let flags = u16::from_be_bytes(data[2..4].try_into()?);
69        if flags & masks::RESERVED_MASK != 0 {
70            return Err(crate::SimpleDnsError::InvalidHeaderData);
71        }
72
73        let header = Self {
74            id: u16::from_be_bytes(data[..2].try_into()?),
75            opcode: ((flags & masks::OPCODE_MASK) >> masks::OPCODE_MASK.trailing_zeros()).into(),
76            response_code: (flags & masks::RESPONSE_CODE_MASK).into(),
77            z_flags: PacketFlag::from_bits_truncate(flags),
78            opt: None,
79        };
80        Ok(header)
81    }
82
83    /// Writes this header to a buffer of 12 bytes
84    pub fn write_to<T: Write>(
85        &self,
86        buffer: &mut T,
87        questions: u16,
88        answers: u16,
89        name_servers: u16,
90        additional_records: u16,
91    ) -> crate::Result<()> {
92        buffer.write_all(&self.id.to_be_bytes())?;
93        buffer.write_all(&self.get_flags().to_be_bytes())?;
94        buffer.write_all(&questions.to_be_bytes())?;
95        buffer.write_all(&answers.to_be_bytes())?;
96        buffer.write_all(&name_servers.to_be_bytes())?;
97        buffer.write_all(&additional_records.to_be_bytes())?;
98
99        Ok(())
100    }
101
102    fn get_flags(&self) -> u16 {
103        let mut flags = self.z_flags.bits();
104
105        flags |= (self.opcode as u16) << masks::OPCODE_MASK.trailing_zeros();
106        flags |= self.response_code as u16 & masks::RESPONSE_CODE_MASK;
107
108        flags
109    }
110
111    pub(crate) fn opt_rr(&self) -> Option<ResourceRecord> {
112        self.opt.as_ref().map(|opt| {
113            ResourceRecord::new(
114                crate::Name::new_unchecked("."),
115                crate::CLASS::IN,
116                opt.encode_ttl(self),
117                crate::rdata::RData::OPT(opt.clone()),
118            )
119        })
120    }
121
122    pub(crate) fn extract_info_from_opt_rr(&mut self, opt_rr: Option<ResourceRecord<'a>>) {
123        if let Some(opt) = opt_rr {
124            self.response_code = OPT::extract_rcode_from_ttl(opt.ttl, self);
125            self.opt = match opt.rdata {
126                crate::rdata::RData::OPT(opt) => Some(opt),
127                _ => unreachable!(),
128            };
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use crate::header_buffer;
136
137    use super::*;
138
139    #[test]
140    fn write_example_query() {
141        let mut header = Header::new_query(u16::MAX);
142
143        header.set_flags(PacketFlag::TRUNCATION | PacketFlag::RECURSION_DESIRED);
144
145        let mut buf = vec![];
146        header.write_to(&mut buf, 0, 0, 0, 0).unwrap();
147
148        assert_eq!(
149            b"\xff\xff\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00",
150            &buf[..]
151        );
152    }
153
154    #[test]
155    fn parse_example_query() {
156        let buffer = b"\xff\xff\x03\x00\x00\x02\x00\x02\x00\x02\x00\x02";
157        let header = Header::parse(&buffer[..]).unwrap();
158
159        assert_eq!(u16::MAX, header.id);
160        assert_eq!(OPCODE::StandardQuery, header.opcode);
161        assert!(!header.has_flags(
162            PacketFlag::AUTHORITATIVE_ANSWER
163                | PacketFlag::RECURSION_AVAILABLE
164                | PacketFlag::RESPONSE
165        ));
166        assert!(header.has_flags(PacketFlag::TRUNCATION | PacketFlag::RECURSION_DESIRED));
167        assert_eq!(RCODE::NoError, header.response_code);
168        assert_eq!(2, header_buffer::additional_records(&buffer[..]).unwrap());
169        assert_eq!(2, header_buffer::answers(&buffer[..]).unwrap());
170        assert_eq!(2, header_buffer::name_servers(&buffer[..]).unwrap());
171        assert_eq!(2, header_buffer::questions(&buffer[..]).unwrap());
172    }
173
174    #[test]
175    fn read_write_questions_count() {
176        let mut buffer = [0u8; 12];
177        header_buffer::set_questions(&mut buffer, 1);
178        assert_eq!(1, header_buffer::questions(&buffer).unwrap());
179    }
180
181    #[test]
182    fn read_write_answers_count() {
183        let mut buffer = [0u8; 12];
184        header_buffer::set_answers(&mut buffer, 1);
185        assert_eq!(1, header_buffer::answers(&buffer).unwrap());
186    }
187
188    #[test]
189    fn read_write_name_servers_count() {
190        let mut buffer = [0u8; 12];
191        header_buffer::set_name_servers(&mut buffer, 1);
192        assert_eq!(1, header_buffer::name_servers(&buffer).unwrap());
193    }
194
195    #[test]
196    fn read_write_additional_records_count() {
197        let mut buffer = [0u8; 12];
198        header_buffer::set_additional_records(&mut buffer, 1);
199        assert_eq!(1, header_buffer::additional_records(&buffer).unwrap());
200    }
201
202    #[test]
203    fn big_rcode_doesnt_break_header() {
204        let mut header = Header::new_reply(1, OPCODE::StandardQuery);
205        header.response_code = RCODE::BADVERS;
206
207        let mut buffer = vec![];
208        header.write_to(&mut buffer, 0, 0, 0, 0).unwrap();
209
210        assert_ne!(RCODE::BADVERS, header_buffer::rcode(&buffer[..]).unwrap());
211
212        let header = Header::parse(&buffer[..]).expect("Header parsing failed");
213        assert_eq!(RCODE::NoError, header.response_code);
214        assert!(header.has_flags(PacketFlag::RESPONSE));
215    }
216}