1#![doc(html_root_url = "https://docs.rs/prost-derive/0.12.6")]
2#![recursion_limit = "4096"]
4
5extern crate alloc;
6extern crate proc_macro;
7
8use anyhow::{bail, Error};
9use itertools::Itertools;
10use proc_macro::TokenStream;
11use proc_macro2::Span;
12use quote::quote;
13use syn::{
14 punctuated::Punctuated, Data, DataEnum, DataStruct, DeriveInput, Expr, Fields, FieldsNamed,
15 FieldsUnnamed, Ident, Index, Variant,
16};
17
18mod field;
19use crate::field::Field;
20
21fn try_message(input: TokenStream) -> Result<TokenStream, Error> {
22 let input: DeriveInput = syn::parse(input)?;
23
24 let ident = input.ident;
25
26 syn::custom_keyword!(skip_debug);
27 let skip_debug = input
28 .attrs
29 .into_iter()
30 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
31
32 let variant_data = match input.data {
33 Data::Struct(variant_data) => variant_data,
34 Data::Enum(..) => bail!("Message can not be derived for an enum"),
35 Data::Union(..) => bail!("Message can not be derived for a union"),
36 };
37
38 let generics = &input.generics;
39 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
40
41 let (is_struct, fields) = match variant_data {
42 DataStruct {
43 fields: Fields::Named(FieldsNamed { named: fields, .. }),
44 ..
45 } => (true, fields.into_iter().collect()),
46 DataStruct {
47 fields:
48 Fields::Unnamed(FieldsUnnamed {
49 unnamed: fields, ..
50 }),
51 ..
52 } => (false, fields.into_iter().collect()),
53 DataStruct {
54 fields: Fields::Unit,
55 ..
56 } => (false, Vec::new()),
57 };
58
59 let mut next_tag: u32 = 1;
60 let mut fields = fields
61 .into_iter()
62 .enumerate()
63 .flat_map(|(i, field)| {
64 let field_ident = field.ident.map(|x| quote!(#x)).unwrap_or_else(|| {
65 let index = Index {
66 index: i as u32,
67 span: Span::call_site(),
68 };
69 quote!(#index)
70 });
71 match Field::new(field.attrs, Some(next_tag)) {
72 Ok(Some(field)) => {
73 next_tag = field.tags().iter().max().map(|t| t + 1).unwrap_or(next_tag);
74 Some(Ok((field_ident, field)))
75 }
76 Ok(None) => None,
77 Err(err) => Some(Err(
78 err.context(format!("invalid message field {}.{}", ident, field_ident))
79 )),
80 }
81 })
82 .collect::<Result<Vec<_>, _>>()?;
83
84 let unsorted_fields = fields.clone();
86
87 fields.sort_by_key(|(_, field)| field.tags().into_iter().min().unwrap());
92 let fields = fields;
93
94 let mut tags = fields
95 .iter()
96 .flat_map(|(_, field)| field.tags())
97 .collect::<Vec<_>>();
98 let num_tags = tags.len();
99 tags.sort_unstable();
100 tags.dedup();
101 if tags.len() != num_tags {
102 bail!("message {} has fields with duplicate tags", ident);
103 }
104
105 let encoded_len = fields
106 .iter()
107 .map(|(field_ident, field)| field.encoded_len(quote!(self.#field_ident)));
108
109 let encode = fields
110 .iter()
111 .map(|(field_ident, field)| field.encode(quote!(self.#field_ident)));
112
113 let merge = fields.iter().map(|(field_ident, field)| {
114 let merge = field.merge(quote!(value));
115 let tags = field.tags().into_iter().map(|tag| quote!(#tag));
116 let tags = Itertools::intersperse(tags, quote!(|));
117
118 quote! {
119 #(#tags)* => {
120 let mut value = &mut self.#field_ident;
121 #merge.map_err(|mut error| {
122 error.push(STRUCT_NAME, stringify!(#field_ident));
123 error
124 })
125 },
126 }
127 });
128
129 let struct_name = if fields.is_empty() {
130 quote!()
131 } else {
132 quote!(
133 const STRUCT_NAME: &'static str = stringify!(#ident);
134 )
135 };
136
137 let clear = fields
138 .iter()
139 .map(|(field_ident, field)| field.clear(quote!(self.#field_ident)));
140
141 let default = if is_struct {
142 let default = fields.iter().map(|(field_ident, field)| {
143 let value = field.default();
144 quote!(#field_ident: #value,)
145 });
146 quote! {#ident {
147 #(#default)*
148 }}
149 } else {
150 let default = fields.iter().map(|(_, field)| {
151 let value = field.default();
152 quote!(#value,)
153 });
154 quote! {#ident (
155 #(#default)*
156 )}
157 };
158
159 let methods = fields
160 .iter()
161 .flat_map(|(field_ident, field)| field.methods(field_ident))
162 .collect::<Vec<_>>();
163 let methods = if methods.is_empty() {
164 quote!()
165 } else {
166 quote! {
167 #[allow(dead_code)]
168 impl #impl_generics #ident #ty_generics #where_clause {
169 #(#methods)*
170 }
171 }
172 };
173
174 let expanded = quote! {
175 impl #impl_generics ::prost::Message for #ident #ty_generics #where_clause {
176 #[allow(unused_variables)]
177 fn encode_raw<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
178 #(#encode)*
179 }
180
181 #[allow(unused_variables)]
182 fn merge_field<B>(
183 &mut self,
184 tag: u32,
185 wire_type: ::prost::encoding::WireType,
186 buf: &mut B,
187 ctx: ::prost::encoding::DecodeContext,
188 ) -> ::core::result::Result<(), ::prost::DecodeError>
189 where B: ::prost::bytes::Buf {
190 #struct_name
191 match tag {
192 #(#merge)*
193 _ => ::prost::encoding::skip_field(wire_type, tag, buf, ctx),
194 }
195 }
196
197 #[inline]
198 fn encoded_len(&self) -> usize {
199 0 #(+ #encoded_len)*
200 }
201
202 fn clear(&mut self) {
203 #(#clear;)*
204 }
205 }
206
207 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
208 fn default() -> Self {
209 #default
210 }
211 }
212 };
213 let expanded = if skip_debug {
214 expanded
215 } else {
216 let debugs = unsorted_fields.iter().map(|(field_ident, field)| {
217 let wrapper = field.debug(quote!(self.#field_ident));
218 let call = if is_struct {
219 quote!(builder.field(stringify!(#field_ident), &wrapper))
220 } else {
221 quote!(builder.field(&wrapper))
222 };
223 quote! {
224 let builder = {
225 let wrapper = #wrapper;
226 #call
227 };
228 }
229 });
230 let debug_builder = if is_struct {
231 quote!(f.debug_struct(stringify!(#ident)))
232 } else {
233 quote!(f.debug_tuple(stringify!(#ident)))
234 };
235 quote! {
236 #expanded
237
238 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
239 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
240 let mut builder = #debug_builder;
241 #(#debugs;)*
242 builder.finish()
243 }
244 }
245 }
246 };
247
248 let expanded = quote! {
249 #expanded
250
251 #methods
252 };
253
254 Ok(expanded.into())
255}
256
257#[proc_macro_derive(Message, attributes(prost))]
258pub fn message(input: TokenStream) -> TokenStream {
259 try_message(input).unwrap()
260}
261
262fn try_enumeration(input: TokenStream) -> Result<TokenStream, Error> {
263 let input: DeriveInput = syn::parse(input)?;
264 let ident = input.ident;
265
266 let generics = &input.generics;
267 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
268
269 let punctuated_variants = match input.data {
270 Data::Enum(DataEnum { variants, .. }) => variants,
271 Data::Struct(_) => bail!("Enumeration can not be derived for a struct"),
272 Data::Union(..) => bail!("Enumeration can not be derived for a union"),
273 };
274
275 let mut variants: Vec<(Ident, Expr)> = Vec::new();
277 for Variant {
278 ident,
279 fields,
280 discriminant,
281 ..
282 } in punctuated_variants
283 {
284 match fields {
285 Fields::Unit => (),
286 Fields::Named(_) | Fields::Unnamed(_) => {
287 bail!("Enumeration variants may not have fields")
288 }
289 }
290
291 match discriminant {
292 Some((_, expr)) => variants.push((ident, expr)),
293 None => bail!("Enumeration variants must have a discriminant"),
294 }
295 }
296
297 if variants.is_empty() {
298 panic!("Enumeration must have at least one variant");
299 }
300
301 let default = variants[0].0.clone();
302
303 let is_valid = variants.iter().map(|(_, value)| quote!(#value => true));
304 let from = variants
305 .iter()
306 .map(|(variant, value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)));
307
308 let try_from = variants
309 .iter()
310 .map(|(variant, value)| quote!(#value => ::core::result::Result::Ok(#ident::#variant)));
311
312 let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident);
313 let from_i32_doc = format!(
314 "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.",
315 ident
316 );
317
318 let expanded = quote! {
319 impl #impl_generics #ident #ty_generics #where_clause {
320 #[doc=#is_valid_doc]
321 pub fn is_valid(value: i32) -> bool {
322 match value {
323 #(#is_valid,)*
324 _ => false,
325 }
326 }
327
328 #[deprecated = "Use the TryFrom<i32> implementation instead"]
329 #[doc=#from_i32_doc]
330 pub fn from_i32(value: i32) -> ::core::option::Option<#ident> {
331 match value {
332 #(#from,)*
333 _ => ::core::option::Option::None,
334 }
335 }
336 }
337
338 impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause {
339 fn default() -> #ident {
340 #ident::#default
341 }
342 }
343
344 impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause {
345 fn from(value: #ident) -> i32 {
346 value as i32
347 }
348 }
349
350 impl #impl_generics ::core::convert::TryFrom::<i32> for #ident #ty_generics #where_clause {
351 type Error = ::prost::DecodeError;
352
353 fn try_from(value: i32) -> ::core::result::Result<#ident, ::prost::DecodeError> {
354 match value {
355 #(#try_from,)*
356 _ => ::core::result::Result::Err(::prost::DecodeError::new("invalid enumeration value")),
357 }
358 }
359 }
360 };
361
362 Ok(expanded.into())
363}
364
365#[proc_macro_derive(Enumeration, attributes(prost))]
366pub fn enumeration(input: TokenStream) -> TokenStream {
367 try_enumeration(input).unwrap()
368}
369
370fn try_oneof(input: TokenStream) -> Result<TokenStream, Error> {
371 let input: DeriveInput = syn::parse(input)?;
372
373 let ident = input.ident;
374
375 syn::custom_keyword!(skip_debug);
376 let skip_debug = input
377 .attrs
378 .into_iter()
379 .any(|a| a.path().is_ident("prost") && a.parse_args::<skip_debug>().is_ok());
380
381 let variants = match input.data {
382 Data::Enum(DataEnum { variants, .. }) => variants,
383 Data::Struct(..) => bail!("Oneof can not be derived for a struct"),
384 Data::Union(..) => bail!("Oneof can not be derived for a union"),
385 };
386
387 let generics = &input.generics;
388 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
389
390 let mut fields: Vec<(Ident, Field)> = Vec::new();
392 for Variant {
393 attrs,
394 ident: variant_ident,
395 fields: variant_fields,
396 ..
397 } in variants
398 {
399 let variant_fields = match variant_fields {
400 Fields::Unit => Punctuated::new(),
401 Fields::Named(FieldsNamed { named: fields, .. })
402 | Fields::Unnamed(FieldsUnnamed {
403 unnamed: fields, ..
404 }) => fields,
405 };
406 if variant_fields.len() != 1 {
407 bail!("Oneof enum variants must have a single field");
408 }
409 match Field::new_oneof(attrs)? {
410 Some(field) => fields.push((variant_ident, field)),
411 None => bail!("invalid oneof variant: oneof variants may not be ignored"),
412 }
413 }
414
415 let mut tags = fields
416 .iter()
417 .flat_map(|(variant_ident, field)| -> Result<u32, Error> {
418 if field.tags().len() > 1 {
419 bail!(
420 "invalid oneof variant {}::{}: oneof variants may only have a single tag",
421 ident,
422 variant_ident
423 );
424 }
425 Ok(field.tags()[0])
426 })
427 .collect::<Vec<_>>();
428 tags.sort_unstable();
429 tags.dedup();
430 if tags.len() != fields.len() {
431 panic!("invalid oneof {}: variants have duplicate tags", ident);
432 }
433
434 let encode = fields.iter().map(|(variant_ident, field)| {
435 let encode = field.encode(quote!(*value));
436 quote!(#ident::#variant_ident(ref value) => { #encode })
437 });
438
439 let merge = fields.iter().map(|(variant_ident, field)| {
440 let tag = field.tags()[0];
441 let merge = field.merge(quote!(value));
442 quote! {
443 #tag => {
444 match field {
445 ::core::option::Option::Some(#ident::#variant_ident(ref mut value)) => {
446 #merge
447 },
448 _ => {
449 let mut owned_value = ::core::default::Default::default();
450 let value = &mut owned_value;
451 #merge.map(|_| *field = ::core::option::Option::Some(#ident::#variant_ident(owned_value)))
452 },
453 }
454 }
455 }
456 });
457
458 let encoded_len = fields.iter().map(|(variant_ident, field)| {
459 let encoded_len = field.encoded_len(quote!(*value));
460 quote!(#ident::#variant_ident(ref value) => #encoded_len)
461 });
462
463 let expanded = quote! {
464 impl #impl_generics #ident #ty_generics #where_clause {
465 pub fn encode<B>(&self, buf: &mut B) where B: ::prost::bytes::BufMut {
467 match *self {
468 #(#encode,)*
469 }
470 }
471
472 pub fn merge<B>(
474 field: &mut ::core::option::Option<#ident #ty_generics>,
475 tag: u32,
476 wire_type: ::prost::encoding::WireType,
477 buf: &mut B,
478 ctx: ::prost::encoding::DecodeContext,
479 ) -> ::core::result::Result<(), ::prost::DecodeError>
480 where B: ::prost::bytes::Buf {
481 match tag {
482 #(#merge,)*
483 _ => unreachable!(concat!("invalid ", stringify!(#ident), " tag: {}"), tag),
484 }
485 }
486
487 #[inline]
489 pub fn encoded_len(&self) -> usize {
490 match *self {
491 #(#encoded_len,)*
492 }
493 }
494 }
495
496 };
497 let expanded = if skip_debug {
498 expanded
499 } else {
500 let debug = fields.iter().map(|(variant_ident, field)| {
501 let wrapper = field.debug(quote!(*value));
502 quote!(#ident::#variant_ident(ref value) => {
503 let wrapper = #wrapper;
504 f.debug_tuple(stringify!(#variant_ident))
505 .field(&wrapper)
506 .finish()
507 })
508 });
509 quote! {
510 #expanded
511
512 impl #impl_generics ::core::fmt::Debug for #ident #ty_generics #where_clause {
513 fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
514 match *self {
515 #(#debug,)*
516 }
517 }
518 }
519 }
520 };
521
522 Ok(expanded.into())
523}
524
525#[proc_macro_derive(Oneof, attributes(prost))]
526pub fn oneof(input: TokenStream) -> TokenStream {
527 try_oneof(input).unwrap()
528}