asn1_rs_derive/
container.rs

1use proc_macro2::{Literal, Span, TokenStream};
2use quote::{quote, ToTokens};
3use syn::{
4    parse::ParseStream, parse_quote, spanned::Spanned, Attribute, DataStruct, DeriveInput, Field,
5    Fields, Ident, Lifetime, LitInt, Meta, Type, WherePredicate,
6};
7
8#[derive(Copy, Clone, Debug, PartialEq)]
9pub enum ContainerType {
10    Alias,
11    Sequence,
12    Set,
13}
14
15impl ToTokens for ContainerType {
16    fn to_tokens(&self, tokens: &mut TokenStream) {
17        let s = match self {
18            ContainerType::Alias => quote! {},
19            ContainerType::Sequence => quote! { asn1_rs::Tag::Sequence },
20            ContainerType::Set => quote! { asn1_rs::Tag::Set },
21        };
22        s.to_tokens(tokens)
23    }
24}
25
26#[derive(Clone, Copy, Debug, PartialEq)]
27enum Asn1Type {
28    Ber,
29    Der,
30}
31
32#[derive(Copy, Clone, Debug, PartialEq)]
33pub enum Asn1TagKind {
34    Explicit,
35    Implicit,
36}
37
38impl ToTokens for Asn1TagKind {
39    fn to_tokens(&self, tokens: &mut TokenStream) {
40        let s = match self {
41            Asn1TagKind::Explicit => quote! { asn1_rs::Explicit },
42            Asn1TagKind::Implicit => quote! { asn1_rs::Implicit },
43        };
44        s.to_tokens(tokens)
45    }
46}
47
48#[derive(Copy, Clone, Debug, PartialEq)]
49pub enum Asn1TagClass {
50    Universal,
51    Application,
52    ContextSpecific,
53    Private,
54}
55
56impl ToTokens for Asn1TagClass {
57    fn to_tokens(&self, tokens: &mut TokenStream) {
58        let s = match self {
59            Asn1TagClass::Application => quote! { asn1_rs::Class::APPLICATION },
60            Asn1TagClass::ContextSpecific => quote! { asn1_rs::Class::CONTEXT_SPECIFIC },
61            Asn1TagClass::Private => quote! { asn1_rs::Class::PRIVATE },
62            Asn1TagClass::Universal => quote! { asn1_rs::Class::UNIVERSAL },
63        };
64        s.to_tokens(tokens)
65    }
66}
67
68pub struct Container {
69    pub container_type: ContainerType,
70    pub fields: Vec<FieldInfo>,
71    pub where_predicates: Vec<WherePredicate>,
72    pub error: Option<Attribute>,
73
74    is_any: bool,
75}
76
77impl Container {
78    pub fn from_datastruct(
79        ds: &DataStruct,
80        ast: &DeriveInput,
81        container_type: ContainerType,
82    ) -> Self {
83        let mut is_any = false;
84        match (container_type, &ds.fields) {
85            (ContainerType::Alias, Fields::Unnamed(f)) => {
86                if f.unnamed.len() != 1 {
87                    panic!("Alias: only tuple fields with one element are supported");
88                }
89                match &f.unnamed[0].ty {
90                    Type::Path(type_path)
91                        if type_path
92                            .clone()
93                            .into_token_stream()
94                            .to_string()
95                            .starts_with("Any") =>
96                    {
97                        is_any = true;
98                    }
99                    _ => (),
100                }
101            }
102            (ContainerType::Alias, _) => panic!("BER/DER alias must be used with tuple strucs"),
103            (_, Fields::Unnamed(_)) => panic!("BER/DER sequence cannot be used on tuple structs"),
104            _ => (),
105        }
106
107        let fields = ds.fields.iter().map(FieldInfo::from).collect();
108
109        // get lifetimes from generics
110        let lfts: Vec<_> = ast.generics.lifetimes().collect();
111        let mut where_predicates = Vec::new();
112        if !lfts.is_empty() {
113            // input slice must outlive all lifetimes from Self
114            let lft = Lifetime::new("'ber", Span::call_site());
115            let wh: WherePredicate = parse_quote! { #lft: #(#lfts)+* };
116            where_predicates.push(wh);
117        };
118
119        // get custom attributes on container
120        let error = ast
121            .attrs
122            .iter()
123            .find(|attr| {
124                attr.meta
125                    .path()
126                    .is_ident(&Ident::new("error", Span::call_site()))
127            })
128            .cloned();
129
130        Container {
131            container_type,
132            fields,
133            where_predicates,
134            error,
135            is_any,
136        }
137    }
138
139    pub fn gen_tryfrom(&self) -> TokenStream {
140        let field_names = &self.fields.iter().map(|f| &f.name).collect::<Vec<_>>();
141        let parse_content =
142            derive_ber_sequence_content(&self.fields, Asn1Type::Ber, self.error.is_some());
143        let lifetime = Lifetime::new("'ber", Span::call_site());
144        let wh = &self.where_predicates;
145        let error = if let Some(attr) = &self.error {
146            get_attribute_meta(attr).expect("Invalid error attribute format")
147        } else {
148            quote! { asn1_rs::Error }
149        };
150
151        let fn_content = if self.container_type == ContainerType::Alias {
152            // special case: is this an alias for Any
153            if self.is_any {
154                quote! { Ok(Self(any)) }
155            } else {
156                quote! {
157                    let res = TryFrom::try_from(any)?;
158                    Ok(Self(res))
159                }
160            }
161        } else {
162            quote! {
163                use asn1_rs::nom::*;
164                any.tag().assert_eq(Self::TAG)?;
165
166                // no need to parse sequence, we already have content
167                let i = any.data;
168                //
169                #parse_content
170                //
171                let _ = i; // XXX check if empty?
172                Ok(Self{#(#field_names),*})
173            }
174        };
175        // note: `gen impl` in synstructure takes care of appending extra where clauses if any, and removing
176        // the `where` statement if there are none.
177        quote! {
178            use asn1_rs::{Any, FromBer};
179            use core::convert::TryFrom;
180
181            gen impl<#lifetime> TryFrom<Any<#lifetime>> for @Self where #(#wh)+* {
182                type Error = #error;
183
184                fn try_from(any: Any<#lifetime>) -> asn1_rs::Result<Self, #error> {
185                    #fn_content
186                }
187            }
188        }
189    }
190
191    pub fn gen_tagged(&self) -> TokenStream {
192        let tag = if self.container_type == ContainerType::Alias {
193            // special case: is this an alias for Any
194            if self.is_any {
195                return quote! {};
196            }
197            // find type of sub-item
198            let ty = &self.fields[0].type_;
199            quote! { <#ty as asn1_rs::Tagged>::TAG }
200        } else {
201            let container_type = self.container_type;
202            quote! { #container_type }
203        };
204        quote! {
205            gen impl<'ber> asn1_rs::Tagged for @Self {
206                const TAG: asn1_rs::Tag = #tag;
207            }
208        }
209    }
210
211    pub fn gen_checkconstraints(&self) -> TokenStream {
212        let lifetime = Lifetime::new("'ber", Span::call_site());
213        let wh = &self.where_predicates;
214        // let parse_content = derive_ber_sequence_content(&field_names, Asn1Type::Der);
215
216        let fn_content = if self.container_type == ContainerType::Alias {
217            // special case: is this an alias for Any
218            if self.is_any {
219                return quote! {};
220            }
221            let ty = &self.fields[0].type_;
222            quote! {
223                any.tag().assert_eq(Self::TAG)?;
224                <#ty>::check_constraints(any)
225            }
226        } else {
227            let check_fields: Vec<_> = self
228                .fields
229                .iter()
230                .map(|field| {
231                    let ty = &field.type_;
232                    quote! {
233                        let (rem, any) = Any::from_der(rem)?;
234                        <#ty as CheckDerConstraints>::check_constraints(&any)?;
235                    }
236                })
237                .collect();
238            quote! {
239                any.tag().assert_eq(Self::TAG)?;
240                let rem = &any.data;
241                #(#check_fields)*
242                Ok(())
243            }
244        };
245
246        // note: `gen impl` in synstructure takes care of appending extra where clauses if any, and removing
247        // the `where` statement if there are none.
248        quote! {
249            use asn1_rs::{CheckDerConstraints, Tagged};
250            gen impl<#lifetime> CheckDerConstraints for @Self where #(#wh)+* {
251                fn check_constraints(any: &Any) -> asn1_rs::Result<()> {
252                    #fn_content
253                }
254            }
255        }
256    }
257
258    pub fn gen_fromder(&self) -> TokenStream {
259        let lifetime = Lifetime::new("'ber", Span::call_site());
260        let wh = &self.where_predicates;
261        let field_names = &self.fields.iter().map(|f| &f.name).collect::<Vec<_>>();
262        let parse_content =
263            derive_ber_sequence_content(&self.fields, Asn1Type::Der, self.error.is_some());
264        let error = if let Some(attr) = &self.error {
265            get_attribute_meta(attr).expect("Invalid error attribute format")
266        } else {
267            quote! { asn1_rs::Error }
268        };
269
270        let fn_content = if self.container_type == ContainerType::Alias {
271            // special case: is this an alias for Any
272            if self.is_any {
273                quote! {
274                    let (rem, any) = asn1_rs::Any::from_der(bytes).map_err(asn1_rs::nom::Err::convert)?;
275                    Ok((rem,Self(any)))
276                }
277            } else {
278                quote! {
279                    let (rem, any) = asn1_rs::Any::from_der(bytes).map_err(asn1_rs::nom::Err::convert)?;
280                    any.header.assert_tag(Self::TAG).map_err(|e| asn1_rs::nom::Err::Error(e.into()))?;
281                    let res = TryFrom::try_from(any)?;
282                    Ok((rem,Self(res)))
283                }
284            }
285        } else {
286            quote! {
287                let (rem, any) = asn1_rs::Any::from_der(bytes).map_err(asn1_rs::nom::Err::convert)?;
288                any.header.assert_tag(Self::TAG).map_err(|e| asn1_rs::nom::Err::Error(e.into()))?;
289                let i = any.data;
290                //
291                #parse_content
292                //
293                // let _ = i; // XXX check if empty?
294                Ok((rem,Self{#(#field_names),*}))
295            }
296        };
297        // note: `gen impl` in synstructure takes care of appending extra where clauses if any, and removing
298        // the `where` statement if there are none.
299        quote! {
300            use asn1_rs::FromDer;
301
302            gen impl<#lifetime> asn1_rs::FromDer<#lifetime, #error> for @Self where #(#wh)+* {
303                fn from_der(bytes: &#lifetime [u8]) -> asn1_rs::ParseResult<#lifetime, Self, #error> {
304                    #fn_content
305                }
306            }
307        }
308    }
309}
310
311#[derive(Debug)]
312pub struct FieldInfo {
313    pub name: Ident,
314    pub type_: Type,
315    pub default: Option<TokenStream>,
316    pub optional: bool,
317    pub tag: Option<(Asn1TagKind, Asn1TagClass, u16)>,
318    pub map_err: Option<TokenStream>,
319}
320
321impl From<&Field> for FieldInfo {
322    fn from(field: &Field) -> Self {
323        // parse attributes and keep supported ones
324        let mut optional = false;
325        let mut tag = None;
326        let mut map_err = None;
327        let mut default = None;
328        let name = field
329            .ident
330            .as_ref()
331            .map_or_else(|| Ident::new("_", Span::call_site()), |s| s.clone());
332        for attr in &field.attrs {
333            let ident = match attr.meta.path().get_ident() {
334                Some(ident) => ident.to_string(),
335                None => continue,
336            };
337            match ident.as_str() {
338                "map_err" => {
339                    let expr: syn::Expr = attr.parse_args().expect("could not parse map_err");
340                    map_err = Some(quote! { #expr });
341                }
342                "default" => {
343                    let expr: syn::Expr = attr.parse_args().expect("could not parse default");
344                    default = Some(quote! { #expr });
345                    optional = true;
346                }
347                "optional" => optional = true,
348                "tag_explicit" => {
349                    if tag.is_some() {
350                        panic!("tag cannot be set twice!");
351                    }
352                    let (class, value) = attr.parse_args_with(parse_tag_args).unwrap();
353                    tag = Some((Asn1TagKind::Explicit, class, value));
354                }
355                "tag_implicit" => {
356                    if tag.is_some() {
357                        panic!("tag cannot be set twice!");
358                    }
359                    let (class, value) = attr.parse_args_with(parse_tag_args).unwrap();
360                    tag = Some((Asn1TagKind::Implicit, class, value));
361                }
362                // ignore unknown attributes
363                _ => (),
364            }
365        }
366        FieldInfo {
367            name,
368            type_: field.ty.clone(),
369            default,
370            optional,
371            tag,
372            map_err,
373        }
374    }
375}
376
377fn parse_tag_args(stream: ParseStream) -> Result<(Asn1TagClass, u16), syn::Error> {
378    let tag_class: Option<Ident> = stream.parse()?;
379    let tag_class = if let Some(ident) = tag_class {
380        let s = ident.to_string().to_uppercase();
381        match s.as_str() {
382            "UNIVERSAL" => Asn1TagClass::Universal,
383            "CONTEXT-SPECIFIC" => Asn1TagClass::ContextSpecific,
384            "APPLICATION" => Asn1TagClass::Application,
385            "PRIVATE" => Asn1TagClass::Private,
386            _ => {
387                return Err(syn::Error::new(stream.span(), "Invalid tag class"));
388            }
389        }
390    } else {
391        Asn1TagClass::ContextSpecific
392    };
393    let lit: LitInt = stream.parse()?;
394    let value = lit.base10_parse::<u16>()?;
395    Ok((tag_class, value))
396}
397
398fn derive_ber_sequence_content(
399    fields: &[FieldInfo],
400    asn1_type: Asn1Type,
401    custom_errors: bool,
402) -> TokenStream {
403    let field_parsers: Vec<_> = fields
404        .iter()
405        .map(|f| get_field_parser(f, asn1_type, custom_errors))
406        .collect();
407
408    quote! {
409        #(#field_parsers)*
410    }
411}
412
413fn get_field_parser(f: &FieldInfo, asn1_type: Asn1Type, custom_errors: bool) -> TokenStream {
414    let from = match asn1_type {
415        Asn1Type::Ber => quote! {FromBer::from_ber},
416        Asn1Type::Der => quote! {FromDer::from_der},
417    };
418    let name = &f.name;
419    let default = f
420        .default
421        .as_ref()
422        // use a type hint, otherwise compiler will not know what type provides .unwrap_or
423        .map(|x| quote! {let #name: Option<_> = #name; let #name = #name.unwrap_or(#x);});
424    let map_err = if let Some(tt) = f.map_err.as_ref() {
425        if asn1_type == Asn1Type::Ber {
426            Some(quote! {
427                .map_err(|err| err.map(#tt))
428                .map_err(asn1_rs::from_nom_error::<_, Self::Error>)
429            })
430        } else {
431            // Some(quote! { .map_err(|err| nom::Err::convert(#tt)) })
432            Some(quote! { .map_err(|err| err.map(#tt)) })
433        }
434    } else {
435        // add mapping functions only if custom errors are used
436        if custom_errors {
437            if asn1_type == Asn1Type::Ber {
438                Some(quote! { .map_err(asn1_rs::from_nom_error::<_, Self::Error>) })
439            } else {
440                Some(quote! { .map_err(nom::Err::convert) })
441            }
442        } else {
443            None
444        }
445    };
446    if let Some((tag_kind, class, n)) = f.tag {
447        let tag = Literal::u16_unsuffixed(n);
448        // test if tagged + optional
449        if f.optional {
450            return quote! {
451                let (i, #name) = {
452                    if i.is_empty() {
453                        (i, None)
454                    } else {
455                        let (_, header): (_, asn1_rs::Header) = #from(i)#map_err?;
456                        if header.tag().0 == #tag {
457                            let (i, t): (_, asn1_rs::TaggedValue::<_, _, #tag_kind, {#class}, #tag>) = #from(i)#map_err?;
458                            (i, Some(t.into_inner()))
459                        } else {
460                            (i, None)
461                        }
462                    }
463                };
464                #default
465            };
466        } else {
467            // tagged, but not OPTIONAL
468            return quote! {
469                let (i, #name) = {
470                    let (i, t): (_, asn1_rs::TaggedValue::<_, _, #tag_kind, {#class}, #tag>) = #from(i)#map_err?;
471                    (i, t.into_inner())
472                };
473                #default
474            };
475        }
476    } else {
477        // neither tagged nor optional
478        quote! {
479            let (i, #name) = #from(i)#map_err?;
480            #default
481        }
482    }
483}
484
485fn get_attribute_meta(attr: &Attribute) -> Result<TokenStream, syn::Error> {
486    if let Meta::List(meta) = &attr.meta {
487        let content = &meta.tokens;
488        Ok(quote! { #content })
489    } else {
490        Err(syn::Error::new(
491            attr.span(),
492            "Invalid error attribute format",
493        ))
494    }
495}