orchestra_proc_macro/
impl_message_wrapper.rs

1// Copyright (C) 2021 Parity Technologies (UK) Ltd.
2// SPDX-License-Identifier: Apache-2.0
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::collections::HashSet;
17
18use itertools::Itertools;
19use quote::quote;
20use syn::{spanned::Spanned, Result};
21
22use super::*;
23
24/// Generates the wrapper type enum.
25pub(crate) fn impl_message_wrapper_enum(info: &OrchestraInfo) -> Result<proc_macro2::TokenStream> {
26	let consumes = info.any_message();
27	let consumes_variant = info.variant_names();
28
29	let outgoing = &info.outgoing_ty;
30
31	let message_wrapper = &info.message_wrapper;
32
33	let (outgoing_from_impl, outgoing_decl) = if let Some(outgoing) = outgoing {
34		let outgoing_variant = outgoing.get_ident().ok_or_else(|| {
35			syn::Error::new(
36				outgoing.span(),
37				"Missing identifier to use as enum variant for outgoing.",
38			)
39		})?;
40		(
41			quote! {
42				impl ::std::convert::From< #outgoing > for #message_wrapper {
43					fn from(message: #outgoing) -> Self {
44						#message_wrapper :: #outgoing_variant ( message )
45					}
46				}
47			},
48			quote! {
49				#outgoing_variant ( #outgoing ) ,
50			},
51		)
52	} else {
53		(TokenStream::new(), TokenStream::new())
54	};
55
56	let mut ts = quote! {
57		/// Generated message type wrapper over all possible messages
58		/// used by any subsystem.
59		#[allow(missing_docs)]
60		#[derive(Debug)]
61		pub enum #message_wrapper {
62			#(
63				#consumes_variant ( #consumes ),
64			)*
65			#outgoing_decl
66			// dummy message type
67			Empty,
68		}
69
70		impl ::std::convert::From< () > for #message_wrapper {
71			fn from(_: ()) -> Self {
72				#message_wrapper :: Empty
73			}
74		}
75		impl ::std::convert::From< #message_wrapper > for () {
76			fn from(message: #message_wrapper) -> Self {
77				match message {
78					#message_wrapper :: Empty => (),
79					_ => panic!("Message is not of type {}", stringify!(#message_wrapper)),
80				}
81			}
82		}
83
84		#(
85		impl ::std::convert::From< #consumes > for #message_wrapper {
86			fn from(message: #consumes) -> Self {
87				#message_wrapper :: #consumes_variant ( message )
88			}
89		}
90
91		impl ::std::convert::TryFrom< #message_wrapper > for #consumes {
92			type Error = ();
93			fn try_from(message: #message_wrapper) -> ::std::result::Result<Self, Self::Error> {
94				match message {
95					#message_wrapper :: #consumes_variant ( inner ) => Ok(inner),
96					_ => Err(()),
97				}
98			}
99		}
100		)*
101
102		#outgoing_from_impl
103	};
104
105	// TODO it's not perfect, if the same type is used with different paths
106	// the detection will fail
107	let outgoing = HashSet::<&Path>::from_iter(
108		info.subsystems().iter().map(|ssf| ssf.messages_to_send.iter()).flatten(),
109	);
110	let incoming = HashSet::<&Path>::from_iter(
111		info.subsystems().iter().filter_map(|ssf| ssf.message_to_consume.as_ref()),
112	);
113
114	// Try to maintain the ordering according to the span start in the declaration.
115	fn cmp<'p, 'q>(a: &'p &&Path, b: &'q &&Path) -> std::cmp::Ordering {
116		a.span()
117			.start()
118			.partial_cmp(&b.span().start())
119			.unwrap_or(std::cmp::Ordering::Equal)
120	}
121
122	// sent but not received
123	if cfg!(feature = "deny_unconsumed_messages") {
124		for sbnr in outgoing.difference(&incoming).sorted_by(cmp) {
125			ts.extend(
126				syn::Error::new(
127					sbnr.span(),
128					format!(
129						"Message `{}` is sent but never received",
130						sbnr.get_ident()
131							.expect("Message is a path that must end in an identifier. qed")
132					),
133				)
134				.to_compile_error(),
135			);
136		}
137	}
138
139	// received but not sent
140
141	if cfg!(feature = "deny_unsent_messages") {
142		for rbns in incoming.difference(&outgoing).sorted_by(cmp) {
143			ts.extend(
144				syn::Error::new(
145					rbns.span(),
146					format!(
147						"Message `{}` is received but never sent",
148						rbns.get_ident()
149							.expect("Message is a path that must end in an identifier. qed")
150					),
151				)
152				.to_compile_error(),
153			);
154		}
155	}
156	Ok(ts)
157}