simple_dns/dns/
packet.rs

1use std::{
2    collections::HashMap,
3    io::{Cursor, Seek, Write},
4};
5
6use crate::{header_buffer, rdata::OPT, RCODE};
7
8use super::{Header, PacketFlag, Question, ResourceRecord, WireFormat, OPCODE};
9
10/// Represents a DNS message packet
11///
12/// When working with EDNS packets, use [Packet::opt] and [Packet::opt_mut] to add or access [OPT] packet information
13#[derive(Debug, Clone)]
14pub struct Packet<'a> {
15    /// Packet header
16    header: Header<'a>,
17    /// Questions section
18    pub questions: Vec<Question<'a>>,
19    /// Answers section
20    pub answers: Vec<ResourceRecord<'a>>,
21    /// Name servers section
22    pub name_servers: Vec<ResourceRecord<'a>>,
23    /// Aditional records section.  
24    /// DO NOT use this field to add OPT record, use [`Packet::opt_mut`] instead
25    pub additional_records: Vec<ResourceRecord<'a>>,
26}
27
28impl<'a> Packet<'a> {
29    /// Creates a new empty packet with a query header
30    pub fn new_query(id: u16) -> Self {
31        Self {
32            header: Header::new_query(id),
33            questions: Vec::new(),
34            answers: Vec::new(),
35            name_servers: Vec::new(),
36            additional_records: Vec::new(),
37        }
38    }
39
40    /// Creates a new empty packet with a reply header
41    pub fn new_reply(id: u16) -> Self {
42        Self {
43            header: Header::new_reply(id, OPCODE::StandardQuery),
44            questions: Vec::new(),
45            answers: Vec::new(),
46            name_servers: Vec::new(),
47            additional_records: Vec::new(),
48        }
49    }
50
51    /// Get packet id
52    pub fn id(&self) -> u16 {
53        self.header.id
54    }
55
56    /// Set packet id
57    pub fn set_id(&mut self, id: u16) {
58        self.header.id = id;
59    }
60
61    /// Set flags in the packet
62    pub fn set_flags(&mut self, flags: PacketFlag) {
63        self.header.set_flags(flags);
64    }
65
66    /// Remove flags present in the packet
67    pub fn remove_flags(&mut self, flags: PacketFlag) {
68        self.header.remove_flags(flags)
69    }
70
71    /// Check if the packet has flags set
72    pub fn has_flags(&self, flags: PacketFlag) -> bool {
73        self.header.has_flags(flags)
74    }
75
76    /// Get this packet [RCODE] information
77    pub fn rcode(&self) -> RCODE {
78        self.header.response_code
79    }
80
81    /// Get a mutable reference for  this packet [RCODE] information
82    /// Warning, if the [RCODE] value is greater than 15 (4 bits), you MUST provide an [OPT]
83    /// resource record through the [Packet::opt_mut] function
84    pub fn rcode_mut(&mut self) -> &mut RCODE {
85        &mut self.header.response_code
86    }
87
88    /// Get this packet [OPCODE] information
89    pub fn opcode(&self) -> OPCODE {
90        self.header.opcode
91    }
92
93    /// Get a mutable reference for this packet [OPCODE] information
94    pub fn opcode_mut(&mut self) -> &mut OPCODE {
95        &mut self.header.opcode
96    }
97
98    /// Get the [OPT] resource record for this packet, if present
99    pub fn opt(&self) -> Option<&OPT<'a>> {
100        self.header.opt.as_ref()
101    }
102
103    /// Get a mutable reference for this packet [OPT] resource record.  
104    pub fn opt_mut(&mut self) -> &mut Option<OPT<'a>> {
105        &mut self.header.opt
106    }
107
108    /// Changes this packet into a reply packet by replacing its header
109    pub fn into_reply(mut self) -> Self {
110        self.header = Header::new_reply(self.header.id, self.header.opcode);
111        self
112    }
113
114    /// Parses a packet from a slice of bytes
115    pub fn parse(data: &'a [u8]) -> crate::Result<Self> {
116        let mut header = Header::parse(data)?;
117
118        let mut offset = 12;
119        let questions = Self::parse_section(data, &mut offset, header_buffer::questions(data)?)?;
120        let answers = Self::parse_section(data, &mut offset, header_buffer::answers(data)?)?;
121        let name_servers =
122            Self::parse_section(data, &mut offset, header_buffer::name_servers(data)?)?;
123        let mut additional_records: Vec<ResourceRecord> =
124            Self::parse_section(data, &mut offset, header_buffer::additional_records(data)?)?;
125
126        header.extract_info_from_opt_rr(
127            additional_records
128                .iter()
129                .position(|rr| rr.rdata.type_code() == crate::TYPE::OPT)
130                .map(|i| additional_records.remove(i)),
131        );
132
133        Ok(Self {
134            header,
135            questions,
136            answers,
137            name_servers,
138            additional_records,
139        })
140    }
141
142    fn parse_section<T: WireFormat<'a>>(
143        data: &'a [u8],
144        offset: &mut usize,
145        items_count: u16,
146    ) -> crate::Result<Vec<T>> {
147        let mut section_items = Vec::with_capacity(items_count as usize);
148
149        for _ in 0..items_count {
150            section_items.push(T::parse(data, offset)?);
151        }
152
153        Ok(section_items)
154    }
155
156    /// Creates a new [Vec`<u8>`](`Vec<T>`) and write the contents of this package in wire format
157    ///
158    /// This call will allocate a `Vec<u8>` of 900 bytes, which is enough for a jumbo UDP packet
159    pub fn build_bytes_vec(&self) -> crate::Result<Vec<u8>> {
160        let mut out = Cursor::new(Vec::with_capacity(900));
161
162        self.write_to(&mut out)?;
163
164        Ok(out.into_inner())
165    }
166
167    /// Creates a new [Vec`<u8>`](`Vec<T>`) and write the contents of this package in wire format
168    /// with compression enabled
169    ///
170    /// This call will allocate a `Vec<u8>` of 900 bytes, which is enough for a jumbo UDP packet
171    pub fn build_bytes_vec_compressed(&self) -> crate::Result<Vec<u8>> {
172        let mut out = Cursor::new(Vec::with_capacity(900));
173        self.write_compressed_to(&mut out)?;
174
175        Ok(out.into_inner())
176    }
177
178    /// Write the contents of this package in wire format into the provided writer
179    pub fn write_to<T: Write>(&self, out: &mut T) -> crate::Result<()> {
180        self.write_header(out)?;
181
182        for e in &self.questions {
183            e.write_to(out)?;
184        }
185        for e in &self.answers {
186            e.write_to(out)?;
187        }
188        for e in &self.name_servers {
189            e.write_to(out)?;
190        }
191
192        if let Some(rr) = self.header.opt_rr() {
193            rr.write_to(out)?;
194        }
195
196        for e in &self.additional_records {
197            e.write_to(out)?;
198        }
199
200        out.flush()?;
201        Ok(())
202    }
203
204    /// Write the contents of this package in wire format with enabled compression into the provided writer
205    pub fn write_compressed_to<T: Write + Seek>(&self, out: &mut T) -> crate::Result<()> {
206        self.write_header(out)?;
207
208        let mut name_refs = HashMap::new();
209        for e in &self.questions {
210            e.write_compressed_to(out, &mut name_refs)?;
211        }
212        for e in &self.answers {
213            e.write_compressed_to(out, &mut name_refs)?;
214        }
215        for e in &self.name_servers {
216            e.write_compressed_to(out, &mut name_refs)?;
217        }
218
219        if let Some(rr) = self.header.opt_rr() {
220            rr.write_to(out)?;
221        }
222
223        for e in &self.additional_records {
224            e.write_compressed_to(out, &mut name_refs)?;
225        }
226        out.flush()?;
227
228        Ok(())
229    }
230
231    fn write_header<T: Write>(&self, out: &mut T) -> crate::Result<()> {
232        self.header.write_to(
233            out,
234            self.questions.len() as u16,
235            self.answers.len() as u16,
236            self.name_servers.len() as u16,
237            self.additional_records.len() as u16 + u16::from(self.header.opt.is_some()),
238        )
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use crate::{dns::CLASS, dns::TYPE, SimpleDnsError};
245
246    use super::*;
247    use std::convert::TryInto;
248
249    #[test]
250    fn parse_without_data_should_not_panic() {
251        assert!(matches!(
252            Packet::parse(&[]),
253            Err(SimpleDnsError::InsufficientData)
254        ));
255    }
256
257    #[test]
258    fn build_query_correct() {
259        let mut query = Packet::new_query(1);
260        query.questions.push(Question::new(
261            "_srv._udp.local".try_into().unwrap(),
262            TYPE::TXT.into(),
263            CLASS::IN.into(),
264            false,
265        ));
266        query.questions.push(Question::new(
267            "_srv2._udp.local".try_into().unwrap(),
268            TYPE::TXT.into(),
269            CLASS::IN.into(),
270            false,
271        ));
272
273        let query = query.build_bytes_vec().unwrap();
274
275        let parsed = Packet::parse(&query);
276        assert!(parsed.is_ok());
277
278        let parsed = parsed.unwrap();
279        assert_eq!(2, parsed.questions.len());
280        assert_eq!("_srv._udp.local", parsed.questions[0].qname.to_string());
281        assert_eq!("_srv2._udp.local", parsed.questions[1].qname.to_string());
282    }
283}