use std::collections::HashSet;
use itertools::Itertools;
use quote::quote;
use syn::{spanned::Spanned, Result};
use super::*;
pub(crate) fn impl_message_wrapper_enum(info: &OrchestraInfo) -> Result<proc_macro2::TokenStream> {
let consumes = info.any_message();
let consumes_variant = info.variant_names();
let outgoing = &info.outgoing_ty;
let message_wrapper = &info.message_wrapper;
let (outgoing_from_impl, outgoing_decl) = if let Some(outgoing) = outgoing {
let outgoing_variant = outgoing.get_ident().ok_or_else(|| {
syn::Error::new(
outgoing.span(),
"Missing identifier to use as enum variant for outgoing.",
)
})?;
(
quote! {
impl ::std::convert::From< #outgoing > for #message_wrapper {
fn from(message: #outgoing) -> Self {
#message_wrapper :: #outgoing_variant ( message )
}
}
},
quote! {
#outgoing_variant ( #outgoing ) ,
},
)
} else {
(TokenStream::new(), TokenStream::new())
};
let mut ts = quote! {
#[allow(missing_docs)]
#[derive(Debug)]
pub enum #message_wrapper {
#(
#consumes_variant ( #consumes ),
)*
#outgoing_decl
Empty,
}
impl ::std::convert::From< () > for #message_wrapper {
fn from(_: ()) -> Self {
#message_wrapper :: Empty
}
}
impl ::std::convert::From< #message_wrapper > for () {
fn from(message: #message_wrapper) -> Self {
match message {
#message_wrapper :: Empty => (),
_ => panic!("Message is not of type {}", stringify!(#message_wrapper)),
}
}
}
#(
impl ::std::convert::From< #consumes > for #message_wrapper {
fn from(message: #consumes) -> Self {
#message_wrapper :: #consumes_variant ( message )
}
}
impl ::std::convert::TryFrom< #message_wrapper > for #consumes {
type Error = ();
fn try_from(message: #message_wrapper) -> ::std::result::Result<Self, Self::Error> {
match message {
#message_wrapper :: #consumes_variant ( inner ) => Ok(inner),
_ => Err(()),
}
}
}
)*
#outgoing_from_impl
};
let outgoing = HashSet::<&Path>::from_iter(
info.subsystems().iter().map(|ssf| ssf.messages_to_send.iter()).flatten(),
);
let incoming = HashSet::<&Path>::from_iter(
info.subsystems().iter().filter_map(|ssf| ssf.message_to_consume.as_ref()),
);
fn cmp<'p, 'q>(a: &'p &&Path, b: &'q &&Path) -> std::cmp::Ordering {
a.span()
.start()
.partial_cmp(&b.span().start())
.unwrap_or(std::cmp::Ordering::Equal)
}
if cfg!(feature = "deny_unconsumed_messages") {
for sbnr in outgoing.difference(&incoming).sorted_by(cmp) {
ts.extend(
syn::Error::new(
sbnr.span(),
format!(
"Message `{}` is sent but never received",
sbnr.get_ident()
.expect("Message is a path that must end in an identifier. qed")
),
)
.to_compile_error(),
);
}
}
if cfg!(feature = "deny_unsent_messages") {
for rbns in incoming.difference(&outgoing).sorted_by(cmp) {
ts.extend(
syn::Error::new(
rbns.span(),
format!(
"Message `{}` is received but never sent",
rbns.get_ident()
.expect("Message is a path that must end in an identifier. qed")
),
)
.to_compile_error(),
);
}
}
Ok(ts)
}