prost_derive/
lib.rs

1#![doc(html_root_url = "https://docs.rs/prost-derive/0.12.6")]
2// The `quote!` macro requires deep recursion.
3#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14    punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15    FieldsUnnamed, Ident, Index, Variant,
16};
17
18mod field;
19use crate::field::Field;
20
21fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22    let input: DeriveInput = syn::parse(input)?;
23
24    let ident = input.ident;
25
26    syn::custom_keyword!(skip_debug);
27    let skip_debug = input
28        .attrs
29        .into_iter()
30        .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
31
32    let variant_data = match input.data {
33        Data::Struct(variant_data) => variant_data,
34        Data::Enum(..) => bail!("Message can not be derived for an enum"),
35        Data::Union(..) => bail!("Message can not be derived for a union"),
36    };
37
38    let generics = &input.generics;
39    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41    let (is_struct, fields) = match variant_data {
42        DataStruct {
43            fields: Fields::Named(FieldsNamed { named: fields, .. }),
44            ..
45        } => (true, fields.into_iter().collect()),
46        DataStruct {
47            fields:
48                Fields::Unnamed(FieldsUnnamed {
49                    unnamed: fields, ..
50                }),
51            ..
52        } => (false, fields.into_iter().collect()),
53        DataStruct {
54            fields: Fields::Unit,
55            ..
56        } => (false, Vec::new()),
57    };
58
59    let mut next_tag: u32 = 1;
60    let mut fields = fields
61        .into_iter()
62        .enumerate()
63        .flat_map(|(i, field)| {
64            let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65                let index = Index {
66                    index: i as u32,
67                    span: Span::call_site(),
68                };
69                quote!(#index)
70            });
71            match Field::new(field.attrs, Some(next_tag)) {
72                Ok(Some(field)) => {
73                    next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74                    Some(Ok((field_ident, field)))
75                }
76                Ok(None) => None,
77                Err(err) => Some(Err(
78                    err.context(format!("invalid message field {}.{}", ident, field_ident))
79                )),
80            }
81        })
82        .collect::<Result<Vec<_>, _>>()?;
83
84    // We want Debug to be in declaration order
85    let unsorted_fields = fields.clone();
86
87    // Sort the fields by tag number so that fields will be encoded in tag order.
88    // TODO: This encodes oneof fields in the position of their lowest tag,
89    // regardless of the currently occupied variant, is that consequential?
90    // See: https://developers.google.com/protocol-buffers/docs/encoding#order
91    fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
92    let fields = fields;
93
94    let mut tags = fields
95        .iter()
96        .flat_map(|(_, field)| field.tags())
97        .collect::<Vec<_>>();
98    let num_tags = tags.len();
99    tags.sort_unstable();
100    tags.dedup();
101    if tags.len() != num_tags {
102        bail!("message {} has fields with duplicate tags", ident);
103    }
104
105    let encoded_len = fields
106        .iter()
107        .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident)));
108
109    let encode = fields
110        .iter()
111        .map(|(field_ident, field)| field.encode(quote!(self.#field_ident)));
112
113    let merge = fields.iter().map(|(field_ident, field)| {
114        let merge = field.merge(quote!(value));
115        let tags = field.tags().into_iter().map(|tag| quote!(#tag));
116        let tags = Itertools::intersperse(tags, quote!(|));
117
118        quote! {
119            #(#tags)* => {
120                let mut value = &mut self.#field_ident;
121                #merge.map_err(|mut error| {
122                    error.push(STRUCT_NAME, stringify!(#field_ident));
123                    error
124                })
125            },
126        }
127    });
128
129    let struct_name = if fields.is_empty() {
130        quote!()
131    } else {
132        quote!(
133            const STRUCT_NAME: &'static str = stringify!(#ident);
134        )
135    };
136
137    let clear = fields
138        .iter()
139        .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
140
141    let default = if is_struct {
142        let default = fields.iter().map(|(field_ident, field)| {
143            let value = field.default();
144            quote!(#field_ident: #value,)
145        });
146        quote! {#ident {
147            #(#default)*
148        }}
149    } else {
150        let default = fields.iter().map(|(_, field)| {
151            let value = field.default();
152            quote!(#value,)
153        });
154        quote! {#ident (
155            #(#default)*
156        )}
157    };
158
159    let methods = fields
160        .iter()
161        .flat_map(|(field_ident, field)| field.methods(field_ident))
162        .collect::<Vec<_>>();
163    let methods = if methods.is_empty() {
164        quote!()
165    } else {
166        quote! {
167            #[allow(dead_code)]
168            impl #impl_generics #ident #ty_generics #where_clause {
169                #(#methods)*
170            }
171        }
172    };
173
174    let expanded = quote! {
175        impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176            #[allow(unused_variables)]
177            fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178                #(#encode)*
179            }
180
181            #[allow(unused_variables)]
182            fn merge_field<B>(
183                &mut self,
184                tag: u32,
185                wire_type: ::prost::encoding::WireType,
186                buf: &mut B,
187                ctx: ::prost::encoding::DecodeContext,
188            ) -> ::core::result::Result<(), ::prost::DecodeError>
189            where B: ::prost::bytes::Buf {
190                #struct_name
191                match tag {
192                    #(#merge)*
193                    _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194                }
195            }
196
197            #[inline]
198            fn encoded_len(&self) -> usize {
199                0 #(+ #encoded_len)*
200            }
201
202            fn clear(&mut self) {
203                #(#clear;)*
204            }
205        }
206
207        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
208            fn default() -> Self {
209                #default
210            }
211        }
212    };
213    let expanded = if skip_debug {
214        expanded
215    } else {
216        let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
217            let wrapper = field.debug(quote!(self.#field_ident));
218            let call = if is_struct {
219                quote!(builder.field(stringify!(#field_ident), &wrapper))
220            } else {
221                quote!(builder.field(&wrapper))
222            };
223            quote! {
224                 let builder = {
225                     let wrapper = #wrapper;
226                     #call
227                 };
228            }
229        });
230        let debug_builder = if is_struct {
231            quote!(f.debug_struct(stringify!(#ident)))
232        } else {
233            quote!(f.debug_tuple(stringify!(#ident)))
234        };
235        quote! {
236            #expanded
237
238            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
239                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
240                    let mut builder = #debug_builder;
241                    #(#debugs;)*
242                    builder.finish()
243                }
244            }
245        }
246    };
247
248    let expanded = quote! {
249        #expanded
250
251        #methods
252    };
253
254    Ok(expanded.into())
255}
256
257#[proc_macro_derive(Message, attributes(prost))]
258pub fn message(input: TokenStream) -> TokenStream {
259    try_message(input).unwrap()
260}
261
262fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
263    let input: DeriveInput = syn::parse(input)?;
264    let ident = input.ident;
265
266    let generics = &input.generics;
267    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268
269    let punctuated_variants = match input.data {
270        Data::Enum(DataEnum { variants, .. }) => variants,
271        Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272        Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273    };
274
275    // Map the variants into 'fields'.
276    let mut variants: Vec<(Ident, Expr)> = Vec::new();
277    for Variant {
278        ident,
279        fields,
280        discriminant,
281        ..
282    } in punctuated_variants
283    {
284        match fields {
285            Fields::Unit => (),
286            Fields::Named(_) | Fields::Unnamed(_) => {
287                bail!("Enumeration variants may not have fields")
288            }
289        }
290
291        match discriminant {
292            Some((_, expr)) => variants.push((ident, expr)),
293            None => bail!("Enumeration variants must have a discriminant"),
294        }
295    }
296
297    if variants.is_empty() {
298        panic!("Enumeration must have at least one variant");
299    }
300
301    let default = variants[0].0.clone();
302
303    let is_valid = variants.iter().map(|(_, value)| quote!(#value => true));
304    let from = variants
305        .iter()
306        .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)));
307
308    let try_from = variants
309        .iter()
310        .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)));
311
312    let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
313    let from_i32_doc = format!(
314        "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
315        ident
316    );
317
318    let expanded = quote! {
319        impl #impl_generics #ident #ty_generics #where_clause {
320            #[doc=#is_valid_doc]
321            pub fn is_valid(value: i32) -> bool {
322                match value {
323                    #(#is_valid,)*
324                    _ => false,
325                }
326            }
327
328            #[deprecated = "Use the TryFrom<i32> implementation instead"]
329            #[doc=#from_i32_doc]
330            pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
331                match value {
332                    #(#from,)*
333                    _ => ::core::option::Option::None,
334                }
335            }
336        }
337
338        impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
339            fn default() -> #ident {
340                #ident::#default
341            }
342        }
343
344        impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
345            fn from(value: #ident) -> i32 {
346                value as i32
347            }
348        }
349
350        impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
351            type Error = ::prost::DecodeError;
352
353            fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> {
354                match value {
355                    #(#try_from,)*
356                    _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")),
357                }
358            }
359        }
360    };
361
362    Ok(expanded.into())
363}
364
365#[proc_macro_derive(Enumeration, attributes(prost))]
366pub fn enumeration(input: TokenStream) -> TokenStream {
367    try_enumeration(input).unwrap()
368}
369
370fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
371    let input: DeriveInput = syn::parse(input)?;
372
373    let ident = input.ident;
374
375    syn::custom_keyword!(skip_debug);
376    let skip_debug = input
377        .attrs
378        .into_iter()
379        .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
380
381    let variants = match input.data {
382        Data::Enum(DataEnum { variants, .. }) => variants,
383        Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
384        Data::Union(..) => bail!("Oneof can not be derived for a union"),
385    };
386
387    let generics = &input.generics;
388    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
389
390    // Map the variants into 'fields'.
391    let mut fields: Vec<(Ident, Field)> = Vec::new();
392    for Variant {
393        attrs,
394        ident: variant_ident,
395        fields: variant_fields,
396        ..
397    } in variants
398    {
399        let variant_fields = match variant_fields {
400            Fields::Unit => Punctuated::new(),
401            Fields::Named(FieldsNamed { named: fields, .. })
402            | Fields::Unnamed(FieldsUnnamed {
403                unnamed: fields, ..
404            }) => fields,
405        };
406        if variant_fields.len() != 1 {
407            bail!("Oneof enum variants must have a single field");
408        }
409        match Field::new_oneof(attrs)? {
410            Some(field) => fields.push((variant_ident, field)),
411            None => bail!("invalid oneof variant: oneof variants may not be ignored"),
412        }
413    }
414
415    let mut tags = fields
416        .iter()
417        .flat_map(|(variant_ident, field)| -> Result<u32, Error> {
418            if field.tags().len() > 1 {
419                bail!(
420                    "invalid oneof variant {}::{}: oneof variants may only have a single tag",
421                    ident,
422                    variant_ident
423                );
424            }
425            Ok(field.tags()[0])
426        })
427        .collect::<Vec<_>>();
428    tags.sort_unstable();
429    tags.dedup();
430    if tags.len() != fields.len() {
431        panic!("invalid oneof {}: variants have duplicate tags", ident);
432    }
433
434    let encode = fields.iter().map(|(variant_ident, field)| {
435        let encode = field.encode(quote!(*value));
436        quote!(#ident::#variant_ident(ref value) => { #encode })
437    });
438
439    let merge = fields.iter().map(|(variant_ident, field)| {
440        let tag = field.tags()[0];
441        let merge = field.merge(quote!(value));
442        quote! {
443            #tag => {
444                match field {
445                    ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
446                        #merge
447                    },
448                    _ => {
449                        let mut owned_value = ::core::default::Default::default();
450                        let value = &mut owned_value;
451                        #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
452                    },
453                }
454            }
455        }
456    });
457
458    let encoded_len = fields.iter().map(|(variant_ident, field)| {
459        let encoded_len = field.encoded_len(quote!(*value));
460        quote!(#ident::#variant_ident(ref value) => #encoded_len)
461    });
462
463    let expanded = quote! {
464        impl #impl_generics #ident #ty_generics #where_clause {
465            /// Encodes the message to a buffer.
466            pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
467                match *self {
468                    #(#encode,)*
469                }
470            }
471
472            /// Decodes an instance of the message from a buffer, and merges it into self.
473            pub fn merge<B>(
474                field: &mut ::core::option::Option<#ident #ty_generics>,
475                tag: u32,
476                wire_type: ::prost::encoding::WireType,
477                buf: &mut B,
478                ctx: ::prost::encoding::DecodeContext,
479            ) -> ::core::result::Result<(), ::prost::DecodeError>
480            where B: ::prost::bytes::Buf {
481                match tag {
482                    #(#merge,)*
483                    _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
484                }
485            }
486
487            /// Returns the encoded length of the message without a length delimiter.
488            #[inline]
489            pub fn encoded_len(&self) -> usize {
490                match *self {
491                    #(#encoded_len,)*
492                }
493            }
494        }
495
496    };
497    let expanded = if skip_debug {
498        expanded
499    } else {
500        let debug = fields.iter().map(|(variant_ident, field)| {
501            let wrapper = field.debug(quote!(*value));
502            quote!(#ident::#variant_ident(ref value) => {
503                let wrapper = #wrapper;
504                f.debug_tuple(stringify!(#variant_ident))
505                    .field(&wrapper)
506                    .finish()
507            })
508        });
509        quote! {
510            #expanded
511
512            impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
513                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
514                    match *self {
515                        #(#debug,)*
516                    }
517                }
518            }
519        }
520    };
521
522    Ok(expanded.into())
523}
524
525#[proc_macro_derive(Oneof, attributes(prost))]
526pub fn oneof(input: TokenStream) -> TokenStream {
527    try_oneof(input).unwrap()
528}