1use std::collections::HashSet;
2
3use self::CharacterClass::{Ascii, InvalidChars, ValidChars};
4
5#[derive(PartialEq, Eq, Clone, Default, Debug)]
6pub struct CharSet {
7 low_mask: u64,
8 high_mask: u64,
9 non_ascii: HashSet<char>,
10}
11
12impl CharSet {
13 pub fn new() -> Self {
14 Self {
15 low_mask: 0,
16 high_mask: 0,
17 non_ascii: HashSet::new(),
18 }
19 }
20
21 pub fn insert(&mut self, char: char) {
22 let val = char as u32 - 1;
23
24 if val > 127 {
25 self.non_ascii.insert(char);
26 } else if val > 63 {
27 let bit = 1 << (val - 64);
28 self.high_mask |= bit;
29 } else {
30 let bit = 1 << val;
31 self.low_mask |= bit;
32 }
33 }
34
35 pub fn contains(&self, char: char) -> bool {
36 let val = char as u32 - 1;
37
38 if val > 127 {
39 self.non_ascii.contains(&char)
40 } else if val > 63 {
41 let bit = 1 << (val - 64);
42 self.high_mask & bit != 0
43 } else {
44 let bit = 1 << val;
45 self.low_mask & bit != 0
46 }
47 }
48}
49
50#[derive(PartialEq, Eq, Clone, Debug)]
51pub enum CharacterClass {
52 Ascii(u64, u64, bool),
53 ValidChars(CharSet),
54 InvalidChars(CharSet),
55}
56
57impl CharacterClass {
58 pub fn any() -> Self {
59 Ascii(u64::max_value(), u64::max_value(), true)
60 }
61
62 pub fn valid(string: &str) -> Self {
63 ValidChars(Self::str_to_set(string))
64 }
65
66 pub fn invalid(string: &str) -> Self {
67 InvalidChars(Self::str_to_set(string))
68 }
69
70 pub fn valid_char(char: char) -> Self {
71 let val = char as u32 - 1;
72
73 if val > 127 {
74 ValidChars(Self::char_to_set(char))
75 } else if val > 63 {
76 Ascii(1 << (val - 64), 0, false)
77 } else {
78 Ascii(0, 1 << val, false)
79 }
80 }
81
82 pub fn invalid_char(char: char) -> Self {
83 let val = char as u32 - 1;
84
85 if val > 127 {
86 InvalidChars(Self::char_to_set(char))
87 } else if val > 63 {
88 Ascii(u64::max_value() ^ (1 << (val - 64)), u64::max_value(), true)
89 } else {
90 Ascii(u64::max_value(), u64::max_value() ^ (1 << val), true)
91 }
92 }
93
94 pub fn matches(&self, char: char) -> bool {
95 match *self {
96 ValidChars(ref valid) => valid.contains(char),
97 InvalidChars(ref invalid) => !invalid.contains(char),
98 Ascii(high, low, unicode) => {
99 let val = char as u32 - 1;
100 if val > 127 {
101 unicode
102 } else if val > 63 {
103 high & (1 << (val - 64)) != 0
104 } else {
105 low & (1 << val) != 0
106 }
107 }
108 }
109 }
110
111 fn char_to_set(char: char) -> CharSet {
112 let mut set = CharSet::new();
113 set.insert(char);
114 set
115 }
116
117 fn str_to_set(string: &str) -> CharSet {
118 let mut set = CharSet::new();
119 for char in string.chars() {
120 set.insert(char);
121 }
122 set
123 }
124}
125
126#[derive(Clone)]
127struct Thread {
128 state: usize,
129 captures: Vec<(usize, usize)>,
130 capture_begin: Option<usize>,
131}
132
133impl Thread {
134 pub(crate) fn new() -> Self {
135 Self {
136 state: 0,
137 captures: Vec::new(),
138 capture_begin: None,
139 }
140 }
141
142 #[inline]
143 pub(crate) fn start_capture(&mut self, start: usize) {
144 self.capture_begin = Some(start);
145 }
146
147 #[inline]
148 pub(crate) fn end_capture(&mut self, end: usize) {
149 self.captures.push((self.capture_begin.unwrap(), end));
150 self.capture_begin = None;
151 }
152
153 pub(crate) fn extract<'a>(&self, source: &'a str) -> Vec<&'a str> {
154 self.captures
155 .iter()
156 .map(|&(begin, end)| &source[begin..end])
157 .collect()
158 }
159}
160
161#[derive(Clone, Debug)]
162pub struct State<T> {
163 pub index: usize,
164 pub chars: CharacterClass,
165 pub next_states: Vec<usize>,
166 pub acceptance: bool,
167 pub start_capture: bool,
168 pub end_capture: bool,
169 pub metadata: Option<T>,
170}
171
172impl<T> PartialEq for State<T> {
173 fn eq(&self, other: &Self) -> bool {
174 self.index == other.index
175 }
176}
177
178impl<T> State<T> {
179 pub fn new(index: usize, chars: CharacterClass) -> Self {
180 Self {
181 index,
182 chars,
183 next_states: Vec::new(),
184 acceptance: false,
185 start_capture: false,
186 end_capture: false,
187 metadata: None,
188 }
189 }
190}
191
192#[derive(Debug)]
193pub struct Match<'a> {
194 pub state: usize,
195 pub captures: Vec<&'a str>,
196}
197
198impl<'a> Match<'a> {
199 pub fn new(state: usize, captures: Vec<&'_ str>) -> Match<'_> {
200 Match { state, captures }
201 }
202}
203
204#[derive(Clone, Default, Debug)]
205pub struct NFA<T> {
206 states: Vec<State<T>>,
207 start_capture: Vec<bool>,
208 end_capture: Vec<bool>,
209 acceptance: Vec<bool>,
210}
211
212impl<T> NFA<T> {
213 pub fn new() -> Self {
214 let root = State::new(0, CharacterClass::any());
215 Self {
216 states: vec![root],
217 start_capture: vec![false],
218 end_capture: vec![false],
219 acceptance: vec![false],
220 }
221 }
222
223 pub fn process<'a, I, F>(&self, string: &'a str, mut ord: F) -> Result<Match<'a>, String>
224 where
225 I: Ord,
226 F: FnMut(usize) -> I,
227 {
228 let mut threads = vec![Thread::new()];
229
230 for (i, char) in string.char_indices() {
231 let next_threads = self.process_char(threads, char, i);
232
233 if next_threads.is_empty() {
234 return Err(format!("Couldn't process {}", string));
235 }
236
237 threads = next_threads;
238 }
239
240 let returned = threads
241 .into_iter()
242 .filter(|thread| self.get(thread.state).acceptance);
243
244 let thread = returned
245 .fold(None, |prev, y| {
246 let y_v = ord(y.state);
247 match prev {
248 None => Some((y_v, y)),
249 Some((x_v, x)) => {
250 if x_v < y_v {
251 Some((y_v, y))
252 } else {
253 Some((x_v, x))
254 }
255 }
256 }
257 })
258 .map(|p| p.1);
259
260 match thread {
261 None => Err("The string was exhausted before reaching an \
262 acceptance state"
263 .to_string()),
264 Some(mut thread) => {
265 if thread.capture_begin.is_some() {
266 thread.end_capture(string.len());
267 }
268 let state = self.get(thread.state);
269 Ok(Match::new(state.index, thread.extract(string)))
270 }
271 }
272 }
273
274 #[inline]
275 fn process_char(&self, threads: Vec<Thread>, char: char, pos: usize) -> Vec<Thread> {
276 let mut returned = Vec::with_capacity(threads.len());
277
278 for mut thread in threads {
279 let current_state = self.get(thread.state);
280
281 let mut count = 0;
282 let mut found_state = 0;
283
284 for &index in ¤t_state.next_states {
285 let state = &self.states[index];
286
287 if state.chars.matches(char) {
288 count += 1;
289 found_state = index;
290 }
291 }
292
293 if count == 1 {
294 thread.state = found_state;
295 capture(self, &mut thread, current_state.index, found_state, pos);
296 returned.push(thread);
297 continue;
298 }
299
300 for &index in ¤t_state.next_states {
301 let state = &self.states[index];
302 if state.chars.matches(char) {
303 let mut thread = fork_thread(&thread, state);
304 capture(self, &mut thread, current_state.index, index, pos);
305 returned.push(thread);
306 }
307 }
308 }
309
310 returned
311 }
312
313 #[inline]
314 pub fn get(&self, state: usize) -> &State<T> {
315 &self.states[state]
316 }
317
318 pub fn get_mut(&mut self, state: usize) -> &mut State<T> {
319 &mut self.states[state]
320 }
321
322 pub fn put(&mut self, index: usize, chars: CharacterClass) -> usize {
323 {
324 let state = self.get(index);
325
326 for &index in &state.next_states {
327 let state = self.get(index);
328 if state.chars == chars {
329 return index;
330 }
331 }
332 }
333
334 let state = self.new_state(chars);
335 self.get_mut(index).next_states.push(state);
336 state
337 }
338
339 pub fn put_state(&mut self, index: usize, child: usize) {
340 if !self.states[index].next_states.contains(&child) {
341 self.get_mut(index).next_states.push(child);
342 }
343 }
344
345 pub fn acceptance(&mut self, index: usize) {
346 self.get_mut(index).acceptance = true;
347 self.acceptance[index] = true;
348 }
349
350 pub fn start_capture(&mut self, index: usize) {
351 self.get_mut(index).start_capture = true;
352 self.start_capture[index] = true;
353 }
354
355 pub fn end_capture(&mut self, index: usize) {
356 self.get_mut(index).end_capture = true;
357 self.end_capture[index] = true;
358 }
359
360 pub fn metadata(&mut self, index: usize, metadata: T) {
361 self.get_mut(index).metadata = Some(metadata);
362 }
363
364 fn new_state(&mut self, chars: CharacterClass) -> usize {
365 let index = self.states.len();
366 let state = State::new(index, chars);
367 self.states.push(state);
368
369 self.acceptance.push(false);
370 self.start_capture.push(false);
371 self.end_capture.push(false);
372
373 index
374 }
375}
376
377#[inline]
378fn fork_thread<T>(thread: &Thread, state: &State<T>) -> Thread {
379 let mut new_trace = thread.clone();
380 new_trace.state = state.index;
381 new_trace
382}
383
384#[inline]
385fn capture<T>(
386 nfa: &NFA<T>,
387 thread: &mut Thread,
388 current_state: usize,
389 next_state: usize,
390 pos: usize,
391) {
392 if thread.capture_begin == None && nfa.start_capture[next_state] {
393 thread.start_capture(pos);
394 }
395
396 if thread.capture_begin != None && nfa.end_capture[current_state] && next_state > current_state
397 {
398 thread.end_capture(pos);
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::{CharSet, CharacterClass, NFA};
405
406 #[test]
407 fn basic_test() {
408 let mut nfa = NFA::<()>::new();
409 let a = nfa.put(0, CharacterClass::valid("h"));
410 let b = nfa.put(a, CharacterClass::valid("e"));
411 let c = nfa.put(b, CharacterClass::valid("l"));
412 let d = nfa.put(c, CharacterClass::valid("l"));
413 let e = nfa.put(d, CharacterClass::valid("o"));
414 nfa.acceptance(e);
415
416 let m = nfa.process("hello", |a| a);
417
418 assert!(
419 m.unwrap().state == e,
420 "You didn't get the right final state"
421 );
422 }
423
424 #[test]
425 fn multiple_solutions() {
426 let mut nfa = NFA::<()>::new();
427 let a1 = nfa.put(0, CharacterClass::valid("n"));
428 let b1 = nfa.put(a1, CharacterClass::valid("e"));
429 let c1 = nfa.put(b1, CharacterClass::valid("w"));
430 nfa.acceptance(c1);
431
432 let a2 = nfa.put(0, CharacterClass::invalid(""));
433 let b2 = nfa.put(a2, CharacterClass::invalid(""));
434 let c2 = nfa.put(b2, CharacterClass::invalid(""));
435 nfa.acceptance(c2);
436
437 let m = nfa.process("new", |a| a);
438
439 assert!(m.unwrap().state == c2, "The two states were not found");
440 }
441
442 #[test]
443 fn multiple_paths() {
444 let mut nfa = NFA::<()>::new();
445 let a = nfa.put(0, CharacterClass::valid("t")); let b1 = nfa.put(a, CharacterClass::valid("h")); let c1 = nfa.put(b1, CharacterClass::valid("o")); let d1 = nfa.put(c1, CharacterClass::valid("m")); let e1 = nfa.put(d1, CharacterClass::valid("a")); let f1 = nfa.put(e1, CharacterClass::valid("s")); let b2 = nfa.put(a, CharacterClass::valid("o")); let c2 = nfa.put(b2, CharacterClass::valid("m")); nfa.acceptance(f1);
456 nfa.acceptance(c2);
457
458 let thomas = nfa.process("thomas", |a| a);
459 let tom = nfa.process("tom", |a| a);
460 let thom = nfa.process("thom", |a| a);
461 let nope = nfa.process("nope", |a| a);
462
463 assert!(thomas.unwrap().state == f1, "thomas was parsed correctly");
464 assert!(tom.unwrap().state == c2, "tom was parsed correctly");
465 assert!(thom.is_err(), "thom didn't reach an acceptance state");
466 assert!(nope.is_err(), "nope wasn't parsed");
467 }
468
469 #[test]
470 fn repetitions() {
471 let mut nfa = NFA::<()>::new();
472 let a = nfa.put(0, CharacterClass::valid("p")); let b = nfa.put(a, CharacterClass::valid("o")); let c = nfa.put(b, CharacterClass::valid("s")); let d = nfa.put(c, CharacterClass::valid("t")); let e = nfa.put(d, CharacterClass::valid("s")); let f = nfa.put(e, CharacterClass::valid("/")); let g = nfa.put(f, CharacterClass::invalid("/")); nfa.put_state(g, g);
480
481 nfa.acceptance(g);
482
483 let post = nfa.process("posts/1", |a| a);
484 let new_post = nfa.process("posts/new", |a| a);
485 let invalid = nfa.process("posts/", |a| a);
486
487 assert!(post.unwrap().state == g, "posts/1 was parsed");
488 assert!(new_post.unwrap().state == g, "posts/new was parsed");
489 assert!(invalid.is_err(), "posts/ was invalid");
490 }
491
492 #[test]
493 fn repetitions_with_ambiguous() {
494 let mut nfa = NFA::<()>::new();
495 let a = nfa.put(0, CharacterClass::valid("p")); let b = nfa.put(a, CharacterClass::valid("o")); let c = nfa.put(b, CharacterClass::valid("s")); let d = nfa.put(c, CharacterClass::valid("t")); let e = nfa.put(d, CharacterClass::valid("s")); let f = nfa.put(e, CharacterClass::valid("/")); let g1 = nfa.put(f, CharacterClass::invalid("/")); let g2 = nfa.put(f, CharacterClass::valid("n")); let h2 = nfa.put(g2, CharacterClass::valid("e")); let i2 = nfa.put(h2, CharacterClass::valid("w")); nfa.put_state(g1, g1);
507
508 nfa.acceptance(g1);
509 nfa.acceptance(i2);
510
511 let post = nfa.process("posts/1", |a| a);
512 let ambiguous = nfa.process("posts/new", |a| a);
513 let invalid = nfa.process("posts/", |a| a);
514
515 assert!(post.unwrap().state == g1, "posts/1 was parsed");
516 assert!(ambiguous.unwrap().state == i2, "posts/new was ambiguous");
517 assert!(invalid.is_err(), "posts/ was invalid");
518 }
519
520 #[test]
521 fn captures() {
522 let mut nfa = NFA::<()>::new();
523 let a = nfa.put(0, CharacterClass::valid("n"));
524 let b = nfa.put(a, CharacterClass::valid("e"));
525 let c = nfa.put(b, CharacterClass::valid("w"));
526
527 nfa.acceptance(c);
528 nfa.start_capture(a);
529 nfa.end_capture(c);
530
531 let post = nfa.process("new", |a| a);
532
533 assert_eq!(post.unwrap().captures, vec!["new"]);
534 }
535
536 #[test]
537 fn capture_mid_match() {
538 let mut nfa = NFA::<()>::new();
539 let a = nfa.put(0, valid('p'));
540 let b = nfa.put(a, valid('/'));
541 let c = nfa.put(b, invalid('/'));
542 let d = nfa.put(c, valid('/'));
543 let e = nfa.put(d, valid('c'));
544
545 nfa.put_state(c, c);
546 nfa.acceptance(e);
547 nfa.start_capture(c);
548 nfa.end_capture(c);
549
550 let post = nfa.process("p/123/c", |a| a);
551
552 assert_eq!(post.unwrap().captures, vec!["123"]);
553 }
554
555 #[test]
556 fn capture_multiple_captures() {
557 let mut nfa = NFA::<()>::new();
558 let a = nfa.put(0, valid('p'));
559 let b = nfa.put(a, valid('/'));
560 let c = nfa.put(b, invalid('/'));
561 let d = nfa.put(c, valid('/'));
562 let e = nfa.put(d, valid('c'));
563 let f = nfa.put(e, valid('/'));
564 let g = nfa.put(f, invalid('/'));
565
566 nfa.put_state(c, c);
567 nfa.put_state(g, g);
568 nfa.acceptance(g);
569
570 nfa.start_capture(c);
571 nfa.end_capture(c);
572
573 nfa.start_capture(g);
574 nfa.end_capture(g);
575
576 let post = nfa.process("p/123/c/456", |a| a);
577 assert_eq!(post.unwrap().captures, vec!["123", "456"]);
578 }
579
580 #[test]
581 fn test_ascii_set() {
582 let mut set = CharSet::new();
583 set.insert('?');
584 set.insert('a');
585 set.insert('é');
586
587 assert!(set.contains('?'), "The set contains char 63");
588 assert!(set.contains('a'), "The set contains char 97");
589 assert!(set.contains('é'), "The set contains char 233");
590 assert!(!set.contains('q'), "The set does not contain q");
591 assert!(!set.contains('ü'), "The set does not contain ü");
592 }
593
594 fn valid(char: char) -> CharacterClass {
595 CharacterClass::valid_char(char)
596 }
597
598 fn invalid(char: char) -> CharacterClass {
599 CharacterClass::invalid_char(char)
600 }
601}