simple_dns/dns/
resource_record.rs

1use crate::{QCLASS, QTYPE};
2
3use super::{name::Label, rdata::RData, Name, WireFormat, CLASS, TYPE};
4use core::fmt::Debug;
5use std::{collections::HashMap, convert::TryInto, hash::Hash};
6
7mod flag {
8    pub const CACHE_FLUSH: u16 = 0b1000_0000_0000_0000;
9}
10/// Resource Records are used to represent the answer, authority, and additional sections in DNS packets.
11#[derive(Debug, Eq, Clone)]
12pub struct ResourceRecord<'a> {
13    /// A [`Name`] to which this resource record pertains.
14    pub name: Name<'a>,
15    /// A [`CLASS`] that defines the class of the rdata field
16    pub class: CLASS,
17    /// The time interval (in seconds) that the resource record may becached before it should be discarded.  
18    /// Zero values are interpreted to mean that the RR can only be used for the transaction in progress, and should not be cached.
19    pub ttl: u32,
20    /// A [`RData`] with the contents of this resource record
21    pub rdata: RData<'a>,
22
23    /// Indicates if this RR is a cache flush
24    pub cache_flush: bool,
25}
26
27impl<'a> ResourceRecord<'a> {
28    /// Creates a new ResourceRecord
29    pub fn new(name: Name<'a>, class: CLASS, ttl: u32, rdata: RData<'a>) -> Self {
30        Self {
31            name,
32            class,
33            ttl,
34            rdata,
35            cache_flush: false,
36        }
37    }
38
39    /// Consume self and change the cache_flush bit
40    pub fn with_cache_flush(mut self, cache_flush: bool) -> Self {
41        self.cache_flush = cache_flush;
42        self
43    }
44
45    /// Returns a cloned self with cache_flush = true
46    pub fn to_cache_flush_record(&self) -> Self {
47        self.clone().with_cache_flush(true)
48    }
49
50    /// Return true if current resource match given query class
51    pub fn match_qclass(&self, qclass: QCLASS) -> bool {
52        match qclass {
53            QCLASS::CLASS(class) => class == self.class,
54            QCLASS::ANY => true,
55        }
56    }
57
58    /// Return true if current resource match given query type
59    pub fn match_qtype(&self, qtype: QTYPE) -> bool {
60        let type_code = self.rdata.type_code();
61        match qtype {
62            QTYPE::ANY => true,
63            QTYPE::IXFR => false,
64            QTYPE::AXFR => true, // TODO: figure out what to do here
65            QTYPE::MAILB => type_code == TYPE::MR || type_code == TYPE::MB || type_code == TYPE::MG,
66            QTYPE::MAILA => type_code == TYPE::MX,
67            QTYPE::TYPE(ty) => ty == type_code,
68        }
69    }
70
71    /// Transforms the inner data into its owned type
72    pub fn into_owned<'b>(self) -> ResourceRecord<'b> {
73        ResourceRecord {
74            name: self.name.into_owned(),
75            class: self.class,
76            ttl: self.ttl,
77            rdata: self.rdata.into_owned(),
78            cache_flush: self.cache_flush,
79        }
80    }
81
82    fn write_common<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
83        out.write_all(&u16::from(self.rdata.type_code()).to_be_bytes())?;
84
85        if let RData::OPT(ref opt) = self.rdata {
86            out.write_all(&opt.udp_packet_size.to_be_bytes())?;
87        } else {
88            let class = if self.cache_flush {
89                ((self.class as u16) | flag::CACHE_FLUSH).to_be_bytes()
90            } else {
91                (self.class as u16).to_be_bytes()
92            };
93
94            out.write_all(&class)?;
95        }
96
97        out.write_all(&self.ttl.to_be_bytes())
98            .map_err(crate::SimpleDnsError::from)
99    }
100}
101
102impl<'a> WireFormat<'a> for ResourceRecord<'a> {
103    fn parse(data: &'a [u8], position: &mut usize) -> crate::Result<Self>
104    where
105        Self: Sized,
106    {
107        let name = Name::parse(data, position)?;
108        if *position + 8 > data.len() {
109            return Err(crate::SimpleDnsError::InsufficientData);
110        }
111
112        let class_value = u16::from_be_bytes(data[*position + 2..*position + 4].try_into()?);
113        let ttl = u32::from_be_bytes(data[*position + 4..*position + 8].try_into()?);
114        let rdata = RData::parse(data, position)?;
115
116        if rdata.type_code() == TYPE::OPT {
117            Ok(Self {
118                name,
119                class: CLASS::IN,
120                ttl,
121                rdata,
122                cache_flush: false,
123            })
124        } else {
125            let cache_flush = class_value & flag::CACHE_FLUSH == flag::CACHE_FLUSH;
126            let class = (class_value & !flag::CACHE_FLUSH).try_into()?;
127
128            Ok(Self {
129                name,
130                class,
131                ttl,
132                rdata,
133                cache_flush,
134            })
135        }
136    }
137
138    fn len(&self) -> usize {
139        self.name.len() + self.rdata.len() + 10
140    }
141
142    fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
143        self.name.write_to(out)?;
144        self.write_common(out)?;
145        out.write_all(&(self.rdata.len() as u16).to_be_bytes())?;
146        self.rdata.write_to(out)
147    }
148
149    fn write_compressed_to<T: std::io::Write + std::io::Seek>(
150        &'a self,
151        out: &mut T,
152        name_refs: &mut HashMap<&'a [Label<'a>], usize>,
153    ) -> crate::Result<()> {
154        self.name.write_compressed_to(out, name_refs)?;
155        self.write_common(out)?;
156
157        let len_position = out.stream_position()?;
158        out.write_all(&[0, 0])?;
159
160        self.rdata.write_compressed_to(out, name_refs)?;
161        let end = out.stream_position()?;
162
163        out.seek(std::io::SeekFrom::Start(len_position))?;
164        out.write_all(&((end - len_position - 2) as u16).to_be_bytes())?;
165        out.seek(std::io::SeekFrom::End(0))?;
166        Ok(())
167    }
168}
169
170impl<'a> Hash for ResourceRecord<'a> {
171    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
172        self.name.hash(state);
173        self.class.hash(state);
174        self.rdata.hash(state);
175    }
176}
177
178impl<'a> PartialEq for ResourceRecord<'a> {
179    fn eq(&self, other: &Self) -> bool {
180        self.name == other.name && self.class == other.class && self.rdata == other.rdata
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use std::{
187        collections::hash_map::DefaultHasher,
188        hash::{Hash, Hasher},
189        io::Cursor,
190    };
191
192    use crate::{dns::rdata::NULL, rdata::TXT};
193
194    use super::*;
195
196    #[test]
197    fn test_parse() {
198        let bytes = b"\x04_srv\x04_udp\x05local\x00\x00\x01\x00\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff";
199        let rr = ResourceRecord::parse(&bytes[..], &mut 0).unwrap();
200
201        assert_eq!("_srv._udp.local", rr.name.to_string());
202        assert_eq!(CLASS::IN, rr.class);
203        assert_eq!(10, rr.ttl);
204        assert_eq!(4, rr.rdata.len());
205        assert!(!rr.cache_flush);
206
207        match rr.rdata {
208            RData::A(a) => assert_eq!(4294967295, a.address),
209            _ => panic!("invalid rdata"),
210        }
211    }
212
213    #[test]
214    fn test_empty_rdata() {
215        let rr = ResourceRecord {
216            class: CLASS::NONE,
217            name: "_srv._udp.local".try_into().unwrap(),
218            ttl: 0,
219            rdata: RData::Empty(TYPE::A),
220            cache_flush: false,
221        };
222
223        assert_eq!(rr.rdata.type_code(), TYPE::A);
224        assert_eq!(rr.rdata.len(), 0);
225
226        let mut data = Vec::new();
227        rr.write_to(&mut data).expect("failed to write");
228
229        let parsed_rr = ResourceRecord::parse(&data, &mut 0).expect("failed to parse");
230        assert_eq!(parsed_rr.rdata.type_code(), TYPE::A);
231        assert_eq!(parsed_rr.rdata.len(), 0);
232        assert!(matches!(parsed_rr.rdata, RData::Empty(TYPE::A)));
233    }
234
235    #[test]
236    fn test_cache_flush_parse() {
237        let bytes = b"\x04_srv\x04_udp\x05local\x00\x00\x01\x80\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff";
238        let rr = ResourceRecord::parse(&bytes[..], &mut 0).unwrap();
239
240        assert_eq!(CLASS::IN, rr.class);
241        assert!(rr.cache_flush);
242    }
243
244    #[test]
245    fn test_write() {
246        let mut out = Cursor::new(Vec::new());
247        let rdata = [255u8; 4];
248
249        let rr = ResourceRecord {
250            class: CLASS::IN,
251            name: "_srv._udp.local".try_into().unwrap(),
252            ttl: 10,
253            rdata: RData::NULL(0, NULL::new(&rdata).unwrap()),
254            cache_flush: false,
255        };
256
257        assert!(rr.write_to(&mut out).is_ok());
258        assert_eq!(
259            b"\x04_srv\x04_udp\x05local\x00\x00\x00\x00\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff",
260            &out.get_ref()[..]
261        );
262        assert_eq!(out.get_ref().len(), rr.len());
263    }
264
265    #[test]
266    fn test_append_to_vec_cache_flush() {
267        let mut out = Cursor::new(Vec::new());
268        let rdata = [255u8; 4];
269
270        let rr = ResourceRecord {
271            class: CLASS::IN,
272            name: "_srv._udp.local".try_into().unwrap(),
273            ttl: 10,
274            rdata: RData::NULL(0, NULL::new(&rdata).unwrap()),
275            cache_flush: true,
276        };
277
278        assert!(rr.write_to(&mut out).is_ok());
279        assert_eq!(
280            b"\x04_srv\x04_udp\x05local\x00\x00\x00\x80\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff",
281            &out.get_ref()[..]
282        );
283        assert_eq!(out.get_ref().len(), rr.len());
284    }
285
286    #[test]
287    fn test_match_qclass() {
288        let rr = ResourceRecord {
289            class: CLASS::IN,
290            name: "_srv._udp.local".try_into().unwrap(),
291            ttl: 10,
292            rdata: RData::NULL(0, NULL::new(&[255u8; 4]).unwrap()),
293            cache_flush: false,
294        };
295
296        assert!(rr.match_qclass(QCLASS::ANY));
297        assert!(rr.match_qclass(CLASS::IN.into()));
298        assert!(!rr.match_qclass(CLASS::CS.into()));
299    }
300
301    #[test]
302    fn test_match_qtype() {
303        let rr = ResourceRecord {
304            class: CLASS::IN,
305            name: "_srv._udp.local".try_into().unwrap(),
306            ttl: 10,
307            rdata: RData::A(crate::rdata::A { address: 0 }),
308            cache_flush: false,
309        };
310
311        assert!(rr.match_qtype(QTYPE::ANY));
312        assert!(rr.match_qtype(TYPE::A.into()));
313        assert!(!rr.match_qtype(TYPE::WKS.into()));
314    }
315
316    #[test]
317    fn test_eq() {
318        let a = ResourceRecord::new(
319            Name::new_unchecked("_srv.local"),
320            CLASS::IN,
321            10,
322            RData::TXT(TXT::new().with_string("text").unwrap()),
323        );
324        let b = ResourceRecord::new(
325            Name::new_unchecked("_srv.local"),
326            CLASS::IN,
327            10,
328            RData::TXT(TXT::new().with_string("text").unwrap()),
329        );
330
331        assert_eq!(a, b);
332        assert_eq!(get_hash(&a), get_hash(&b));
333    }
334
335    #[test]
336    fn test_hash_ignore_ttl() {
337        let a = ResourceRecord::new(
338            Name::new_unchecked("_srv.local"),
339            CLASS::IN,
340            10,
341            RData::TXT(TXT::new().with_string("text").unwrap()),
342        );
343        let mut b = ResourceRecord::new(
344            Name::new_unchecked("_srv.local"),
345            CLASS::IN,
346            10,
347            RData::TXT(TXT::new().with_string("text").unwrap()),
348        );
349
350        assert_eq!(get_hash(&a), get_hash(&b));
351        b.ttl = 50;
352
353        assert_eq!(get_hash(&a), get_hash(&b));
354    }
355
356    fn get_hash(rr: &ResourceRecord) -> u64 {
357        let mut hasher = DefaultHasher::default();
358        rr.hash(&mut hasher);
359        hasher.finish()
360    }
361
362    #[test]
363    fn parse_sample_files() -> Result<(), Box<dyn std::error::Error>> {
364        for file_path in std::fs::read_dir("samples/zonefile")? {
365            let data = std::fs::read(file_path?.path())?;
366            let mut pos = 0;
367            while pos < data.len() {
368                crate::ResourceRecord::parse(&data, &mut pos)?;
369            }
370        }
371
372        Ok(())
373    }
374}