route_recognizer/
nfa.rs

1use std::collections::HashSet;
2
3use self::CharacterClass::{Ascii, InvalidChars, ValidChars};
4
5#[derive(PartialEq, Eq, Clone, Default, Debug)]
6pub struct CharSet {
7    low_mask: u64,
8    high_mask: u64,
9    non_ascii: HashSet<char>,
10}
11
12impl CharSet {
13    pub fn new() -> Self {
14        Self {
15            low_mask: 0,
16            high_mask: 0,
17            non_ascii: HashSet::new(),
18        }
19    }
20
21    pub fn insert(&mut self, char: char) {
22        let val = char as u32 - 1;
23
24        if val > 127 {
25            self.non_ascii.insert(char);
26        } else if val > 63 {
27            let bit = 1 << (val - 64);
28            self.high_mask |= bit;
29        } else {
30            let bit = 1 << val;
31            self.low_mask |= bit;
32        }
33    }
34
35    pub fn contains(&self, char: char) -> bool {
36        let val = char as u32 - 1;
37
38        if val > 127 {
39            self.non_ascii.contains(&char)
40        } else if val > 63 {
41            let bit = 1 << (val - 64);
42            self.high_mask & bit != 0
43        } else {
44            let bit = 1 << val;
45            self.low_mask & bit != 0
46        }
47    }
48}
49
50#[derive(PartialEq, Eq, Clone, Debug)]
51pub enum CharacterClass {
52    Ascii(u64, u64, bool),
53    ValidChars(CharSet),
54    InvalidChars(CharSet),
55}
56
57impl CharacterClass {
58    pub fn any() -> Self {
59        Ascii(u64::max_value(), u64::max_value(), true)
60    }
61
62    pub fn valid(string: &str) -> Self {
63        ValidChars(Self::str_to_set(string))
64    }
65
66    pub fn invalid(string: &str) -> Self {
67        InvalidChars(Self::str_to_set(string))
68    }
69
70    pub fn valid_char(char: char) -> Self {
71        let val = char as u32 - 1;
72
73        if val > 127 {
74            ValidChars(Self::char_to_set(char))
75        } else if val > 63 {
76            Ascii(1 << (val - 64), 0, false)
77        } else {
78            Ascii(0, 1 << val, false)
79        }
80    }
81
82    pub fn invalid_char(char: char) -> Self {
83        let val = char as u32 - 1;
84
85        if val > 127 {
86            InvalidChars(Self::char_to_set(char))
87        } else if val > 63 {
88            Ascii(u64::max_value() ^ (1 << (val - 64)), u64::max_value(), true)
89        } else {
90            Ascii(u64::max_value(), u64::max_value() ^ (1 << val), true)
91        }
92    }
93
94    pub fn matches(&self, char: char) -> bool {
95        match *self {
96            ValidChars(ref valid) => valid.contains(char),
97            InvalidChars(ref invalid) => !invalid.contains(char),
98            Ascii(high, low, unicode) => {
99                let val = char as u32 - 1;
100                if val > 127 {
101                    unicode
102                } else if val > 63 {
103                    high & (1 << (val - 64)) != 0
104                } else {
105                    low & (1 << val) != 0
106                }
107            }
108        }
109    }
110
111    fn char_to_set(char: char) -> CharSet {
112        let mut set = CharSet::new();
113        set.insert(char);
114        set
115    }
116
117    fn str_to_set(string: &str) -> CharSet {
118        let mut set = CharSet::new();
119        for char in string.chars() {
120            set.insert(char);
121        }
122        set
123    }
124}
125
126#[derive(Clone)]
127struct Thread {
128    state: usize,
129    captures: Vec<(usize, usize)>,
130    capture_begin: Option<usize>,
131}
132
133impl Thread {
134    pub(crate) fn new() -> Self {
135        Self {
136            state: 0,
137            captures: Vec::new(),
138            capture_begin: None,
139        }
140    }
141
142    #[inline]
143    pub(crate) fn start_capture(&mut self, start: usize) {
144        self.capture_begin = Some(start);
145    }
146
147    #[inline]
148    pub(crate) fn end_capture(&mut self, end: usize) {
149        self.captures.push((self.capture_begin.unwrap(), end));
150        self.capture_begin = None;
151    }
152
153    pub(crate) fn extract<'a>(&self, source: &'a str) -> Vec<&'a str> {
154        self.captures
155            .iter()
156            .map(|&(begin, end)| &source[begin..end])
157            .collect()
158    }
159}
160
161#[derive(Clone, Debug)]
162pub struct State<T> {
163    pub index: usize,
164    pub chars: CharacterClass,
165    pub next_states: Vec<usize>,
166    pub acceptance: bool,
167    pub start_capture: bool,
168    pub end_capture: bool,
169    pub metadata: Option<T>,
170}
171
172impl<T> PartialEq for State<T> {
173    fn eq(&self, other: &Self) -> bool {
174        self.index == other.index
175    }
176}
177
178impl<T> State<T> {
179    pub fn new(index: usize, chars: CharacterClass) -> Self {
180        Self {
181            index,
182            chars,
183            next_states: Vec::new(),
184            acceptance: false,
185            start_capture: false,
186            end_capture: false,
187            metadata: None,
188        }
189    }
190}
191
192#[derive(Debug)]
193pub struct Match<'a> {
194    pub state: usize,
195    pub captures: Vec<&'a str>,
196}
197
198impl<'a> Match<'a> {
199    pub fn new(state: usize, captures: Vec<&'_ str>) -> Match<'_> {
200        Match { state, captures }
201    }
202}
203
204#[derive(Clone, Default, Debug)]
205pub struct NFA<T> {
206    states: Vec<State<T>>,
207    start_capture: Vec<bool>,
208    end_capture: Vec<bool>,
209    acceptance: Vec<bool>,
210}
211
212impl<T> NFA<T> {
213    pub fn new() -> Self {
214        let root = State::new(0, CharacterClass::any());
215        Self {
216            states: vec![root],
217            start_capture: vec![false],
218            end_capture: vec![false],
219            acceptance: vec![false],
220        }
221    }
222
223    pub fn process<'a, I, F>(&self, string: &'a str, mut ord: F) -> Result<Match<'a>, String>
224    where
225        I: Ord,
226        F: FnMut(usize) -> I,
227    {
228        let mut threads = vec![Thread::new()];
229
230        for (i, char) in string.char_indices() {
231            let next_threads = self.process_char(threads, char, i);
232
233            if next_threads.is_empty() {
234                return Err(format!("Couldn't process {}", string));
235            }
236
237            threads = next_threads;
238        }
239
240        let returned = threads
241            .into_iter()
242            .filter(|thread| self.get(thread.state).acceptance);
243
244        let thread = returned
245            .fold(None, |prev, y| {
246                let y_v = ord(y.state);
247                match prev {
248                    None => Some((y_v, y)),
249                    Some((x_v, x)) => {
250                        if x_v < y_v {
251                            Some((y_v, y))
252                        } else {
253                            Some((x_v, x))
254                        }
255                    }
256                }
257            })
258            .map(|p| p.1);
259
260        match thread {
261            None => Err("The string was exhausted before reaching an \
262                         acceptance state"
263                .to_string()),
264            Some(mut thread) => {
265                if thread.capture_begin.is_some() {
266                    thread.end_capture(string.len());
267                }
268                let state = self.get(thread.state);
269                Ok(Match::new(state.index, thread.extract(string)))
270            }
271        }
272    }
273
274    #[inline]
275    fn process_char(&self, threads: Vec<Thread>, char: char, pos: usize) -> Vec<Thread> {
276        let mut returned = Vec::with_capacity(threads.len());
277
278        for mut thread in threads {
279            let current_state = self.get(thread.state);
280
281            let mut count = 0;
282            let mut found_state = 0;
283
284            for &index in &current_state.next_states {
285                let state = &self.states[index];
286
287                if state.chars.matches(char) {
288                    count += 1;
289                    found_state = index;
290                }
291            }
292
293            if count == 1 {
294                thread.state = found_state;
295                capture(self, &mut thread, current_state.index, found_state, pos);
296                returned.push(thread);
297                continue;
298            }
299
300            for &index in &current_state.next_states {
301                let state = &self.states[index];
302                if state.chars.matches(char) {
303                    let mut thread = fork_thread(&thread, state);
304                    capture(self, &mut thread, current_state.index, index, pos);
305                    returned.push(thread);
306                }
307            }
308        }
309
310        returned
311    }
312
313    #[inline]
314    pub fn get(&self, state: usize) -> &State<T> {
315        &self.states[state]
316    }
317
318    pub fn get_mut(&mut self, state: usize) -> &mut State<T> {
319        &mut self.states[state]
320    }
321
322    pub fn put(&mut self, index: usize, chars: CharacterClass) -> usize {
323        {
324            let state = self.get(index);
325
326            for &index in &state.next_states {
327                let state = self.get(index);
328                if state.chars == chars {
329                    return index;
330                }
331            }
332        }
333
334        let state = self.new_state(chars);
335        self.get_mut(index).next_states.push(state);
336        state
337    }
338
339    pub fn put_state(&mut self, index: usize, child: usize) {
340        if !self.states[index].next_states.contains(&child) {
341            self.get_mut(index).next_states.push(child);
342        }
343    }
344
345    pub fn acceptance(&mut self, index: usize) {
346        self.get_mut(index).acceptance = true;
347        self.acceptance[index] = true;
348    }
349
350    pub fn start_capture(&mut self, index: usize) {
351        self.get_mut(index).start_capture = true;
352        self.start_capture[index] = true;
353    }
354
355    pub fn end_capture(&mut self, index: usize) {
356        self.get_mut(index).end_capture = true;
357        self.end_capture[index] = true;
358    }
359
360    pub fn metadata(&mut self, index: usize, metadata: T) {
361        self.get_mut(index).metadata = Some(metadata);
362    }
363
364    fn new_state(&mut self, chars: CharacterClass) -> usize {
365        let index = self.states.len();
366        let state = State::new(index, chars);
367        self.states.push(state);
368
369        self.acceptance.push(false);
370        self.start_capture.push(false);
371        self.end_capture.push(false);
372
373        index
374    }
375}
376
377#[inline]
378fn fork_thread<T>(thread: &Thread, state: &State<T>) -> Thread {
379    let mut new_trace = thread.clone();
380    new_trace.state = state.index;
381    new_trace
382}
383
384#[inline]
385fn capture<T>(
386    nfa: &NFA<T>,
387    thread: &mut Thread,
388    current_state: usize,
389    next_state: usize,
390    pos: usize,
391) {
392    if thread.capture_begin == None && nfa.start_capture[next_state] {
393        thread.start_capture(pos);
394    }
395
396    if thread.capture_begin != None && nfa.end_capture[current_state] && next_state > current_state
397    {
398        thread.end_capture(pos);
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::{CharSet, CharacterClass, NFA};
405
406    #[test]
407    fn basic_test() {
408        let mut nfa = NFA::<()>::new();
409        let a = nfa.put(0, CharacterClass::valid("h"));
410        let b = nfa.put(a, CharacterClass::valid("e"));
411        let c = nfa.put(b, CharacterClass::valid("l"));
412        let d = nfa.put(c, CharacterClass::valid("l"));
413        let e = nfa.put(d, CharacterClass::valid("o"));
414        nfa.acceptance(e);
415
416        let m = nfa.process("hello", |a| a);
417
418        assert!(
419            m.unwrap().state == e,
420            "You didn't get the right final state"
421        );
422    }
423
424    #[test]
425    fn multiple_solutions() {
426        let mut nfa = NFA::<()>::new();
427        let a1 = nfa.put(0, CharacterClass::valid("n"));
428        let b1 = nfa.put(a1, CharacterClass::valid("e"));
429        let c1 = nfa.put(b1, CharacterClass::valid("w"));
430        nfa.acceptance(c1);
431
432        let a2 = nfa.put(0, CharacterClass::invalid(""));
433        let b2 = nfa.put(a2, CharacterClass::invalid(""));
434        let c2 = nfa.put(b2, CharacterClass::invalid(""));
435        nfa.acceptance(c2);
436
437        let m = nfa.process("new", |a| a);
438
439        assert!(m.unwrap().state == c2, "The two states were not found");
440    }
441
442    #[test]
443    fn multiple_paths() {
444        let mut nfa = NFA::<()>::new();
445        let a = nfa.put(0, CharacterClass::valid("t")); // t
446        let b1 = nfa.put(a, CharacterClass::valid("h")); // th
447        let c1 = nfa.put(b1, CharacterClass::valid("o")); // tho
448        let d1 = nfa.put(c1, CharacterClass::valid("m")); // thom
449        let e1 = nfa.put(d1, CharacterClass::valid("a")); // thoma
450        let f1 = nfa.put(e1, CharacterClass::valid("s")); // thomas
451
452        let b2 = nfa.put(a, CharacterClass::valid("o")); // to
453        let c2 = nfa.put(b2, CharacterClass::valid("m")); // tom
454
455        nfa.acceptance(f1);
456        nfa.acceptance(c2);
457
458        let thomas = nfa.process("thomas", |a| a);
459        let tom = nfa.process("tom", |a| a);
460        let thom = nfa.process("thom", |a| a);
461        let nope = nfa.process("nope", |a| a);
462
463        assert!(thomas.unwrap().state == f1, "thomas was parsed correctly");
464        assert!(tom.unwrap().state == c2, "tom was parsed correctly");
465        assert!(thom.is_err(), "thom didn't reach an acceptance state");
466        assert!(nope.is_err(), "nope wasn't parsed");
467    }
468
469    #[test]
470    fn repetitions() {
471        let mut nfa = NFA::<()>::new();
472        let a = nfa.put(0, CharacterClass::valid("p")); // p
473        let b = nfa.put(a, CharacterClass::valid("o")); // po
474        let c = nfa.put(b, CharacterClass::valid("s")); // pos
475        let d = nfa.put(c, CharacterClass::valid("t")); // post
476        let e = nfa.put(d, CharacterClass::valid("s")); // posts
477        let f = nfa.put(e, CharacterClass::valid("/")); // posts/
478        let g = nfa.put(f, CharacterClass::invalid("/")); // posts/[^/]
479        nfa.put_state(g, g);
480
481        nfa.acceptance(g);
482
483        let post = nfa.process("posts/1", |a| a);
484        let new_post = nfa.process("posts/new", |a| a);
485        let invalid = nfa.process("posts/", |a| a);
486
487        assert!(post.unwrap().state == g, "posts/1 was parsed");
488        assert!(new_post.unwrap().state == g, "posts/new was parsed");
489        assert!(invalid.is_err(), "posts/ was invalid");
490    }
491
492    #[test]
493    fn repetitions_with_ambiguous() {
494        let mut nfa = NFA::<()>::new();
495        let a = nfa.put(0, CharacterClass::valid("p")); // p
496        let b = nfa.put(a, CharacterClass::valid("o")); // po
497        let c = nfa.put(b, CharacterClass::valid("s")); // pos
498        let d = nfa.put(c, CharacterClass::valid("t")); // post
499        let e = nfa.put(d, CharacterClass::valid("s")); // posts
500        let f = nfa.put(e, CharacterClass::valid("/")); // posts/
501        let g1 = nfa.put(f, CharacterClass::invalid("/")); // posts/[^/]
502        let g2 = nfa.put(f, CharacterClass::valid("n")); // posts/n
503        let h2 = nfa.put(g2, CharacterClass::valid("e")); // posts/ne
504        let i2 = nfa.put(h2, CharacterClass::valid("w")); // posts/new
505
506        nfa.put_state(g1, g1);
507
508        nfa.acceptance(g1);
509        nfa.acceptance(i2);
510
511        let post = nfa.process("posts/1", |a| a);
512        let ambiguous = nfa.process("posts/new", |a| a);
513        let invalid = nfa.process("posts/", |a| a);
514
515        assert!(post.unwrap().state == g1, "posts/1 was parsed");
516        assert!(ambiguous.unwrap().state == i2, "posts/new was ambiguous");
517        assert!(invalid.is_err(), "posts/ was invalid");
518    }
519
520    #[test]
521    fn captures() {
522        let mut nfa = NFA::<()>::new();
523        let a = nfa.put(0, CharacterClass::valid("n"));
524        let b = nfa.put(a, CharacterClass::valid("e"));
525        let c = nfa.put(b, CharacterClass::valid("w"));
526
527        nfa.acceptance(c);
528        nfa.start_capture(a);
529        nfa.end_capture(c);
530
531        let post = nfa.process("new", |a| a);
532
533        assert_eq!(post.unwrap().captures, vec!["new"]);
534    }
535
536    #[test]
537    fn capture_mid_match() {
538        let mut nfa = NFA::<()>::new();
539        let a = nfa.put(0, valid('p'));
540        let b = nfa.put(a, valid('/'));
541        let c = nfa.put(b, invalid('/'));
542        let d = nfa.put(c, valid('/'));
543        let e = nfa.put(d, valid('c'));
544
545        nfa.put_state(c, c);
546        nfa.acceptance(e);
547        nfa.start_capture(c);
548        nfa.end_capture(c);
549
550        let post = nfa.process("p/123/c", |a| a);
551
552        assert_eq!(post.unwrap().captures, vec!["123"]);
553    }
554
555    #[test]
556    fn capture_multiple_captures() {
557        let mut nfa = NFA::<()>::new();
558        let a = nfa.put(0, valid('p'));
559        let b = nfa.put(a, valid('/'));
560        let c = nfa.put(b, invalid('/'));
561        let d = nfa.put(c, valid('/'));
562        let e = nfa.put(d, valid('c'));
563        let f = nfa.put(e, valid('/'));
564        let g = nfa.put(f, invalid('/'));
565
566        nfa.put_state(c, c);
567        nfa.put_state(g, g);
568        nfa.acceptance(g);
569
570        nfa.start_capture(c);
571        nfa.end_capture(c);
572
573        nfa.start_capture(g);
574        nfa.end_capture(g);
575
576        let post = nfa.process("p/123/c/456", |a| a);
577        assert_eq!(post.unwrap().captures, vec!["123", "456"]);
578    }
579
580    #[test]
581    fn test_ascii_set() {
582        let mut set = CharSet::new();
583        set.insert('?');
584        set.insert('a');
585        set.insert('é');
586
587        assert!(set.contains('?'), "The set contains char 63");
588        assert!(set.contains('a'), "The set contains char 97");
589        assert!(set.contains('é'), "The set contains char 233");
590        assert!(!set.contains('q'), "The set does not contain q");
591        assert!(!set.contains('ü'), "The set does not contain ü");
592    }
593
594    fn valid(char: char) -> CharacterClass {
595        CharacterClass::valid_char(char)
596    }
597
598    fn invalid(char: char) -> CharacterClass {
599        CharacterClass::invalid_char(char)
600    }
601}