1use std::{
2 borrow::Cow,
3 collections::HashMap,
4 convert::{TryFrom, TryInto},
5 fmt::Display,
6 hash::Hash,
7};
8
9use super::{WireFormat, MAX_LABEL_LENGTH, MAX_NAME_LENGTH};
10
11const POINTER_MASK: u8 = 0b1100_0000;
12const POINTER_MASK_U16: u16 = 0b1100_0000_0000_0000;
13
14#[derive(Eq, Clone)]
22pub struct Name<'a> {
23 labels: Vec<Label<'a>>,
24}
25
26impl<'a> Name<'a> {
27 pub fn new(name: &'a str) -> crate::Result<Self> {
29 let labels = NameSpliter::new(name.as_bytes())
30 .map(Label::new)
31 .collect::<Result<Vec<Label>, _>>()?;
32
33 let name = Self { labels };
34
35 if name.len() > MAX_NAME_LENGTH {
36 Err(crate::SimpleDnsError::InvalidServiceName)
37 } else {
38 Ok(name)
39 }
40 }
41
42 pub fn new_unchecked(name: &'a str) -> Self {
44 let labels = NameSpliter::new(name.as_bytes())
45 .map(Label::new_unchecked)
46 .collect();
47
48 Self { labels }
49 }
50
51 pub fn is_link_local(&self) -> bool {
53 match self.iter().last() {
54 Some(label) => b"local".eq_ignore_ascii_case(&label.data),
55 None => false,
56 }
57 }
58
59 pub fn iter(&'a self) -> std::slice::Iter<Label<'a>> {
61 self.labels.iter()
62 }
63
64 pub fn is_subdomain_of(&self, other: &Name) -> bool {
66 self.labels.len() > other.labels.len()
67 && other
68 .iter()
69 .rev()
70 .zip(self.iter().rev())
71 .all(|(o, s)| *o == *s)
72 }
73
74 pub fn without(&self, domain: &Name) -> Option<Name> {
89 if self.is_subdomain_of(domain) {
90 let labels = self.labels[..self.labels.len() - domain.labels.len()].to_vec();
91
92 Some(Name { labels })
93 } else {
94 None
95 }
96 }
97
98 pub fn into_owned<'b>(self) -> Name<'b> {
100 Name {
101 labels: self.labels.into_iter().map(|l| l.into_owned()).collect(),
102 }
103 }
104
105 pub fn get_labels(&'_ self) -> &'_ [Label<'_>] {
107 &self.labels[..]
108 }
109
110 fn plain_append<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
111 for label in self.iter() {
112 out.write_all(&[label.len() as u8])?;
113 out.write_all(&label.data)?;
114 }
115
116 out.write_all(&[0])?;
117 Ok(())
118 }
119
120 fn compress_append<T: std::io::Write + std::io::Seek>(
121 &'a self,
122 out: &mut T,
123 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
124 ) -> crate::Result<()> {
125 for (i, label) in self.iter().enumerate() {
126 match name_refs.entry(&self.labels[i..]) {
127 std::collections::hash_map::Entry::Occupied(e) => {
128 let p = *e.get() as u16;
129 out.write_all(&(p | POINTER_MASK_U16).to_be_bytes())?;
130
131 return Ok(());
132 }
133 std::collections::hash_map::Entry::Vacant(e) => {
134 e.insert(out.stream_position()? as usize);
135 out.write_all(&[label.len() as u8])?;
136 out.write_all(&label.data)?;
137 }
138 }
139 }
140
141 out.write_all(&[0])?;
142 Ok(())
143 }
144}
145
146impl<'a> WireFormat<'a> for Name<'a> {
147 fn parse(data: &'a [u8], position: &mut usize) -> crate::Result<Self>
148 where
149 Self: Sized,
150 {
151 let mut following_compression_pointer = false;
152 let mut labels = Vec::new();
153
154 let mut pointer_position = *position;
155
156 let mut name_size = 0usize;
158
159 loop {
160 if *position >= data.len() {
161 return Err(crate::SimpleDnsError::InsufficientData);
162 }
163
164 if name_size >= MAX_NAME_LENGTH {
166 return Err(crate::SimpleDnsError::InvalidDnsPacket);
167 }
168
169 match data[pointer_position] {
170 0 => {
171 *position += 1;
172 break;
173 }
174 len if len & POINTER_MASK == POINTER_MASK => {
175 if !following_compression_pointer {
176 *position += 1;
177 }
178
179 following_compression_pointer = true;
180 if pointer_position + 2 > data.len() {
181 return Err(crate::SimpleDnsError::InsufficientData);
182 }
183
184 let pointer = (u16::from_be_bytes(
186 data[pointer_position..pointer_position + 2].try_into()?,
187 ) & !POINTER_MASK_U16) as usize;
188 if pointer >= pointer_position {
189 return Err(crate::SimpleDnsError::InvalidDnsPacket);
190 }
191 pointer_position = pointer;
192 }
193 len => {
194 name_size += 1 + len as usize;
195 if pointer_position + 1 + len as usize > data.len() {
196 return Err(crate::SimpleDnsError::InsufficientData);
197 }
198
199 labels.push(Label::new(
200 &data[pointer_position + 1..pointer_position + 1 + len as usize],
201 )?);
202
203 if !following_compression_pointer {
204 *position += len as usize + 1;
205 }
206 pointer_position += len as usize + 1;
207 }
208 }
209 }
210
211 Ok(Self { labels })
212 }
213
214 fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
215 self.plain_append(out)
216 }
217
218 fn write_compressed_to<T: std::io::Write + std::io::Seek>(
219 &'a self,
220 out: &mut T,
221 name_refs: &mut HashMap<&'a [Label<'a>], usize>,
222 ) -> crate::Result<()> {
223 self.compress_append(out, name_refs)
224 }
225
226 fn len(&self) -> usize {
227 self.labels
228 .iter()
229 .map(|label| label.len() + 1)
230 .sum::<usize>()
231 + 1
232 }
234}
235
236impl<'a> TryFrom<&'a str> for Name<'a> {
237 type Error = crate::SimpleDnsError;
238
239 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
240 Name::new(value)
241 }
242}
243
244impl<'a> Display for Name<'a> {
245 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246 for (i, label) in self.iter().enumerate() {
247 if i != 0 {
248 f.write_str(".")?;
249 }
250
251 f.write_fmt(format_args!("{}", label))?;
252 }
253
254 Ok(())
255 }
256}
257
258impl<'a> std::fmt::Debug for Name<'a> {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.debug_tuple("Name")
261 .field(&format!("{}", self))
262 .field(&format!("{}", self.len()))
263 .finish()
264 }
265}
266
267impl<'a> PartialEq for Name<'a> {
268 fn eq(&self, other: &Self) -> bool {
269 self.labels == other.labels
270 }
271}
272
273impl<'a> Hash for Name<'a> {
274 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
275 self.labels.hash(state);
276 }
277}
278
279struct NameSpliter<'a> {
280 bytes: &'a [u8],
281 current: usize,
282}
283
284impl<'a> NameSpliter<'a> {
285 fn new(bytes: &'a [u8]) -> Self {
286 Self { bytes, current: 0 }
287 }
288}
289
290impl<'a> Iterator for NameSpliter<'a> {
291 type Item = Cow<'a, [u8]>;
292
293 fn next(&mut self) -> Option<Self::Item> {
294 let mut slices: Vec<&[u8]> = Vec::new();
295
296 for i in self.current..self.bytes.len() {
297 if self.bytes[i] == b'.' && i - self.current > 0 {
298 let current = std::mem::replace(&mut self.current, i + 1);
299 if self.bytes[i - 1] == b'\\' {
300 slices.push(&self.bytes[current..i - 1]);
301 continue;
302 }
303
304 return Some(join_slices(slices, &self.bytes[current..i]));
305 }
306 }
307
308 if self.current < self.bytes.len() {
309 let current = std::mem::replace(&mut self.current, self.bytes.len());
310 Some(join_slices(slices, &self.bytes[current..]))
311 } else {
312 None
313 }
314 }
315}
316
317fn join_slices<'a>(mut slices: Vec<&'a [u8]>, slice: &'a [u8]) -> Cow<'a, [u8]> {
318 if slices.is_empty() {
319 slice.into()
320 } else {
321 slices.push(slice);
322
323 slices
324 .iter_mut()
325 .fold(Vec::new(), |mut c, v| {
326 if !c.is_empty() {
327 c.push(b'.');
328 }
329
330 c.extend(&v[..]);
331 c
332 })
333 .into()
334 }
335}
336
337#[derive(Eq, PartialEq, Hash, Clone)]
338pub struct Label<'a> {
339 data: Cow<'a, [u8]>,
340}
341
342impl<'a> Label<'a> {
343 pub fn new<T: Into<Cow<'a, [u8]>>>(data: T) -> crate::Result<Self> {
344 let label = Self::new_unchecked(data);
345 if label.len() > MAX_LABEL_LENGTH {
346 Err(crate::SimpleDnsError::InvalidServiceLabel)
347 } else {
348 Ok(label)
349 }
350 }
351
352 pub fn new_unchecked<T: Into<Cow<'a, [u8]>>>(data: T) -> Self {
353 Self { data: data.into() }
354 }
355
356 pub fn len(&self) -> usize {
357 self.data.len()
358 }
359
360 pub fn into_owned<'b>(self) -> Label<'b> {
361 Label {
362 data: self.data.into_owned().into(),
363 }
364 }
365}
366
367impl<'a> Display for Label<'a> {
368 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369 match std::str::from_utf8(&self.data) {
370 Ok(s) => f.write_str(s),
371 Err(_) => Err(std::fmt::Error),
372 }
373 }
374}
375
376impl<'a> std::fmt::Debug for Label<'a> {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 f.debug_struct("Label")
379 .field("data", &self.to_string())
380 .finish()
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use std::io::Cursor;
387 use std::{collections::hash_map::DefaultHasher, hash::Hasher};
388
389 use super::*;
390 use crate::SimpleDnsError;
391
392 #[test]
393 fn construct_valid_names() -> Result<(), SimpleDnsError> {
394 assert!(Name::new("some").is_ok());
395 assert!(Name::new("some.local").is_ok());
396 assert!(Name::new("some.local.").is_ok());
397 assert!(Name::new("\u{1F600}.local.").is_ok());
398
399 let scaped = Name::new("some\\.local")?;
400 assert_eq!(scaped.labels.len(), 1);
401
402 Ok(())
403 }
404
405 #[test]
406 fn is_link_local() {
407 assert!(!Name::new("some.example.com").unwrap().is_link_local());
408 assert!(Name::new("some.example.local.").unwrap().is_link_local());
410 }
411
412 #[test]
413 fn parse_without_compression() {
414 let data =
415 b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\x01F\x03ISI\x04ARPA\x00\x04ARPA\x00";
416 let mut position = 3;
417 let name = Name::parse(data, &mut position).unwrap();
418 assert_eq!("F.ISI.ARPA", name.to_string());
419
420 let name = Name::parse(data, &mut position).unwrap();
421 assert_eq!("FOO.F.ISI.ARPA", name.to_string());
422 }
423
424 #[test]
425 fn parse_with_compression() {
426 let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03\x07INVALID\xc0\x1b";
427 let mut offset = 3usize;
428
429 let name = Name::parse(data, &mut offset).unwrap();
430 assert_eq!("F.ISI.ARPA", name.to_string());
431
432 let name = Name::parse(data, &mut offset).unwrap();
433 assert_eq!("FOO.F.ISI.ARPA", name.to_string());
434
435 let name = Name::parse(data, &mut offset).unwrap();
436 assert_eq!("BAR.F.ISI.ARPA", name.to_string());
437
438 assert!(Name::parse(data, &mut offset).is_err());
439 }
440
441 #[test]
442 fn test_write() {
443 let mut bytes = Cursor::new(Vec::with_capacity(30));
444
445 Name::new_unchecked("_srv._udp.local")
446 .write_to(&mut bytes)
447 .unwrap();
448
449 assert_eq!(b"\x04_srv\x04_udp\x05local\x00", &bytes.get_ref()[..]);
450
451 let mut bytes = Cursor::new(Vec::with_capacity(30));
452 Name::new_unchecked("_srv._udp.local2.")
453 .write_to(&mut bytes)
454 .unwrap();
455
456 assert_eq!(b"\x04_srv\x04_udp\x06local2\x00", &bytes.get_ref()[..]);
457 }
458
459 #[test]
460 fn append_to_vec_with_compression() {
461 let mut buf = Cursor::new(vec![0, 0, 0]);
462 buf.set_position(3);
463
464 let mut name_refs = HashMap::new();
465
466 let f_isi_arpa = Name::new_unchecked("F.ISI.ARPA");
467 f_isi_arpa
468 .write_compressed_to(&mut buf, &mut name_refs)
469 .expect("failed to add F.ISI.ARPA");
470 let foo_f_isi_arpa = Name::new_unchecked("FOO.F.ISI.ARPA");
471 foo_f_isi_arpa
472 .write_compressed_to(&mut buf, &mut name_refs)
473 .expect("failed to add FOO.F.ISI.ARPA");
474
475 Name::new_unchecked("BAR.F.ISI.ARPA")
476 .write_compressed_to(&mut buf, &mut name_refs)
477 .expect("failed to add FOO.F.ISI.ARPA");
478
479 let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
480 assert_eq!(data[..], buf.get_ref()[..]);
481 }
482
483 #[test]
484 fn append_to_vec_with_compression_mult_names() {
485 let mut buf = Cursor::new(vec![]);
486 let mut name_refs = HashMap::new();
487
488 let isi_arpa = Name::new_unchecked("ISI.ARPA");
489 isi_arpa
490 .write_compressed_to(&mut buf, &mut name_refs)
491 .expect("failed to add ISI.ARPA");
492
493 let f_isi_arpa = Name::new_unchecked("F.ISI.ARPA");
494 f_isi_arpa
495 .write_compressed_to(&mut buf, &mut name_refs)
496 .expect("failed to add F.ISI.ARPA");
497 let foo_f_isi_arpa = Name::new_unchecked("FOO.F.ISI.ARPA");
498 foo_f_isi_arpa
499 .write_compressed_to(&mut buf, &mut name_refs)
500 .expect("failed to add F.ISI.ARPA");
501 Name::new_unchecked("BAR.F.ISI.ARPA")
502 .write_compressed_to(&mut buf, &mut name_refs)
503 .expect("failed to add F.ISI.ARPA");
504
505 let expected = b"\x03ISI\x04ARPA\x00\x01F\xc0\x00\x03FOO\xc0\x0a\x03BAR\xc0\x0a";
506 assert_eq!(expected[..], buf.get_ref()[..]);
507
508 let mut position = 0;
509 let first = Name::parse(buf.get_ref(), &mut position).unwrap();
510 assert_eq!("ISI.ARPA", first.to_string());
511 let second = Name::parse(buf.get_ref(), &mut position).unwrap();
512 assert_eq!("F.ISI.ARPA", second.to_string());
513 let third = Name::parse(buf.get_ref(), &mut position).unwrap();
514 assert_eq!("FOO.F.ISI.ARPA", third.to_string());
515 let fourth = Name::parse(buf.get_ref(), &mut position).unwrap();
516 assert_eq!("BAR.F.ISI.ARPA", fourth.to_string());
517 }
518
519 #[test]
520 fn ensure_different_domains_are_not_compressed() {
521 let mut buf = Cursor::new(vec![]);
522 let mut name_refs = HashMap::new();
523
524 let foo_bar_baz = Name::new_unchecked("FOO.BAR.BAZ");
525 foo_bar_baz
526 .write_compressed_to(&mut buf, &mut name_refs)
527 .expect("failed to add FOO.BAR.BAZ");
528
529 let foo_bar_buz = Name::new_unchecked("FOO.BAR.BUZ");
530 foo_bar_buz
531 .write_compressed_to(&mut buf, &mut name_refs)
532 .expect("failed to add FOO.BAR.BUZ");
533
534 Name::new_unchecked("FOO.BAR")
535 .write_compressed_to(&mut buf, &mut name_refs)
536 .expect("failed to add FOO.BAR");
537
538 let expected = b"\x03FOO\x03BAR\x03BAZ\x00\x03FOO\x03BAR\x03BUZ\x00\x03FOO\x03BAR\x00";
539 assert_eq!(expected[..], buf.get_ref()[..]);
540 }
541
542 #[test]
543 fn eq_other_name() -> Result<(), SimpleDnsError> {
544 assert_eq!(Name::new("example.com")?, Name::new("example.com")?);
545 assert_ne!(Name::new("some.example.com")?, Name::new("example.com")?);
546 assert_ne!(Name::new("example.co")?, Name::new("example.com")?);
547 assert_ne!(Name::new("example.com.org")?, Name::new("example.com")?);
548
549 let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
550 let mut position = 3;
551 assert_eq!(Name::new("F.ISI.ARPA")?, Name::parse(data, &mut position)?);
552 assert_eq!(
553 Name::new("FOO.F.ISI.ARPA")?,
554 Name::parse(data, &mut position)?
555 );
556 Ok(())
557 }
558
559 #[test]
560 fn len() -> crate::Result<()> {
561 let mut bytes = Cursor::new(Vec::new());
562 let name_one = Name::new_unchecked("ex.com.");
563 name_one.write_to(&mut bytes)?;
564
565 assert_eq!(8, bytes.get_ref().len());
566 assert_eq!(bytes.get_ref().len(), name_one.len());
567 assert_eq!(8, Name::parse(bytes.get_ref(), &mut 0)?.len());
568
569 let mut name_refs = HashMap::new();
570 let mut bytes = Cursor::new(Vec::new());
571 name_one.write_compressed_to(&mut bytes, &mut name_refs)?;
572 name_one.write_compressed_to(&mut bytes, &mut name_refs)?;
573
574 assert_eq!(10, bytes.get_ref().len());
575 Ok(())
576 }
577
578 #[test]
579 fn hash() -> crate::Result<()> {
580 let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
581
582 assert_eq!(
583 get_hash(&Name::new("F.ISI.ARPA")?),
584 get_hash(&Name::parse(data, &mut 3)?)
585 );
586
587 assert_eq!(
588 get_hash(&Name::new("FOO.F.ISI.ARPA")?),
589 get_hash(&Name::parse(data, &mut 15)?)
590 );
591
592 Ok(())
593 }
594
595 fn get_hash(name: &Name) -> u64 {
596 let mut hasher = DefaultHasher::default();
597 name.hash(&mut hasher);
598 hasher.finish()
599 }
600
601 #[test]
602 fn is_subdomain_of() {
603 assert!(Name::new_unchecked("sub.example.com")
604 .is_subdomain_of(&Name::new_unchecked("example.com")));
605
606 assert!(!Name::new_unchecked("example.com")
607 .is_subdomain_of(&Name::new_unchecked("example.com")));
608
609 assert!(Name::new_unchecked("foo.sub.example.com")
610 .is_subdomain_of(&Name::new_unchecked("example.com")));
611
612 assert!(!Name::new_unchecked("example.com")
613 .is_subdomain_of(&Name::new_unchecked("example.xom")));
614
615 assert!(!Name::new_unchecked("domain.com")
616 .is_subdomain_of(&Name::new_unchecked("other.domain")));
617
618 assert!(!Name::new_unchecked("domain.com")
619 .is_subdomain_of(&Name::new_unchecked("domain.com.br")));
620 }
621
622 #[test]
623 fn subtract_domain() {
624 let domain = Name::new_unchecked("_srv3._tcp.local");
625 assert_eq!(
626 Name::new_unchecked("a._srv3._tcp.local")
627 .without(&domain)
628 .unwrap()
629 .to_string(),
630 "a"
631 );
632
633 assert!(Name::new_unchecked("unrelated").without(&domain).is_none(),);
634
635 assert_eq!(
636 Name::new_unchecked("some.longer.domain._srv3._tcp.local")
637 .without(&domain)
638 .unwrap()
639 .to_string(),
640 "some.longer.domain"
641 );
642 }
643}