orchestra_proc_macro/
impl_message_wrapper.rs1use std::collections::HashSet;
17
18use itertools::Itertools;
19use quote::quote;
20use syn::{spanned::Spanned, Result};
21
22use super::*;
23
24pub(crate) fn impl_message_wrapper_enum(info: &OrchestraInfo) -> Result<proc_macro2::TokenStream> {
26 let consumes = info.any_message();
27 let consumes_variant = info.variant_names();
28
29 let outgoing = &info.outgoing_ty;
30
31 let message_wrapper = &info.message_wrapper;
32
33 let (outgoing_from_impl, outgoing_decl) = if let Some(outgoing) = outgoing {
34 let outgoing_variant = outgoing.get_ident().ok_or_else(|| {
35 syn::Error::new(
36 outgoing.span(),
37 "Missing identifier to use as enum variant for outgoing.",
38 )
39 })?;
40 (
41 quote! {
42 impl ::std::convert::From< #outgoing > for #message_wrapper {
43 fn from(message: #outgoing) -> Self {
44 #message_wrapper :: #outgoing_variant ( message )
45 }
46 }
47 },
48 quote! {
49 #outgoing_variant ( #outgoing ) ,
50 },
51 )
52 } else {
53 (TokenStream::new(), TokenStream::new())
54 };
55
56 let mut ts = quote! {
57 #[allow(missing_docs)]
60 #[derive(Debug)]
61 pub enum #message_wrapper {
62 #(
63 #consumes_variant ( #consumes ),
64 )*
65 #outgoing_decl
66 Empty,
68 }
69
70 impl ::std::convert::From< () > for #message_wrapper {
71 fn from(_: ()) -> Self {
72 #message_wrapper :: Empty
73 }
74 }
75 impl ::std::convert::From< #message_wrapper > for () {
76 fn from(message: #message_wrapper) -> Self {
77 match message {
78 #message_wrapper :: Empty => (),
79 _ => panic!("Message is not of type {}", stringify!(#message_wrapper)),
80 }
81 }
82 }
83
84 #(
85 impl ::std::convert::From< #consumes > for #message_wrapper {
86 fn from(message: #consumes) -> Self {
87 #message_wrapper :: #consumes_variant ( message )
88 }
89 }
90
91 impl ::std::convert::TryFrom< #message_wrapper > for #consumes {
92 type Error = ();
93 fn try_from(message: #message_wrapper) -> ::std::result::Result<Self, Self::Error> {
94 match message {
95 #message_wrapper :: #consumes_variant ( inner ) => Ok(inner),
96 _ => Err(()),
97 }
98 }
99 }
100 )*
101
102 #outgoing_from_impl
103 };
104
105 let outgoing = HashSet::<&Path>::from_iter(
108 info.subsystems().iter().map(|ssf| ssf.messages_to_send.iter()).flatten(),
109 );
110 let incoming = HashSet::<&Path>::from_iter(
111 info.subsystems().iter().filter_map(|ssf| ssf.message_to_consume.as_ref()),
112 );
113
114 fn cmp<'p, 'q>(a: &'p &&Path, b: &'q &&Path) -> std::cmp::Ordering {
116 a.span()
117 .start()
118 .partial_cmp(&b.span().start())
119 .unwrap_or(std::cmp::Ordering::Equal)
120 }
121
122 if cfg!(feature = "deny_unconsumed_messages") {
124 for sbnr in outgoing.difference(&incoming).sorted_by(cmp) {
125 ts.extend(
126 syn::Error::new(
127 sbnr.span(),
128 format!(
129 "Message `{}` is sent but never received",
130 sbnr.get_ident()
131 .expect("Message is a path that must end in an identifier. qed")
132 ),
133 )
134 .to_compile_error(),
135 );
136 }
137 }
138
139 if cfg!(feature = "deny_unsent_messages") {
142 for rbns in incoming.difference(&outgoing).sorted_by(cmp) {
143 ts.extend(
144 syn::Error::new(
145 rbns.span(),
146 format!(
147 "Message `{}` is received but never sent",
148 rbns.get_ident()
149 .expect("Message is a path that must end in an identifier. qed")
150 ),
151 )
152 .to_compile_error(),
153 );
154 }
155 }
156 Ok(ts)
157}