simple_dns/dns/
name.rs

1use std::{
2    borrow::Cow,
3    collections::HashMap,
4    convert::{TryFrom, TryInto},
5    fmt::Display,
6    hash::Hash,
7};
8
9use super::{WireFormat, MAX_LABEL_LENGTH, MAX_NAME_LENGTH};
10
11const POINTER_MASK: u8 = 0b1100_0000;
12const POINTER_MASK_U16: u16 = 0b1100_0000_0000_0000;
13
14// NOTE: there are no extend labels implemented today
15// const EXTENDED_LABEL: u8 = 0b0100_0000;
16// const EXTENDED_LABEL_U16: u16 = 0b0100_0000_0000_0000;
17
18/// A Name represents a domain-name, which consists of character strings separated by dots.  
19/// Each section of a name is called label  
20/// ex: `google.com` consists of two labels `google` and `com`
21#[derive(Eq, Clone)]
22pub struct Name<'a> {
23    labels: Vec<Label<'a>>,
24}
25
26impl<'a> Name<'a> {
27    /// Creates a new validated Name
28    pub fn new(name: &'a str) -> crate::Result<Self> {
29        let labels = NameSpliter::new(name.as_bytes())
30            .map(Label::new)
31            .collect::<Result<Vec<Label>, _>>()?;
32
33        let name = Self { labels };
34
35        if name.len() > MAX_NAME_LENGTH {
36            Err(crate::SimpleDnsError::InvalidServiceName)
37        } else {
38            Ok(name)
39        }
40    }
41
42    /// Create a new Name without checking for size limits
43    pub fn new_unchecked(name: &'a str) -> Self {
44        let labels = NameSpliter::new(name.as_bytes())
45            .map(Label::new_unchecked)
46            .collect();
47
48        Self { labels }
49    }
50
51    /// Verify if name ends with .local.
52    pub fn is_link_local(&self) -> bool {
53        match self.iter().last() {
54            Some(label) => b"local".eq_ignore_ascii_case(&label.data),
55            None => false,
56        }
57    }
58
59    /// Returns an Iter of this Name Labels
60    pub fn iter(&'a self) -> std::slice::Iter<Label<'a>> {
61        self.labels.iter()
62    }
63
64    /// Returns true if self is a subdomain of other
65    pub fn is_subdomain_of(&self, other: &Name) -> bool {
66        self.labels.len() > other.labels.len()
67            && other
68                .iter()
69                .rev()
70                .zip(self.iter().rev())
71                .all(|(o, s)| *o == *s)
72    }
73
74    /// Returns the subdomain part of self, based on `domain`.
75    /// If self is not a subdomain of `domain`, returns None
76    ///
77    /// Example:
78    /// ```
79    /// # use simple_dns::Name;
80    /// let name = Name::new_unchecked("sub.domain.local");
81    /// let domain = Name::new_unchecked("domain.local");
82    ///
83    /// assert!(domain.without(&name).is_none());
84    ///
85    /// let sub = name.without(&domain).unwrap();
86    /// assert_eq!(sub.to_string(), "sub")
87    /// ```
88    pub fn without(&self, domain: &Name) -> Option<Name> {
89        if self.is_subdomain_of(domain) {
90            let labels = self.labels[..self.labels.len() - domain.labels.len()].to_vec();
91
92            Some(Name { labels })
93        } else {
94            None
95        }
96    }
97
98    /// Transforms the inner data into its owned type
99    pub fn into_owned<'b>(self) -> Name<'b> {
100        Name {
101            labels: self.labels.into_iter().map(|l| l.into_owned()).collect(),
102        }
103    }
104
105    /// Get the labels that compose this name
106    pub fn get_labels(&'_ self) -> &'_ [Label<'_>] {
107        &self.labels[..]
108    }
109
110    fn plain_append<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
111        for label in self.iter() {
112            out.write_all(&[label.len() as u8])?;
113            out.write_all(&label.data)?;
114        }
115
116        out.write_all(&[0])?;
117        Ok(())
118    }
119
120    fn compress_append<T: std::io::Write + std::io::Seek>(
121        &'a self,
122        out: &mut T,
123        name_refs: &mut HashMap<&'a [Label<'a>], usize>,
124    ) -> crate::Result<()> {
125        for (i, label) in self.iter().enumerate() {
126            match name_refs.entry(&self.labels[i..]) {
127                std::collections::hash_map::Entry::Occupied(e) => {
128                    let p = *e.get() as u16;
129                    out.write_all(&(p | POINTER_MASK_U16).to_be_bytes())?;
130
131                    return Ok(());
132                }
133                std::collections::hash_map::Entry::Vacant(e) => {
134                    e.insert(out.stream_position()? as usize);
135                    out.write_all(&[label.len() as u8])?;
136                    out.write_all(&label.data)?;
137                }
138            }
139        }
140
141        out.write_all(&[0])?;
142        Ok(())
143    }
144}
145
146impl<'a> WireFormat<'a> for Name<'a> {
147    fn parse(data: &'a [u8], position: &mut usize) -> crate::Result<Self>
148    where
149        Self: Sized,
150    {
151        let mut following_compression_pointer = false;
152        let mut labels = Vec::new();
153
154        let mut pointer_position = *position;
155
156        // avoid invalid data caused oom
157        let mut name_size = 0usize;
158
159        loop {
160            if *position >= data.len() {
161                return Err(crate::SimpleDnsError::InsufficientData);
162            }
163
164            // domain name max size is 255
165            if name_size >= MAX_NAME_LENGTH {
166                return Err(crate::SimpleDnsError::InvalidDnsPacket);
167            }
168
169            match data[pointer_position] {
170                0 => {
171                    *position += 1;
172                    break;
173                }
174                len if len & POINTER_MASK == POINTER_MASK => {
175                    if !following_compression_pointer {
176                        *position += 1;
177                    }
178
179                    following_compression_pointer = true;
180                    if pointer_position + 2 > data.len() {
181                        return Err(crate::SimpleDnsError::InsufficientData);
182                    }
183
184                    // avoid pointer forward (RFC 1035)
185                    let pointer = (u16::from_be_bytes(
186                        data[pointer_position..pointer_position + 2].try_into()?,
187                    ) & !POINTER_MASK_U16) as usize;
188                    if pointer >= pointer_position {
189                        return Err(crate::SimpleDnsError::InvalidDnsPacket);
190                    }
191                    pointer_position = pointer;
192                }
193                len => {
194                    name_size += 1 + len as usize;
195                    if pointer_position + 1 + len as usize > data.len() {
196                        return Err(crate::SimpleDnsError::InsufficientData);
197                    }
198
199                    labels.push(Label::new(
200                        &data[pointer_position + 1..pointer_position + 1 + len as usize],
201                    )?);
202
203                    if !following_compression_pointer {
204                        *position += len as usize + 1;
205                    }
206                    pointer_position += len as usize + 1;
207                }
208            }
209        }
210
211        Ok(Self { labels })
212    }
213
214    fn write_to<T: std::io::Write>(&self, out: &mut T) -> crate::Result<()> {
215        self.plain_append(out)
216    }
217
218    fn write_compressed_to<T: std::io::Write + std::io::Seek>(
219        &'a self,
220        out: &mut T,
221        name_refs: &mut HashMap<&'a [Label<'a>], usize>,
222    ) -> crate::Result<()> {
223        self.compress_append(out, name_refs)
224    }
225
226    fn len(&self) -> usize {
227        self.labels
228            .iter()
229            .map(|label| label.len() + 1)
230            .sum::<usize>()
231            + 1
232        // self.total_size
233    }
234}
235
236impl<'a> TryFrom<&'a str> for Name<'a> {
237    type Error = crate::SimpleDnsError;
238
239    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
240        Name::new(value)
241    }
242}
243
244impl<'a> Display for Name<'a> {
245    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
246        for (i, label) in self.iter().enumerate() {
247            if i != 0 {
248                f.write_str(".")?;
249            }
250
251            f.write_fmt(format_args!("{}", label))?;
252        }
253
254        Ok(())
255    }
256}
257
258impl<'a> std::fmt::Debug for Name<'a> {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.debug_tuple("Name")
261            .field(&format!("{}", self))
262            .field(&format!("{}", self.len()))
263            .finish()
264    }
265}
266
267impl<'a> PartialEq for Name<'a> {
268    fn eq(&self, other: &Self) -> bool {
269        self.labels == other.labels
270    }
271}
272
273impl<'a> Hash for Name<'a> {
274    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
275        self.labels.hash(state);
276    }
277}
278
279struct NameSpliter<'a> {
280    bytes: &'a [u8],
281    current: usize,
282}
283
284impl<'a> NameSpliter<'a> {
285    fn new(bytes: &'a [u8]) -> Self {
286        Self { bytes, current: 0 }
287    }
288}
289
290impl<'a> Iterator for NameSpliter<'a> {
291    type Item = Cow<'a, [u8]>;
292
293    fn next(&mut self) -> Option<Self::Item> {
294        let mut slices: Vec<&[u8]> = Vec::new();
295
296        for i in self.current..self.bytes.len() {
297            if self.bytes[i] == b'.' && i - self.current > 0 {
298                let current = std::mem::replace(&mut self.current, i + 1);
299                if self.bytes[i - 1] == b'\\' {
300                    slices.push(&self.bytes[current..i - 1]);
301                    continue;
302                }
303
304                return Some(join_slices(slices, &self.bytes[current..i]));
305            }
306        }
307
308        if self.current < self.bytes.len() {
309            let current = std::mem::replace(&mut self.current, self.bytes.len());
310            Some(join_slices(slices, &self.bytes[current..]))
311        } else {
312            None
313        }
314    }
315}
316
317fn join_slices<'a>(mut slices: Vec<&'a [u8]>, slice: &'a [u8]) -> Cow<'a, [u8]> {
318    if slices.is_empty() {
319        slice.into()
320    } else {
321        slices.push(slice);
322
323        slices
324            .iter_mut()
325            .fold(Vec::new(), |mut c, v| {
326                if !c.is_empty() {
327                    c.push(b'.');
328                }
329
330                c.extend(&v[..]);
331                c
332            })
333            .into()
334    }
335}
336
337#[derive(Eq, PartialEq, Hash, Clone)]
338pub struct Label<'a> {
339    data: Cow<'a, [u8]>,
340}
341
342impl<'a> Label<'a> {
343    pub fn new<T: Into<Cow<'a, [u8]>>>(data: T) -> crate::Result<Self> {
344        let label = Self::new_unchecked(data);
345        if label.len() > MAX_LABEL_LENGTH {
346            Err(crate::SimpleDnsError::InvalidServiceLabel)
347        } else {
348            Ok(label)
349        }
350    }
351
352    pub fn new_unchecked<T: Into<Cow<'a, [u8]>>>(data: T) -> Self {
353        Self { data: data.into() }
354    }
355
356    pub fn len(&self) -> usize {
357        self.data.len()
358    }
359
360    pub fn into_owned<'b>(self) -> Label<'b> {
361        Label {
362            data: self.data.into_owned().into(),
363        }
364    }
365}
366
367impl<'a> Display for Label<'a> {
368    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
369        match std::str::from_utf8(&self.data) {
370            Ok(s) => f.write_str(s),
371            Err(_) => Err(std::fmt::Error),
372        }
373    }
374}
375
376impl<'a> std::fmt::Debug for Label<'a> {
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        f.debug_struct("Label")
379            .field("data", &self.to_string())
380            .finish()
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use std::io::Cursor;
387    use std::{collections::hash_map::DefaultHasher, hash::Hasher};
388
389    use super::*;
390    use crate::SimpleDnsError;
391
392    #[test]
393    fn construct_valid_names() -> Result<(), SimpleDnsError> {
394        assert!(Name::new("some").is_ok());
395        assert!(Name::new("some.local").is_ok());
396        assert!(Name::new("some.local.").is_ok());
397        assert!(Name::new("\u{1F600}.local.").is_ok());
398
399        let scaped = Name::new("some\\.local")?;
400        assert_eq!(scaped.labels.len(), 1);
401
402        Ok(())
403    }
404
405    #[test]
406    fn is_link_local() {
407        assert!(!Name::new("some.example.com").unwrap().is_link_local());
408        // assert!(!Name::new("some.example.local").unwrap().is_link_local());
409        assert!(Name::new("some.example.local.").unwrap().is_link_local());
410    }
411
412    #[test]
413    fn parse_without_compression() {
414        let data =
415            b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\x01F\x03ISI\x04ARPA\x00\x04ARPA\x00";
416        let mut position = 3;
417        let name = Name::parse(data, &mut position).unwrap();
418        assert_eq!("F.ISI.ARPA", name.to_string());
419
420        let name = Name::parse(data, &mut position).unwrap();
421        assert_eq!("FOO.F.ISI.ARPA", name.to_string());
422    }
423
424    #[test]
425    fn parse_with_compression() {
426        let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03\x07INVALID\xc0\x1b";
427        let mut offset = 3usize;
428
429        let name = Name::parse(data, &mut offset).unwrap();
430        assert_eq!("F.ISI.ARPA", name.to_string());
431
432        let name = Name::parse(data, &mut offset).unwrap();
433        assert_eq!("FOO.F.ISI.ARPA", name.to_string());
434
435        let name = Name::parse(data, &mut offset).unwrap();
436        assert_eq!("BAR.F.ISI.ARPA", name.to_string());
437
438        assert!(Name::parse(data, &mut offset).is_err());
439    }
440
441    #[test]
442    fn test_write() {
443        let mut bytes = Cursor::new(Vec::with_capacity(30));
444
445        Name::new_unchecked("_srv._udp.local")
446            .write_to(&mut bytes)
447            .unwrap();
448
449        assert_eq!(b"\x04_srv\x04_udp\x05local\x00", &bytes.get_ref()[..]);
450
451        let mut bytes = Cursor::new(Vec::with_capacity(30));
452        Name::new_unchecked("_srv._udp.local2.")
453            .write_to(&mut bytes)
454            .unwrap();
455
456        assert_eq!(b"\x04_srv\x04_udp\x06local2\x00", &bytes.get_ref()[..]);
457    }
458
459    #[test]
460    fn append_to_vec_with_compression() {
461        let mut buf = Cursor::new(vec![0, 0, 0]);
462        buf.set_position(3);
463
464        let mut name_refs = HashMap::new();
465
466        let f_isi_arpa = Name::new_unchecked("F.ISI.ARPA");
467        f_isi_arpa
468            .write_compressed_to(&mut buf, &mut name_refs)
469            .expect("failed to add F.ISI.ARPA");
470        let foo_f_isi_arpa = Name::new_unchecked("FOO.F.ISI.ARPA");
471        foo_f_isi_arpa
472            .write_compressed_to(&mut buf, &mut name_refs)
473            .expect("failed to add FOO.F.ISI.ARPA");
474
475        Name::new_unchecked("BAR.F.ISI.ARPA")
476            .write_compressed_to(&mut buf, &mut name_refs)
477            .expect("failed to add FOO.F.ISI.ARPA");
478
479        let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
480        assert_eq!(data[..], buf.get_ref()[..]);
481    }
482
483    #[test]
484    fn append_to_vec_with_compression_mult_names() {
485        let mut buf = Cursor::new(vec![]);
486        let mut name_refs = HashMap::new();
487
488        let isi_arpa = Name::new_unchecked("ISI.ARPA");
489        isi_arpa
490            .write_compressed_to(&mut buf, &mut name_refs)
491            .expect("failed to add ISI.ARPA");
492
493        let f_isi_arpa = Name::new_unchecked("F.ISI.ARPA");
494        f_isi_arpa
495            .write_compressed_to(&mut buf, &mut name_refs)
496            .expect("failed to add F.ISI.ARPA");
497        let foo_f_isi_arpa = Name::new_unchecked("FOO.F.ISI.ARPA");
498        foo_f_isi_arpa
499            .write_compressed_to(&mut buf, &mut name_refs)
500            .expect("failed to add F.ISI.ARPA");
501        Name::new_unchecked("BAR.F.ISI.ARPA")
502            .write_compressed_to(&mut buf, &mut name_refs)
503            .expect("failed to add F.ISI.ARPA");
504
505        let expected = b"\x03ISI\x04ARPA\x00\x01F\xc0\x00\x03FOO\xc0\x0a\x03BAR\xc0\x0a";
506        assert_eq!(expected[..], buf.get_ref()[..]);
507
508        let mut position = 0;
509        let first = Name::parse(buf.get_ref(), &mut position).unwrap();
510        assert_eq!("ISI.ARPA", first.to_string());
511        let second = Name::parse(buf.get_ref(), &mut position).unwrap();
512        assert_eq!("F.ISI.ARPA", second.to_string());
513        let third = Name::parse(buf.get_ref(), &mut position).unwrap();
514        assert_eq!("FOO.F.ISI.ARPA", third.to_string());
515        let fourth = Name::parse(buf.get_ref(), &mut position).unwrap();
516        assert_eq!("BAR.F.ISI.ARPA", fourth.to_string());
517    }
518
519    #[test]
520    fn ensure_different_domains_are_not_compressed() {
521        let mut buf = Cursor::new(vec![]);
522        let mut name_refs = HashMap::new();
523
524        let foo_bar_baz = Name::new_unchecked("FOO.BAR.BAZ");
525        foo_bar_baz
526            .write_compressed_to(&mut buf, &mut name_refs)
527            .expect("failed to add FOO.BAR.BAZ");
528
529        let foo_bar_buz = Name::new_unchecked("FOO.BAR.BUZ");
530        foo_bar_buz
531            .write_compressed_to(&mut buf, &mut name_refs)
532            .expect("failed to add FOO.BAR.BUZ");
533
534        Name::new_unchecked("FOO.BAR")
535            .write_compressed_to(&mut buf, &mut name_refs)
536            .expect("failed to add FOO.BAR");
537
538        let expected = b"\x03FOO\x03BAR\x03BAZ\x00\x03FOO\x03BAR\x03BUZ\x00\x03FOO\x03BAR\x00";
539        assert_eq!(expected[..], buf.get_ref()[..]);
540    }
541
542    #[test]
543    fn eq_other_name() -> Result<(), SimpleDnsError> {
544        assert_eq!(Name::new("example.com")?, Name::new("example.com")?);
545        assert_ne!(Name::new("some.example.com")?, Name::new("example.com")?);
546        assert_ne!(Name::new("example.co")?, Name::new("example.com")?);
547        assert_ne!(Name::new("example.com.org")?, Name::new("example.com")?);
548
549        let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
550        let mut position = 3;
551        assert_eq!(Name::new("F.ISI.ARPA")?, Name::parse(data, &mut position)?);
552        assert_eq!(
553            Name::new("FOO.F.ISI.ARPA")?,
554            Name::parse(data, &mut position)?
555        );
556        Ok(())
557    }
558
559    #[test]
560    fn len() -> crate::Result<()> {
561        let mut bytes = Cursor::new(Vec::new());
562        let name_one = Name::new_unchecked("ex.com.");
563        name_one.write_to(&mut bytes)?;
564
565        assert_eq!(8, bytes.get_ref().len());
566        assert_eq!(bytes.get_ref().len(), name_one.len());
567        assert_eq!(8, Name::parse(bytes.get_ref(), &mut 0)?.len());
568
569        let mut name_refs = HashMap::new();
570        let mut bytes = Cursor::new(Vec::new());
571        name_one.write_compressed_to(&mut bytes, &mut name_refs)?;
572        name_one.write_compressed_to(&mut bytes, &mut name_refs)?;
573
574        assert_eq!(10, bytes.get_ref().len());
575        Ok(())
576    }
577
578    #[test]
579    fn hash() -> crate::Result<()> {
580        let data = b"\x00\x00\x00\x01F\x03ISI\x04ARPA\x00\x03FOO\xc0\x03\x03BAR\xc0\x03";
581
582        assert_eq!(
583            get_hash(&Name::new("F.ISI.ARPA")?),
584            get_hash(&Name::parse(data, &mut 3)?)
585        );
586
587        assert_eq!(
588            get_hash(&Name::new("FOO.F.ISI.ARPA")?),
589            get_hash(&Name::parse(data, &mut 15)?)
590        );
591
592        Ok(())
593    }
594
595    fn get_hash(name: &Name) -> u64 {
596        let mut hasher = DefaultHasher::default();
597        name.hash(&mut hasher);
598        hasher.finish()
599    }
600
601    #[test]
602    fn is_subdomain_of() {
603        assert!(Name::new_unchecked("sub.example.com")
604            .is_subdomain_of(&Name::new_unchecked("example.com")));
605
606        assert!(!Name::new_unchecked("example.com")
607            .is_subdomain_of(&Name::new_unchecked("example.com")));
608
609        assert!(Name::new_unchecked("foo.sub.example.com")
610            .is_subdomain_of(&Name::new_unchecked("example.com")));
611
612        assert!(!Name::new_unchecked("example.com")
613            .is_subdomain_of(&Name::new_unchecked("example.xom")));
614
615        assert!(!Name::new_unchecked("domain.com")
616            .is_subdomain_of(&Name::new_unchecked("other.domain")));
617
618        assert!(!Name::new_unchecked("domain.com")
619            .is_subdomain_of(&Name::new_unchecked("domain.com.br")));
620    }
621
622    #[test]
623    fn subtract_domain() {
624        let domain = Name::new_unchecked("_srv3._tcp.local");
625        assert_eq!(
626            Name::new_unchecked("a._srv3._tcp.local")
627                .without(&domain)
628                .unwrap()
629                .to_string(),
630            "a"
631        );
632
633        assert!(Name::new_unchecked("unrelated").without(&domain).is_none(),);
634
635        assert_eq!(
636            Name::new_unchecked("some.longer.domain._srv3._tcp.local")
637                .without(&domain)
638                .unwrap()
639                .to_string(),
640            "some.longer.domain"
641        );
642    }
643}