prost_derive/field/
scalar.rs

1use std::fmt;
2
3use anyhow::{anyhow, bail, Error};
4use proc_macro2::{Span, TokenStream};
5use quote::{quote, ToTokens, TokenStreamExt};
6use syn::{parse_str, Expr, ExprLit, Ident, Index, Lit, LitByteStr, Meta, MetaNameValue, Path};
7
8use crate::field::{bool_attr, set_option, tag_attr, Label};
9
10/// A scalar protobuf field.
11#[derive(Clone)]
12pub struct Field {
13    pub ty: Ty,
14    pub kind: Kind,
15    pub tag: u32,
16}
17
18impl Field {
19    pub fn new(attrs: &[Meta], inferred_tag: Option<u32>) -> Result<Option<Field>, Error> {
20        let mut ty = None;
21        let mut label = None;
22        let mut packed = None;
23        let mut default = None;
24        let mut tag = None;
25
26        let mut unknown_attrs = Vec::new();
27
28        for attr in attrs {
29            if let Some(t) = Ty::from_attr(attr)? {
30                set_option(&mut ty, t, "duplicate type attributes")?;
31            } else if let Some(p) = bool_attr("packed", attr)? {
32                set_option(&mut packed, p, "duplicate packed attributes")?;
33            } else if let Some(t) = tag_attr(attr)? {
34                set_option(&mut tag, t, "duplicate tag attributes")?;
35            } else if let Some(l) = Label::from_attr(attr) {
36                set_option(&mut label, l, "duplicate label attributes")?;
37            } else if let Some(d) = DefaultValue::from_attr(attr)? {
38                set_option(&mut default, d, "duplicate default attributes")?;
39            } else {
40                unknown_attrs.push(attr);
41            }
42        }
43
44        let ty = match ty {
45            Some(ty) => ty,
46            None => return Ok(None),
47        };
48
49        match unknown_attrs.len() {
50            0 => (),
51            1 => bail!("unknown attribute: {:?}", unknown_attrs[0]),
52            _ => bail!("unknown attributes: {:?}", unknown_attrs),
53        }
54
55        let tag = match tag.or(inferred_tag) {
56            Some(tag) => tag,
57            None => bail!("missing tag attribute"),
58        };
59
60        let has_default = default.is_some();
61        let default = default.map_or_else(
62            || Ok(DefaultValue::new(&ty)),
63            |lit| DefaultValue::from_lit(&ty, lit),
64        )?;
65
66        let kind = match (label, packed, has_default) {
67            (None, Some(true), _)
68            | (Some(Label::Optional), Some(true), _)
69            | (Some(Label::Required), Some(true), _) => {
70                bail!("packed attribute may only be applied to repeated fields");
71            }
72            (Some(Label::Repeated), Some(true), _) if !ty.is_numeric() => {
73                bail!("packed attribute may only be applied to numeric types");
74            }
75            (Some(Label::Repeated), _, true) => {
76                bail!("repeated fields may not have a default value");
77            }
78
79            (None, _, _) => Kind::Plain(default),
80            (Some(Label::Optional), _, _) => Kind::Optional(default),
81            (Some(Label::Required), _, _) => Kind::Required(default),
82            (Some(Label::Repeated), packed, false) if packed.unwrap_or_else(|| ty.is_numeric()) => {
83                Kind::Packed
84            }
85            (Some(Label::Repeated), _, false) => Kind::Repeated,
86        };
87
88        Ok(Some(Field { ty, kind, tag }))
89    }
90
91    pub fn new_oneof(attrs: &[Meta]) -> Result<Option<Field>, Error> {
92        if let Some(mut field) = Field::new(attrs, None)? {
93            match field.kind {
94                Kind::Plain(default) => {
95                    field.kind = Kind::Required(default);
96                    Ok(Some(field))
97                }
98                Kind::Optional(..) => bail!("invalid optional attribute on oneof field"),
99                Kind::Required(..) => bail!("invalid required attribute on oneof field"),
100                Kind::Packed | Kind::Repeated => bail!("invalid repeated attribute on oneof field"),
101            }
102        } else {
103            Ok(None)
104        }
105    }
106
107    pub fn encode(&self, ident: TokenStream) -> TokenStream {
108        let module = self.ty.module();
109        let encode_fn = match self.kind {
110            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encode),
111            Kind::Repeated => quote!(encode_repeated),
112            Kind::Packed => quote!(encode_packed),
113        };
114        let encode_fn = quote!(::prost::encoding::#module::#encode_fn);
115        let tag = self.tag;
116
117        match self.kind {
118            Kind::Plain(ref default) => {
119                let default = default.typed();
120                quote! {
121                    if #ident != #default {
122                        #encode_fn(#tag, &#ident, buf);
123                    }
124                }
125            }
126            Kind::Optional(..) => quote! {
127                if let ::core::option::Option::Some(ref value) = #ident {
128                    #encode_fn(#tag, value, buf);
129                }
130            },
131            Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
132                #encode_fn(#tag, &#ident, buf);
133            },
134        }
135    }
136
137    /// Returns an expression which evaluates to the result of merging a decoded
138    /// scalar value into the field.
139    pub fn merge(&self, ident: TokenStream) -> TokenStream {
140        let module = self.ty.module();
141        let merge_fn = match self.kind {
142            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(merge),
143            Kind::Repeated | Kind::Packed => quote!(merge_repeated),
144        };
145        let merge_fn = quote!(::prost::encoding::#module::#merge_fn);
146
147        match self.kind {
148            Kind::Plain(..) | Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
149                #merge_fn(wire_type, #ident, buf, ctx)
150            },
151            Kind::Optional(..) => quote! {
152                #merge_fn(wire_type,
153                          #ident.get_or_insert_with(::core::default::Default::default),
154                          buf,
155                          ctx)
156            },
157        }
158    }
159
160    /// Returns an expression which evaluates to the encoded length of the field.
161    pub fn encoded_len(&self, ident: TokenStream) -> TokenStream {
162        let module = self.ty.module();
163        let encoded_len_fn = match self.kind {
164            Kind::Plain(..) | Kind::Optional(..) | Kind::Required(..) => quote!(encoded_len),
165            Kind::Repeated => quote!(encoded_len_repeated),
166            Kind::Packed => quote!(encoded_len_packed),
167        };
168        let encoded_len_fn = quote!(::prost::encoding::#module::#encoded_len_fn);
169        let tag = self.tag;
170
171        match self.kind {
172            Kind::Plain(ref default) => {
173                let default = default.typed();
174                quote! {
175                    if #ident != #default {
176                        #encoded_len_fn(#tag, &#ident)
177                    } else {
178                        0
179                    }
180                }
181            }
182            Kind::Optional(..) => quote! {
183                #ident.as_ref().map_or(0, |value| #encoded_len_fn(#tag, value))
184            },
185            Kind::Required(..) | Kind::Repeated | Kind::Packed => quote! {
186                #encoded_len_fn(#tag, &#ident)
187            },
188        }
189    }
190
191    pub fn clear(&self, ident: TokenStream) -> TokenStream {
192        match self.kind {
193            Kind::Plain(ref default) | Kind::Required(ref default) => {
194                let default = default.typed();
195                match self.ty {
196                    Ty::String | Ty::Bytes(..) => quote!(#ident.clear()),
197                    _ => quote!(#ident = #default),
198                }
199            }
200            Kind::Optional(_) => quote!(#ident = ::core::option::Option::None),
201            Kind::Repeated | Kind::Packed => quote!(#ident.clear()),
202        }
203    }
204
205    /// Returns an expression which evaluates to the default value of the field.
206    pub fn default(&self) -> TokenStream {
207        match self.kind {
208            Kind::Plain(ref value) | Kind::Required(ref value) => value.owned(),
209            Kind::Optional(_) => quote!(::core::option::Option::None),
210            Kind::Repeated | Kind::Packed => quote!(::prost::alloc::vec::Vec::new()),
211        }
212    }
213
214    /// An inner debug wrapper, around the base type.
215    fn debug_inner(&self, wrap_name: TokenStream) -> TokenStream {
216        if let Ty::Enumeration(ref ty) = self.ty {
217            quote! {
218                struct #wrap_name<'a>(&'a i32);
219                impl<'a> ::core::fmt::Debug for #wrap_name<'a> {
220                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
221                        let res: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(*self.0);
222                        match res {
223                            Err(_) => ::core::fmt::Debug::fmt(&self.0, f),
224                            Ok(en) => ::core::fmt::Debug::fmt(&en, f),
225                        }
226                    }
227                }
228            }
229        } else {
230            quote! {
231                #[allow(non_snake_case)]
232                fn #wrap_name<T>(v: T) -> T { v }
233            }
234        }
235    }
236
237    /// Returns a fragment for formatting the field `ident` in `Debug`.
238    pub fn debug(&self, wrapper_name: TokenStream) -> TokenStream {
239        let wrapper = self.debug_inner(quote!(Inner));
240        let inner_ty = self.ty.rust_type();
241        match self.kind {
242            Kind::Plain(_) | Kind::Required(_) => self.debug_inner(wrapper_name),
243            Kind::Optional(_) => quote! {
244                struct #wrapper_name<'a>(&'a ::core::option::Option<#inner_ty>);
245                impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
246                    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
247                        #wrapper
248                        ::core::fmt::Debug::fmt(&self.0.as_ref().map(Inner), f)
249                    }
250                }
251            },
252            Kind::Repeated | Kind::Packed => {
253                quote! {
254                    struct #wrapper_name<'a>(&'a ::prost::alloc::vec::Vec<#inner_ty>);
255                    impl<'a> ::core::fmt::Debug for #wrapper_name<'a> {
256                        fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
257                            let mut vec_builder = f.debug_list();
258                            for v in self.0 {
259                                #wrapper
260                                vec_builder.entry(&Inner(v));
261                            }
262                            vec_builder.finish()
263                        }
264                    }
265                }
266            }
267        }
268    }
269
270    /// Returns methods to embed in the message.
271    pub fn methods(&self, ident: &TokenStream) -> Option<TokenStream> {
272        let mut ident_str = ident.to_string();
273        if ident_str.starts_with("r#") {
274            ident_str = ident_str.split_off(2);
275        }
276
277        // Prepend `get_` for getter methods of tuple structs.
278        let get = match syn::parse_str::<Index>(&ident_str) {
279            Ok(index) => {
280                let get = Ident::new(&format!("get_{}", index.index), Span::call_site());
281                quote!(#get)
282            }
283            Err(_) => quote!(#ident),
284        };
285
286        if let Ty::Enumeration(ref ty) = self.ty {
287            let set = Ident::new(&format!("set_{}", ident_str), Span::call_site());
288            let set_doc = format!("Sets `{}` to the provided enum value.", ident_str);
289            Some(match self.kind {
290                Kind::Plain(ref default) | Kind::Required(ref default) => {
291                    let get_doc = format!(
292                        "Returns the enum value of `{}`, \
293                         or the default if the field is set to an invalid enum value.",
294                        ident_str,
295                    );
296                    quote! {
297                        #[doc=#get_doc]
298                        pub fn #get(&self) -> #ty {
299                            ::core::convert::TryFrom::try_from(self.#ident).unwrap_or(#default)
300                        }
301
302                        #[doc=#set_doc]
303                        pub fn #set(&mut self, value: #ty) {
304                            self.#ident = value as i32;
305                        }
306                    }
307                }
308                Kind::Optional(ref default) => {
309                    let get_doc = format!(
310                        "Returns the enum value of `{}`, \
311                         or the default if the field is unset or set to an invalid enum value.",
312                        ident_str,
313                    );
314                    quote! {
315                        #[doc=#get_doc]
316                        pub fn #get(&self) -> #ty {
317                            self.#ident.and_then(|x| {
318                                let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
319                                result.ok()
320                            }).unwrap_or(#default)
321                        }
322
323                        #[doc=#set_doc]
324                        pub fn #set(&mut self, value: #ty) {
325                            self.#ident = ::core::option::Option::Some(value as i32);
326                        }
327                    }
328                }
329                Kind::Repeated | Kind::Packed => {
330                    let iter_doc = format!(
331                        "Returns an iterator which yields the valid enum values contained in `{}`.",
332                        ident_str,
333                    );
334                    let push = Ident::new(&format!("push_{}", ident_str), Span::call_site());
335                    let push_doc = format!("Appends the provided enum value to `{}`.", ident_str);
336                    quote! {
337                        #[doc=#iter_doc]
338                        pub fn #get(&self) -> ::core::iter::FilterMap<
339                            ::core::iter::Cloned<::core::slice::Iter<i32>>,
340                            fn(i32) -> ::core::option::Option<#ty>,
341                        > {
342                            self.#ident.iter().cloned().filter_map(|x| {
343                                let result: ::core::result::Result<#ty, _> = ::core::convert::TryFrom::try_from(x);
344                                result.ok()
345                            })
346                        }
347                        #[doc=#push_doc]
348                        pub fn #push(&mut self, value: #ty) {
349                            self.#ident.push(value as i32);
350                        }
351                    }
352                }
353            })
354        } else if let Kind::Optional(ref default) = self.kind {
355            let ty = self.ty.rust_ref_type();
356
357            let match_some = if self.ty.is_numeric() {
358                quote!(::core::option::Option::Some(val) => val,)
359            } else {
360                quote!(::core::option::Option::Some(ref val) => &val[..],)
361            };
362
363            let get_doc = format!(
364                "Returns the value of `{0}`, or the default value if `{0}` is unset.",
365                ident_str,
366            );
367
368            Some(quote! {
369                #[doc=#get_doc]
370                pub fn #get(&self) -> #ty {
371                    match self.#ident {
372                        #match_some
373                        ::core::option::Option::None => #default,
374                    }
375                }
376            })
377        } else {
378            None
379        }
380    }
381}
382
383/// A scalar protobuf field type.
384#[derive(Clone, PartialEq, Eq)]
385pub enum Ty {
386    Double,
387    Float,
388    Int32,
389    Int64,
390    Uint32,
391    Uint64,
392    Sint32,
393    Sint64,
394    Fixed32,
395    Fixed64,
396    Sfixed32,
397    Sfixed64,
398    Bool,
399    String,
400    Bytes(BytesTy),
401    Enumeration(Path),
402}
403
404#[derive(Clone, Debug, PartialEq, Eq)]
405pub enum BytesTy {
406    Vec,
407    Bytes,
408}
409
410impl BytesTy {
411    fn try_from_str(s: &str) -> Result<Self, Error> {
412        match s {
413            "vec" => Ok(BytesTy::Vec),
414            "bytes" => Ok(BytesTy::Bytes),
415            _ => bail!("Invalid bytes type: {}", s),
416        }
417    }
418
419    fn rust_type(&self) -> TokenStream {
420        match self {
421            BytesTy::Vec => quote! { ::prost::alloc::vec::Vec<u8> },
422            BytesTy::Bytes => quote! { ::prost::bytes::Bytes },
423        }
424    }
425}
426
427impl Ty {
428    pub fn from_attr(attr: &Meta) -> Result<Option<Ty>, Error> {
429        let ty = match *attr {
430            Meta::Path(ref name) if name.is_ident("float") => Ty::Float,
431            Meta::Path(ref name) if name.is_ident("double") => Ty::Double,
432            Meta::Path(ref name) if name.is_ident("int32") => Ty::Int32,
433            Meta::Path(ref name) if name.is_ident("int64") => Ty::Int64,
434            Meta::Path(ref name) if name.is_ident("uint32") => Ty::Uint32,
435            Meta::Path(ref name) if name.is_ident("uint64") => Ty::Uint64,
436            Meta::Path(ref name) if name.is_ident("sint32") => Ty::Sint32,
437            Meta::Path(ref name) if name.is_ident("sint64") => Ty::Sint64,
438            Meta::Path(ref name) if name.is_ident("fixed32") => Ty::Fixed32,
439            Meta::Path(ref name) if name.is_ident("fixed64") => Ty::Fixed64,
440            Meta::Path(ref name) if name.is_ident("sfixed32") => Ty::Sfixed32,
441            Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64,
442            Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool,
443            Meta::Path(ref name) if name.is_ident("string") => Ty::String,
444            Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec),
445            Meta::NameValue(MetaNameValue {
446                ref path,
447                value:
448                    Expr::Lit(ExprLit {
449                        lit: Lit::Str(ref l),
450                        ..
451                    }),
452                ..
453            }) if path.is_ident("bytes") => Ty::Bytes(BytesTy::try_from_str(&l.value())?),
454            Meta::NameValue(MetaNameValue {
455                ref path,
456                value:
457                    Expr::Lit(ExprLit {
458                        lit: Lit::Str(ref l),
459                        ..
460                    }),
461                ..
462            }) if path.is_ident("enumeration") => Ty::Enumeration(parse_str::<Path>(&l.value())?),
463            Meta::List(ref meta_list) if meta_list.path.is_ident("enumeration") => {
464                Ty::Enumeration(meta_list.parse_args::<Path>()?)
465            }
466            _ => return Ok(None),
467        };
468        Ok(Some(ty))
469    }
470
471    pub fn from_str(s: &str) -> Result<Ty, Error> {
472        let enumeration_len = "enumeration".len();
473        let error = Err(anyhow!("invalid type: {}", s));
474        let ty = match s.trim() {
475            "float" => Ty::Float,
476            "double" => Ty::Double,
477            "int32" => Ty::Int32,
478            "int64" => Ty::Int64,
479            "uint32" => Ty::Uint32,
480            "uint64" => Ty::Uint64,
481            "sint32" => Ty::Sint32,
482            "sint64" => Ty::Sint64,
483            "fixed32" => Ty::Fixed32,
484            "fixed64" => Ty::Fixed64,
485            "sfixed32" => Ty::Sfixed32,
486            "sfixed64" => Ty::Sfixed64,
487            "bool" => Ty::Bool,
488            "string" => Ty::String,
489            "bytes" => Ty::Bytes(BytesTy::Vec),
490            s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => {
491                let s = &s[enumeration_len..].trim();
492                match s.chars().next() {
493                    Some('<') | Some('(') => (),
494                    _ => return error,
495                }
496                match s.chars().next_back() {
497                    Some('>') | Some(')') => (),
498                    _ => return error,
499                }
500
501                Ty::Enumeration(parse_str::<Path>(s[1..s.len() - 1].trim())?)
502            }
503            _ => return error,
504        };
505        Ok(ty)
506    }
507
508    /// Returns the type as it appears in protobuf field declarations.
509    pub fn as_str(&self) -> &'static str {
510        match *self {
511            Ty::Double => "double",
512            Ty::Float => "float",
513            Ty::Int32 => "int32",
514            Ty::Int64 => "int64",
515            Ty::Uint32 => "uint32",
516            Ty::Uint64 => "uint64",
517            Ty::Sint32 => "sint32",
518            Ty::Sint64 => "sint64",
519            Ty::Fixed32 => "fixed32",
520            Ty::Fixed64 => "fixed64",
521            Ty::Sfixed32 => "sfixed32",
522            Ty::Sfixed64 => "sfixed64",
523            Ty::Bool => "bool",
524            Ty::String => "string",
525            Ty::Bytes(..) => "bytes",
526            Ty::Enumeration(..) => "enum",
527        }
528    }
529
530    // TODO: rename to 'owned_type'.
531    pub fn rust_type(&self) -> TokenStream {
532        match self {
533            Ty::String => quote!(::prost::alloc::string::String),
534            Ty::Bytes(ty) => ty.rust_type(),
535            _ => self.rust_ref_type(),
536        }
537    }
538
539    // TODO: rename to 'ref_type'
540    pub fn rust_ref_type(&self) -> TokenStream {
541        match *self {
542            Ty::Double => quote!(f64),
543            Ty::Float => quote!(f32),
544            Ty::Int32 => quote!(i32),
545            Ty::Int64 => quote!(i64),
546            Ty::Uint32 => quote!(u32),
547            Ty::Uint64 => quote!(u64),
548            Ty::Sint32 => quote!(i32),
549            Ty::Sint64 => quote!(i64),
550            Ty::Fixed32 => quote!(u32),
551            Ty::Fixed64 => quote!(u64),
552            Ty::Sfixed32 => quote!(i32),
553            Ty::Sfixed64 => quote!(i64),
554            Ty::Bool => quote!(bool),
555            Ty::String => quote!(&str),
556            Ty::Bytes(..) => quote!(&[u8]),
557            Ty::Enumeration(..) => quote!(i32),
558        }
559    }
560
561    pub fn module(&self) -> Ident {
562        match *self {
563            Ty::Enumeration(..) => Ident::new("int32", Span::call_site()),
564            _ => Ident::new(self.as_str(), Span::call_site()),
565        }
566    }
567
568    /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`).
569    pub fn is_numeric(&self) -> bool {
570        !matches!(self, Ty::String | Ty::Bytes(..))
571    }
572}
573
574impl fmt::Debug for Ty {
575    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
576        f.write_str(self.as_str())
577    }
578}
579
580impl fmt::Display for Ty {
581    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
582        f.write_str(self.as_str())
583    }
584}
585
586/// Scalar Protobuf field types.
587#[derive(Clone)]
588pub enum Kind {
589    /// A plain proto3 scalar field.
590    Plain(DefaultValue),
591    /// An optional scalar field.
592    Optional(DefaultValue),
593    /// A required proto2 scalar field.
594    Required(DefaultValue),
595    /// A repeated scalar field.
596    Repeated,
597    /// A packed repeated scalar field.
598    Packed,
599}
600
601/// Scalar Protobuf field default value.
602#[derive(Clone, Debug)]
603pub enum DefaultValue {
604    F64(f64),
605    F32(f32),
606    I32(i32),
607    I64(i64),
608    U32(u32),
609    U64(u64),
610    Bool(bool),
611    String(String),
612    Bytes(Vec<u8>),
613    Enumeration(TokenStream),
614    Path(Path),
615}
616
617impl DefaultValue {
618    pub fn from_attr(attr: &Meta) -> Result<Option<Lit>, Error> {
619        if !attr.path().is_ident("default") {
620            Ok(None)
621        } else if let Meta::NameValue(MetaNameValue {
622            value: Expr::Lit(ExprLit { ref lit, .. }),
623            ..
624        }) = *attr
625        {
626            Ok(Some(lit.clone()))
627        } else {
628            bail!("invalid default value attribute: {:?}", attr)
629        }
630    }
631
632    pub fn from_lit(ty: &Ty, lit: Lit) -> Result<DefaultValue, Error> {
633        let is_i32 = *ty == Ty::Int32 || *ty == Ty::Sint32 || *ty == Ty::Sfixed32;
634        let is_i64 = *ty == Ty::Int64 || *ty == Ty::Sint64 || *ty == Ty::Sfixed64;
635
636        let is_u32 = *ty == Ty::Uint32 || *ty == Ty::Fixed32;
637        let is_u64 = *ty == Ty::Uint64 || *ty == Ty::Fixed64;
638
639        let empty_or_is = |expected, actual: &str| expected == actual || actual.is_empty();
640
641        let default = match lit {
642            Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
643                DefaultValue::I32(lit.base10_parse()?)
644            }
645            Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
646                DefaultValue::I64(lit.base10_parse()?)
647            }
648            Lit::Int(ref lit) if is_u32 && empty_or_is("u32", lit.suffix()) => {
649                DefaultValue::U32(lit.base10_parse()?)
650            }
651            Lit::Int(ref lit) if is_u64 && empty_or_is("u64", lit.suffix()) => {
652                DefaultValue::U64(lit.base10_parse()?)
653            }
654
655            Lit::Float(ref lit) if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) => {
656                DefaultValue::F32(lit.base10_parse()?)
657            }
658            Lit::Int(ref lit) if *ty == Ty::Float => DefaultValue::F32(lit.base10_parse()?),
659
660            Lit::Float(ref lit) if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) => {
661                DefaultValue::F64(lit.base10_parse()?)
662            }
663            Lit::Int(ref lit) if *ty == Ty::Double => DefaultValue::F64(lit.base10_parse()?),
664
665            Lit::Bool(ref lit) if *ty == Ty::Bool => DefaultValue::Bool(lit.value),
666            Lit::Str(ref lit) if *ty == Ty::String => DefaultValue::String(lit.value()),
667            Lit::ByteStr(ref lit)
668                if *ty == Ty::Bytes(BytesTy::Bytes) || *ty == Ty::Bytes(BytesTy::Vec) =>
669            {
670                DefaultValue::Bytes(lit.value())
671            }
672
673            Lit::Str(ref lit) => {
674                let value = lit.value();
675                let value = value.trim();
676
677                if let Ty::Enumeration(ref path) = *ty {
678                    let variant = Ident::new(value, Span::call_site());
679                    return Ok(DefaultValue::Enumeration(quote!(#path::#variant)));
680                }
681
682                // Parse special floating point values.
683                if *ty == Ty::Float {
684                    match value {
685                        "inf" => {
686                            return Ok(DefaultValue::Path(parse_str::<Path>(
687                                "::core::f32::INFINITY",
688                            )?));
689                        }
690                        "-inf" => {
691                            return Ok(DefaultValue::Path(parse_str::<Path>(
692                                "::core::f32::NEG_INFINITY",
693                            )?));
694                        }
695                        "nan" => {
696                            return Ok(DefaultValue::Path(parse_str::<Path>("::core::f32::NAN")?));
697                        }
698                        _ => (),
699                    }
700                }
701                if *ty == Ty::Double {
702                    match value {
703                        "inf" => {
704                            return Ok(DefaultValue::Path(parse_str::<Path>(
705                                "::core::f64::INFINITY",
706                            )?));
707                        }
708                        "-inf" => {
709                            return Ok(DefaultValue::Path(parse_str::<Path>(
710                                "::core::f64::NEG_INFINITY",
711                            )?));
712                        }
713                        "nan" => {
714                            return Ok(DefaultValue::Path(parse_str::<Path>("::core::f64::NAN")?));
715                        }
716                        _ => (),
717                    }
718                }
719
720                // Rust doesn't have a negative literals, so they have to be parsed specially.
721                if let Some(Ok(lit)) = value.strip_prefix('-').map(syn::parse_str::<Lit>) {
722                    match lit {
723                        Lit::Int(ref lit) if is_i32 && empty_or_is("i32", lit.suffix()) => {
724                            // Initially parse into an i64, so that i32::MIN does not overflow.
725                            let value: i64 = -lit.base10_parse()?;
726                            return Ok(i32::try_from(value).map(DefaultValue::I32)?);
727                        }
728                        Lit::Int(ref lit) if is_i64 && empty_or_is("i64", lit.suffix()) => {
729                            // Initially parse into an i128, so that i64::MIN does not overflow.
730                            let value: i128 = -lit.base10_parse()?;
731                            return Ok(i64::try_from(value).map(DefaultValue::I64)?);
732                        }
733                        Lit::Float(ref lit)
734                            if *ty == Ty::Float && empty_or_is("f32", lit.suffix()) =>
735                        {
736                            return Ok(DefaultValue::F32(-lit.base10_parse()?));
737                        }
738                        Lit::Float(ref lit)
739                            if *ty == Ty::Double && empty_or_is("f64", lit.suffix()) =>
740                        {
741                            return Ok(DefaultValue::F64(-lit.base10_parse()?));
742                        }
743                        Lit::Int(ref lit) if *ty == Ty::Float && lit.suffix().is_empty() => {
744                            return Ok(DefaultValue::F32(-lit.base10_parse()?));
745                        }
746                        Lit::Int(ref lit) if *ty == Ty::Double && lit.suffix().is_empty() => {
747                            return Ok(DefaultValue::F64(-lit.base10_parse()?));
748                        }
749                        _ => (),
750                    }
751                }
752                match syn::parse_str::<Lit>(value) {
753                    Ok(Lit::Str(_)) => (),
754                    Ok(lit) => return DefaultValue::from_lit(ty, lit),
755                    _ => (),
756                }
757                bail!("invalid default value: {}", quote!(#value));
758            }
759            _ => bail!("invalid default value: {}", quote!(#lit)),
760        };
761
762        Ok(default)
763    }
764
765    pub fn new(ty: &Ty) -> DefaultValue {
766        match *ty {
767            Ty::Float => DefaultValue::F32(0.0),
768            Ty::Double => DefaultValue::F64(0.0),
769            Ty::Int32 | Ty::Sint32 | Ty::Sfixed32 => DefaultValue::I32(0),
770            Ty::Int64 | Ty::Sint64 | Ty::Sfixed64 => DefaultValue::I64(0),
771            Ty::Uint32 | Ty::Fixed32 => DefaultValue::U32(0),
772            Ty::Uint64 | Ty::Fixed64 => DefaultValue::U64(0),
773
774            Ty::Bool => DefaultValue::Bool(false),
775            Ty::String => DefaultValue::String(String::new()),
776            Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()),
777            Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())),
778        }
779    }
780
781    pub fn owned(&self) -> TokenStream {
782        match *self {
783            DefaultValue::String(ref value) if value.is_empty() => {
784                quote!(::prost::alloc::string::String::new())
785            }
786            DefaultValue::String(ref value) => quote!(#value.into()),
787            DefaultValue::Bytes(ref value) if value.is_empty() => {
788                quote!(::core::default::Default::default())
789            }
790            DefaultValue::Bytes(ref value) => {
791                let lit = LitByteStr::new(value, Span::call_site());
792                quote!(#lit.as_ref().into())
793            }
794
795            ref other => other.typed(),
796        }
797    }
798
799    pub fn typed(&self) -> TokenStream {
800        if let DefaultValue::Enumeration(_) = *self {
801            quote!(#self as i32)
802        } else {
803            quote!(#self)
804        }
805    }
806}
807
808impl ToTokens for DefaultValue {
809    fn to_tokens(&self, tokens: &mut TokenStream) {
810        match *self {
811            DefaultValue::F64(value) => value.to_tokens(tokens),
812            DefaultValue::F32(value) => value.to_tokens(tokens),
813            DefaultValue::I32(value) => value.to_tokens(tokens),
814            DefaultValue::I64(value) => value.to_tokens(tokens),
815            DefaultValue::U32(value) => value.to_tokens(tokens),
816            DefaultValue::U64(value) => value.to_tokens(tokens),
817            DefaultValue::Bool(value) => value.to_tokens(tokens),
818            DefaultValue::String(ref value) => value.to_tokens(tokens),
819            DefaultValue::Bytes(ref value) => {
820                let byte_str = LitByteStr::new(value, Span::call_site());
821                tokens.append_all(quote!(#byte_str as &[u8]));
822            }
823            DefaultValue::Enumeration(ref value) => value.to_tokens(tokens),
824            DefaultValue::Path(ref value) => value.to_tokens(tokens),
825        }
826    }
827}