1use crate::{QCLASS, QTYPE};
2
3use super::{name::Label, rdata::RData, Name, WireFormat, CLASS, TYPE};
4use core::fmt::Debug;
5use std::{collections::HashMap, convert::TryInto, hash::Hash};
6
7mod flag {
8 pub const CACHE_FLUSH: u16 = 0b1000_0000_0000_0000;
9}
10#[derive(Debug, Eq, Clone)]
12pub struct ResourceRecord<'a> {
13 pub name: Name<'a>,
15 pub class: CLASS,
17 pub ttl: u32,
20 pub rdata: RData<'a>,
22
23 pub cache_flush: bool,
25}
26
27impl<'a> ResourceRecord<'a> {
28 pub fn new(name: Name<'a>, class: CLASS, ttl: u32, rdata: RData<'a>) -> Self {
30 Self {
31 name,
32 class,
33 ttl,
34 rdata,
35 cache_flush: false,
36 }
37 }
38
39 pub fn with_cache_flush(mut self, cache_flush: bool) -> Self {
41 self.cache_flush = cache_flush;
42 self
43 }
44
45 pub fn to_cache_flush_record(&self) -> Self {
47 self.clone().with_cache_flush(true)
48 }
49
50 pub fn match_qclass(&self, qclass: QCLASS) -> bool {
52 match qclass {
53 QCLASS::CLASS(class) => class == self.class,
54 QCLASS::ANY => true,
55 }
56 }
57
58 pub fn match_qtype(&self, qtype: QTYPE) -> bool {
60 let type_code = self.rdata.type_code();
61 match qtype {
62 QTYPE::ANY => true,
63 QTYPE::IXFR => false,
64 QTYPE::AXFR => true, QTYPE::MAILB => type_code == TYPE::MR || type_code == TYPE::MB || type_code == TYPE::MG,
66 QTYPE::MAILA => type_code == TYPE::MX,
67 QTYPE::TYPE(ty) => ty == type_code,
68 }
69 }
70
71 pub fn into_owned<'b>(self) -> ResourceRecord<'b> {
73 ResourceRecord {
74 name: self.name.into_owned(),
75 class: self.class,
76 ttl: self.ttl,
77 rdata: self.rdata.into_owned(),
78 cache_flush: self.cache_flush,
79 }
80 }
81
82 fn write_common<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
83 out.write_all(&u16::from(self.rdata.type_code()).to_be_bytes())?;
84
85 if let RData::OPT(ref opt) = self.rdata {
86 out.write_all(&opt.udp_packet_size.to_be_bytes())?;
87 } else {
88 let class = if self.cache_flush {
89 ((self.class as u16) | flag::CACHE_FLUSH).to_be_bytes()
90 } else {
91 (self.class as u16).to_be_bytes()
92 };
93
94 out.write_all(&class)?;
95 }
96
97 out.write_all(&self.ttl.to_be_bytes())
98 .map_err(crate::SimpleDnsError::from)
99 }
100}
101
102impl<'a> WireFormat<'a> for ResourceRecord<'a> {
103 fn parse(data: &'a [u8], position: &mut usize) -> crate::Result<Self>
104 where
105 Self: Sized,
106 {
107 let name = Name::parse(data, position)?;
108 if *position + 8 > data.len() {
109 return Err(crate::SimpleDnsError::InsufficientData);
110 }
111
112 let class_value = u16::from_be_bytes(data[*position + 2..*position + 4].try_into()?);
113 let ttl = u32::from_be_bytes(data[*position + 4..*position + 8].try_into()?);
114 let rdata = RData::parse(data, position)?;
115
116 if rdata.type_code() == TYPE::OPT {
117 Ok(Self {
118 name,
119 class: CLASS::IN,
120 ttl,
121 rdata,
122 cache_flush: false,
123 })
124 } else {
125 let cache_flush = class_value & flag::CACHE_FLUSH == flag::CACHE_FLUSH;
126 let class = (class_value & !flag::CACHE_FLUSH).try_into()?;
127
128 Ok(Self {
129 name,
130 class,
131 ttl,
132 rdata,
133 cache_flush,
134 })
135 }
136 }
137
138 fn len(&self) -> usize {
139 self.name.len() + self.rdata.len() + 10
140 }
141
142 fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
143 self.name.write_to(out)?;
144 self.write_common(out)?;
145 out.write_all(&(self.rdata.len() as u16).to_be_bytes())?;
146 self.rdata.write_to(out)
147 }
148
149 fn write_compressed_to<T: std::io::Write + std::io::Seek>(
150 &'a self,
151 out: &mut T,
152 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
153 ) -> crate::Result<()> {
154 self.name.write_compressed_to(out, name_refs)?;
155 self.write_common(out)?;
156
157 let len_position = out.stream_position()?;
158 out.write_all(&[0, 0])?;
159
160 self.rdata.write_compressed_to(out, name_refs)?;
161 let end = out.stream_position()?;
162
163 out.seek(std::io::SeekFrom::Start(len_position))?;
164 out.write_all(&((end - len_position - 2) as u16).to_be_bytes())?;
165 out.seek(std::io::SeekFrom::End(0))?;
166 Ok(())
167 }
168}
169
170impl<'a> Hash for ResourceRecord<'a> {
171 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
172 self.name.hash(state);
173 self.class.hash(state);
174 self.rdata.hash(state);
175 }
176}
177
178impl<'a> PartialEq for ResourceRecord<'a> {
179 fn eq(&self, other: &Self) -> bool {
180 self.name == other.name && self.class == other.class && self.rdata == other.rdata
181 }
182}
183
184#[cfg(test)]
185mod tests {
186 use std::{
187 collections::hash_map::DefaultHasher,
188 hash::{Hash, Hasher},
189 io::Cursor,
190 };
191
192 use crate::{dns::rdata::NULL, rdata::TXT};
193
194 use super::*;
195
196 #[test]
197 fn test_parse() {
198 let bytes = b"\x04_srv\x04_udp\x05local\x00\x00\x01\x00\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff";
199 let rr = ResourceRecord::parse(&bytes[..], &mut 0).unwrap();
200
201 assert_eq!("_srv._udp.local", rr.name.to_string());
202 assert_eq!(CLASS::IN, rr.class);
203 assert_eq!(10, rr.ttl);
204 assert_eq!(4, rr.rdata.len());
205 assert!(!rr.cache_flush);
206
207 match rr.rdata {
208 RData::A(a) => assert_eq!(4294967295, a.address),
209 _ => panic!("invalid rdata"),
210 }
211 }
212
213 #[test]
214 fn test_empty_rdata() {
215 let rr = ResourceRecord {
216 class: CLASS::NONE,
217 name: "_srv._udp.local".try_into().unwrap(),
218 ttl: 0,
219 rdata: RData::Empty(TYPE::A),
220 cache_flush: false,
221 };
222
223 assert_eq!(rr.rdata.type_code(), TYPE::A);
224 assert_eq!(rr.rdata.len(), 0);
225
226 let mut data = Vec::new();
227 rr.write_to(&mut data).expect("failed to write");
228
229 let parsed_rr = ResourceRecord::parse(&data, &mut 0).expect("failed to parse");
230 assert_eq!(parsed_rr.rdata.type_code(), TYPE::A);
231 assert_eq!(parsed_rr.rdata.len(), 0);
232 assert!(matches!(parsed_rr.rdata, RData::Empty(TYPE::A)));
233 }
234
235 #[test]
236 fn test_cache_flush_parse() {
237 let bytes = b"\x04_srv\x04_udp\x05local\x00\x00\x01\x80\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff";
238 let rr = ResourceRecord::parse(&bytes[..], &mut 0).unwrap();
239
240 assert_eq!(CLASS::IN, rr.class);
241 assert!(rr.cache_flush);
242 }
243
244 #[test]
245 fn test_write() {
246 let mut out = Cursor::new(Vec::new());
247 let rdata = [255u8; 4];
248
249 let rr = ResourceRecord {
250 class: CLASS::IN,
251 name: "_srv._udp.local".try_into().unwrap(),
252 ttl: 10,
253 rdata: RData::NULL(0, NULL::new(&rdata).unwrap()),
254 cache_flush: false,
255 };
256
257 assert!(rr.write_to(&mut out).is_ok());
258 assert_eq!(
259 b"\x04_srv\x04_udp\x05local\x00\x00\x00\x00\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff",
260 &out.get_ref()[..]
261 );
262 assert_eq!(out.get_ref().len(), rr.len());
263 }
264
265 #[test]
266 fn test_append_to_vec_cache_flush() {
267 let mut out = Cursor::new(Vec::new());
268 let rdata = [255u8; 4];
269
270 let rr = ResourceRecord {
271 class: CLASS::IN,
272 name: "_srv._udp.local".try_into().unwrap(),
273 ttl: 10,
274 rdata: RData::NULL(0, NULL::new(&rdata).unwrap()),
275 cache_flush: true,
276 };
277
278 assert!(rr.write_to(&mut out).is_ok());
279 assert_eq!(
280 b"\x04_srv\x04_udp\x05local\x00\x00\x00\x80\x01\x00\x00\x00\x0a\x00\x04\xff\xff\xff\xff",
281 &out.get_ref()[..]
282 );
283 assert_eq!(out.get_ref().len(), rr.len());
284 }
285
286 #[test]
287 fn test_match_qclass() {
288 let rr = ResourceRecord {
289 class: CLASS::IN,
290 name: "_srv._udp.local".try_into().unwrap(),
291 ttl: 10,
292 rdata: RData::NULL(0, NULL::new(&[255u8; 4]).unwrap()),
293 cache_flush: false,
294 };
295
296 assert!(rr.match_qclass(QCLASS::ANY));
297 assert!(rr.match_qclass(CLASS::IN.into()));
298 assert!(!rr.match_qclass(CLASS::CS.into()));
299 }
300
301 #[test]
302 fn test_match_qtype() {
303 let rr = ResourceRecord {
304 class: CLASS::IN,
305 name: "_srv._udp.local".try_into().unwrap(),
306 ttl: 10,
307 rdata: RData::A(crate::rdata::A { address: 0 }),
308 cache_flush: false,
309 };
310
311 assert!(rr.match_qtype(QTYPE::ANY));
312 assert!(rr.match_qtype(TYPE::A.into()));
313 assert!(!rr.match_qtype(TYPE::WKS.into()));
314 }
315
316 #[test]
317 fn test_eq() {
318 let a = ResourceRecord::new(
319 Name::new_unchecked("_srv.local"),
320 CLASS::IN,
321 10,
322 RData::TXT(TXT::new().with_string("text").unwrap()),
323 );
324 let b = ResourceRecord::new(
325 Name::new_unchecked("_srv.local"),
326 CLASS::IN,
327 10,
328 RData::TXT(TXT::new().with_string("text").unwrap()),
329 );
330
331 assert_eq!(a, b);
332 assert_eq!(get_hash(&a), get_hash(&b));
333 }
334
335 #[test]
336 fn test_hash_ignore_ttl() {
337 let a = ResourceRecord::new(
338 Name::new_unchecked("_srv.local"),
339 CLASS::IN,
340 10,
341 RData::TXT(TXT::new().with_string("text").unwrap()),
342 );
343 let mut b = ResourceRecord::new(
344 Name::new_unchecked("_srv.local"),
345 CLASS::IN,
346 10,
347 RData::TXT(TXT::new().with_string("text").unwrap()),
348 );
349
350 assert_eq!(get_hash(&a), get_hash(&b));
351 b.ttl = 50;
352
353 assert_eq!(get_hash(&a), get_hash(&b));
354 }
355
356 fn get_hash(rr: &ResourceRecord) -> u64 {
357 let mut hasher = DefaultHasher::default();
358 rr.hash(&mut hasher);
359 hasher.finish()
360 }
361
362 #[test]
363 fn parse_sample_files() -> Result<(), Box<dyn std::error::Error>> {
364 for file_path in std::fs::read_dir("samples/zonefile")? {
365 let data = std::fs::read(file_path?.path())?;
366 let mut pos = 0;
367 while pos < data.len() {
368 crate::ResourceRecord::parse(&data, &mut pos)?;
369 }
370 }
371
372 Ok(())
373 }
374}