1use proc_macro2::TokenStream;
17use quote::quote;
18use syn::{Ident, Path, Result, Type};
19
20use petgraph::{visit::EdgeRef, Direction};
21
22use super::*;
23
24pub(crate) fn impl_subsystem_types_all(info: &OrchestraInfo) -> Result<TokenStream> {
26 let mut ts = TokenStream::new();
27
28 let orchestra_name = &info.orchestra_name;
29 let span = orchestra_name.span();
30 let all_messages_wrapper = &info.message_wrapper;
31 let support_crate = info.support_crate_name();
32 let signal_ty = &info.extern_signal_ty;
33 let error_ty = &info.extern_error_ty;
34
35 let cg = graph::ConnectionGraph::construct(info.subsystems());
36 let graph = &cg.graph;
37
38 for node_index in graph.node_indices() {
40 let subsystem_name = graph[node_index].to_string();
41 let outgoing_wrapper = Ident::new(&(subsystem_name + "OutgoingMessages"), span);
42
43 let outgoing_to_consumer = graph
46 .edges_directed(node_index, Direction::Outgoing)
47 .map(|edge| {
48 let message_ty = edge.weight();
49 let subsystem_generic_consumer = graph[edge.target()].clone();
50 Ok((to_variant(message_ty, span.clone())?, subsystem_generic_consumer))
51 })
52 .collect::<Result<Vec<(Ident, Ident)>>>()?;
53
54 let outgoing_variant = outgoing_to_consumer.iter().map(|x| x.0.clone()).collect::<Vec<_>>();
56 let subsystem_generic = outgoing_to_consumer.into_iter().map(|x| x.1).collect::<Vec<_>>();
57
58 ts.extend(quote! {
59 impl ::std::convert::From< #outgoing_wrapper > for #all_messages_wrapper {
60 fn from(message: #outgoing_wrapper) -> Self {
61 match message {
62 #(
63 #outgoing_wrapper :: #outgoing_variant ( msg ) => #all_messages_wrapper :: #subsystem_generic ( msg ),
64 )*
65 #outgoing_wrapper :: Empty => #all_messages_wrapper :: Empty,
66 #[allow(unreachable_patterns)]
68 unused_msg => {
69 #support_crate :: tracing :: warn!("Nothing consumes {:?}", unused_msg);
70 #all_messages_wrapper :: Empty
71 }
72 }
73 }
74 }
75
76 impl ::std::convert::TryFrom< #all_messages_wrapper > for #outgoing_wrapper {
77 type Error = ();
78 fn try_from(message: #all_messages_wrapper) -> ::std::result::Result<Self, Self::Error> {
79 match message {
80 #(
81 #all_messages_wrapper :: #subsystem_generic ( msg ) => Ok(#outgoing_wrapper :: #outgoing_variant ( msg )),
82 )*
83 #all_messages_wrapper :: Empty => Ok(#outgoing_wrapper :: Empty),
84 _ => Err(()),
85 }
86 }
87 }
88 })
89 }
90
91 #[cfg(feature = "dotgraph")]
93 {
94 let dest = std::path::PathBuf::from(env!("OUT_DIR"))
95 .join(orchestra_name.to_string().to_lowercase() + "-subsystem-messaging.dot");
96 if let Err(e) = cg.render_graphs(&dest) {
97 eprintln!("Hetscher/Hiccup: {}", e);
98 e.chain().skip(1).for_each(|cause| eprintln!("caused by: {}", cause));
99 }
100 }
101
102 let subsystem_sender_name = &Ident::new(&(orchestra_name.to_string() + "Sender"), span);
103 let subsystem_ctx_name = &Ident::new(&(orchestra_name.to_string() + "SubsystemContext"), span);
104 ts.extend(impl_subsystem_context(info, &subsystem_sender_name, &subsystem_ctx_name));
105
106 ts.extend(impl_associate_outgoing_messages_trait(&all_messages_wrapper));
107
108 ts.extend(impl_subsystem_sender(
109 support_crate,
110 info.subsystems().iter().map(|ssf| {
111 let outgoing_wrapper =
112 Ident::new(&(ssf.generic.to_string() + "OutgoingMessages"), span);
113 outgoing_wrapper
114 }),
115 &all_messages_wrapper,
116 &subsystem_sender_name,
117 ));
118
119 for ssf in info.subsystems() {
121 let subsystem_name = ssf.generic.to_string();
122 let outgoing_wrapper = &Ident::new(&(subsystem_name.clone() + "OutgoingMessages"), span);
123 let message_to_consume = ssf.message_to_consume();
124
125 let subsystem_ctx_trait = &Ident::new(&(subsystem_name.clone() + "ContextTrait"), span);
126 let subsystem_sender_trait = &Ident::new(&(subsystem_name.clone() + "SenderTrait"), span);
127
128 ts.extend(impl_per_subsystem_helper_traits(
129 info,
130 subsystem_ctx_name,
131 subsystem_ctx_trait,
132 subsystem_sender_name,
133 subsystem_sender_trait,
134 &message_to_consume,
135 &ssf.messages_to_send,
136 outgoing_wrapper,
137 ));
138
139 ts.extend(impl_associate_outgoing_messages(&message_to_consume, &outgoing_wrapper));
140 ts.extend(impl_wrapper_enum(&outgoing_wrapper, ssf.messages_to_send.as_slice())?);
141 }
142
143 ts.extend({
145 let mut messages = TokenStream::new();
146 for ssf in info.subsystems() {
147 messages.extend(ssf.gen_dummy_message_ty());
148 }
149 let comment = "The exclusive home of all generated dummy messages (if any at all)";
150 quote! {
151 #[doc = #comment]
152 pub mod messages {
153 #messages
154 }
155 }
156 });
157
158 let empty_tuple: Type = parse_quote! { () };
160 ts.extend(impl_subsystem_context_trait_for(
161 info,
162 empty_tuple.clone(),
163 &[],
164 empty_tuple.clone(),
165 all_messages_wrapper,
166 subsystem_ctx_name,
167 subsystem_sender_name,
168 support_crate,
169 signal_ty,
170 error_ty,
171 ));
172
173 Ok(ts)
174}
175
176fn to_variant(path: &Path, span: Span) -> Result<Ident> {
178 let ident = path
179 .segments
180 .last()
181 .ok_or_else(|| syn::Error::new(span, "Path is empty, but it must end with an identifier"))
182 .map(|segment| segment.ident.clone())?;
183 Ok(ident)
184}
185
186fn to_variants(message_types: &[Path], span: Span) -> Result<Vec<Ident>> {
191 let variants: Vec<_> =
192 Result::from_iter(message_types.into_iter().map(|path| to_variant(path, span.clone())))?;
193 Ok(variants)
194}
195
196pub(crate) fn impl_wrapper_enum(wrapper: &Ident, message_types: &[Path]) -> Result<TokenStream> {
198 let variants = to_variants(message_types, wrapper.span())?;
201
202 let ts = quote! {
203 #[allow(missing_docs, clippy::large_enum_variant)]
204 #[derive(Debug)]
205 pub enum #wrapper {
206 #(
207 #variants ( #message_types ),
208 )*
209 Empty,
210 }
211
212 #(
213 impl ::std::convert::From< #message_types > for #wrapper {
214 fn from(message: #message_types) -> Self {
215 #wrapper :: #variants ( message )
216 }
217 }
218 )*
219
220 impl ::std::convert::From< () > for #wrapper {
222 fn from(_message: ()) -> Self {
223 #wrapper :: Empty
224 }
225 }
226 };
227 Ok(ts)
228}
229
230pub(crate) fn impl_subsystem_sender(
233 support_crate: &Path,
234 outgoing_wrappers: impl IntoIterator<Item = Ident>,
235 all_messages_wrapper: &Ident,
236 subsystem_sender_name: &Ident,
237) -> TokenStream {
238 let mut ts = quote! {
239 #[derive(Debug)]
242 pub struct #subsystem_sender_name < OutgoingWrapper > {
243 channels: ChannelsOut,
245 signals_received: SignalsReceived,
247 _phantom: ::core::marker::PhantomData< OutgoingWrapper >,
249 }
250
251 impl<OutgoingWrapper> std::clone::Clone for #subsystem_sender_name < OutgoingWrapper > {
254 fn clone(&self) -> Self {
255 Self {
256 channels: self.channels.clone(),
257 signals_received: self.signals_received.clone(),
258 _phantom: ::core::marker::PhantomData,
259 }
260 }
261 }
262 };
263
264 let wrapped = |outgoing_wrapper: &TokenStream| {
269 quote! {
270 #[allow(clippy::unit_arg)]
271 #[#support_crate ::async_trait]
272 impl<OutgoingMessage> SubsystemSender< OutgoingMessage > for #subsystem_sender_name < #outgoing_wrapper >
273 where
274 OutgoingMessage: ::std::convert::TryFrom<#all_messages_wrapper> + Send + 'static,
275 #outgoing_wrapper: ::std::convert::From<OutgoingMessage> + Send,
276 #all_messages_wrapper: ::std::convert::From< #outgoing_wrapper > + Send,
277 <OutgoingMessage as ::std::convert::TryFrom<#all_messages_wrapper>>::Error: ::std::fmt::Debug,
278 {
279 async fn send_message(&mut self, msg: OutgoingMessage)
280 {
281 self.send_message_with_priority::<#support_crate ::NormalPriority>(msg).await;
282 }
283
284 async fn send_message_with_priority<P: #support_crate ::Priority>(&mut self, msg: OutgoingMessage)
285 {
286 self.channels.send_and_log_error::<P>(
287 self.signals_received.load(),
288 <#all_messages_wrapper as ::std::convert::From<_>> ::from (
289 <#outgoing_wrapper as ::std::convert::From<_>> :: from ( msg )
290 ),
291 ).await;
292 }
293
294 fn try_send_message(&mut self, msg: OutgoingMessage) -> ::std::result::Result<(), #support_crate ::metered::TrySendError<OutgoingMessage>>
295 {
296 self.try_send_message_with_priority::<#support_crate ::NormalPriority>(msg)
297 }
298
299 fn try_send_message_with_priority<P: #support_crate ::Priority>(&mut self, msg: OutgoingMessage) -> ::std::result::Result<(), #support_crate ::metered::TrySendError<OutgoingMessage>>
300 {
301 self.channels.try_send::<P>(
302 self.signals_received.load(),
303 <#all_messages_wrapper as ::std::convert::From<_>> ::from (
304 <#outgoing_wrapper as ::std::convert::From<_>> :: from ( msg )
305 ),
306 ).map_err(|err| match err {
307 #support_crate ::metered::TrySendError::Full(inner) => #support_crate ::metered::TrySendError::Full(inner.try_into().expect("we should be able to unwrap what we wrap, qed")),
308 #support_crate ::metered::TrySendError::Closed(inner) => #support_crate ::metered::TrySendError::Closed(inner.try_into().expect("we should be able to unwrap what we wrap, qed")),
309 })
310 }
311
312 async fn send_messages<I>(&mut self, msgs: I)
313 where
314 I: IntoIterator<Item=OutgoingMessage> + Send,
315 I::IntoIter: Iterator<Item=OutgoingMessage> + Send,
316 {
317 for msg in msgs {
318 self.send_message( msg ).await;
319 }
320 }
321
322 fn send_unbounded_message(&mut self, msg: OutgoingMessage)
323 {
324 self.channels.send_unbounded_and_log_error(
325 self.signals_received.load(),
326 <#all_messages_wrapper as ::std::convert::From<_>> ::from (
327 <#outgoing_wrapper as ::std::convert::From<_>> :: from ( msg )
328 )
329 );
330 }
331 }
332 }
333 };
334
335 for outgoing_wrapper in outgoing_wrappers {
336 ts.extend(wrapped("e! {
337 #outgoing_wrapper
338 }));
339 }
340
341 ts.extend(wrapped("e! {
342 ()
343 }));
344
345 ts
346}
347
348pub(crate) fn impl_associate_outgoing_messages_trait(all_messages_wrapper: &Ident) -> TokenStream {
350 quote! {
351 pub trait AssociateOutgoing: ::std::fmt::Debug + Send {
357 type OutgoingMessages: Into< #all_messages_wrapper > + ::std::fmt::Debug + Send;
359 }
360
361 impl AssociateOutgoing for () {
363 type OutgoingMessages = ();
364 }
365
366 impl AssociateOutgoing for #all_messages_wrapper {
369 type OutgoingMessages = #all_messages_wrapper ;
370 }
371 }
372}
373
374pub(crate) fn impl_associate_outgoing_messages(
381 consumes: &Path,
382 outgoing_wrapper: &Ident,
383) -> TokenStream {
384 quote! {
385 impl AssociateOutgoing for #outgoing_wrapper {
386 type OutgoingMessages = #outgoing_wrapper;
387 }
388
389 impl AssociateOutgoing for #consumes {
390 type OutgoingMessages = #outgoing_wrapper;
391 }
392 }
393}
394
395pub(crate) fn impl_subsystem_context_trait_for(
398 info: &OrchestraInfo,
399 consumes: Type,
400 outgoing: &[Type],
401 outgoing_wrapper: Type,
402 all_messages_wrapper: &Ident,
403 subsystem_ctx_name: &Ident,
404 subsystem_sender_name: &Ident,
405 support_crate: &Path,
406 signal: &Path,
407 error_ty: &Path,
408) -> TokenStream {
409 let where_clause = quote! {
411 #consumes: AssociateOutgoing + ::std::fmt::Debug + Send + 'static,
412 #all_messages_wrapper: From< #outgoing_wrapper >,
413 #all_messages_wrapper: From< #consumes >,
414 #outgoing_wrapper: #( From< #outgoing > )+*,
415 };
416
417 let maybe_unbox_packet = if info.boxed_messages {
418 quote! { *packet.message }
419 } else {
420 quote! { packet.message }
421 };
422
423 quote! {
424 #[#support_crate ::async_trait]
425 impl #support_crate ::SubsystemContext for #subsystem_ctx_name < #consumes >
426 where
427 #where_clause
428 {
429 type Message = #consumes;
430 type Signal = #signal;
431 type OutgoingMessages = #outgoing_wrapper;
432 type Sender = #subsystem_sender_name < #outgoing_wrapper >;
433 type Error = #error_ty;
434
435 async fn try_recv(&mut self) -> ::std::result::Result<Option<FromOrchestra< Self::Message, #signal>>, ()> {
436 match #support_crate ::poll!(self.recv()) {
437 #support_crate ::Poll::Ready(msg) => Ok(Some(msg.map_err(|_| ())?)),
438 #support_crate ::Poll::Pending => Ok(None),
439 }
440 }
441
442 #[allow(clippy::suspicious_else_formatting)]
443 async fn recv(&mut self) -> ::std::result::Result<FromOrchestra<Self::Message, #signal>, #error_ty> {
444 loop {
445 if let Some((needs_signals_received, msg)) = self.pending_incoming.take() {
448 if needs_signals_received <= self.signals_received.load() {
449 return Ok( #support_crate ::FromOrchestra::Communication { msg });
450 } else {
451 self.pending_incoming = Some((needs_signals_received, msg));
452
453 let signal = self.signals.next().await
455 .ok_or(#support_crate ::OrchestraError::Context(
456 "Signal channel is terminated and empty."
457 .to_owned()
458 ))?;
459
460 self.signals_received.inc();
461 return Ok( #support_crate ::FromOrchestra::Signal(signal))
462 }
463 }
464
465 let mut await_message = self.messages.next().fuse();
466 let mut await_signal = self.signals.next().fuse();
467 let signals_received = self.signals_received.load();
468 let pending_incoming = &mut self.pending_incoming;
469
470 let from_orchestra = #support_crate ::futures::select_biased! {
472 signal = await_signal => {
473 let signal = signal
474 .ok_or( #support_crate ::OrchestraError::Context(
475 "Signal channel is terminated and empty."
476 .to_owned()
477 ))?;
478
479 #support_crate ::FromOrchestra::Signal(signal)
480 }
481 msg = await_message => {
482 let packet = msg
483 .ok_or( #support_crate ::OrchestraError::Context(
484 "Message channel is terminated and empty."
485 .to_owned()
486 ))?;
487
488 if packet.signals_received > signals_received {
489 *pending_incoming = Some((packet.signals_received, #maybe_unbox_packet));
491 continue;
492 } else {
493 #support_crate ::FromOrchestra::Communication { msg: #maybe_unbox_packet}
495 }
496 }
497 };
498
499 if let #support_crate ::FromOrchestra::Signal(_) = from_orchestra {
500 self.signals_received.inc();
501 }
502
503 return Ok(from_orchestra);
504 }
505 }
506
507 async fn recv_signal(&mut self) -> ::std::result::Result<#signal, #error_ty> {
508 let result = self.signals.next().await.ok_or(#support_crate ::OrchestraError::Context(
509 "Signal channel is terminated and empty.".to_owned(),
510 ).into());
511 if result.is_ok() {
512 self.signals_received.inc();
513 }
514 result
515 }
516
517 fn sender(&mut self) -> &mut Self::Sender {
518 &mut self.to_subsystems
519 }
520
521 fn spawn(&mut self, name: &'static str, s: Pin<Box<dyn Future<Output = ()> + Send>>)
522 -> ::std::result::Result<(), #error_ty>
523 {
524 self.to_orchestra.unbounded_send(#support_crate ::ToOrchestra::SpawnJob {
525 name,
526 subsystem: Some(self.name()),
527 s,
528 }).map_err(|_| #support_crate ::OrchestraError::TaskSpawn(name))?;
529 Ok(())
530 }
531
532 fn spawn_blocking(&mut self, name: &'static str, s: Pin<Box<dyn Future<Output = ()> + Send>>)
533 -> ::std::result::Result<(), #error_ty>
534 {
535 self.to_orchestra.unbounded_send(#support_crate ::ToOrchestra::SpawnBlockingJob {
536 name,
537 subsystem: Some(self.name()),
538 s,
539 }).map_err(|_| #support_crate ::OrchestraError::TaskSpawn(name))?;
540 Ok(())
541 }
542 }
543 }
544}
545
546pub(crate) fn impl_per_subsystem_helper_traits(
549 info: &OrchestraInfo,
550 subsystem_ctx_name: &Ident,
551 subsystem_ctx_trait: &Ident,
552 subsystem_sender_name: &Ident,
553 subsystem_sender_trait: &Ident,
554 consumes: &Path,
555 outgoing: &[Path],
556 outgoing_wrapper: &Ident,
557) -> TokenStream {
558 let all_messages_wrapper = &info.message_wrapper;
559 let signal_ty = &info.extern_signal_ty;
560 let error_ty = &info.extern_error_ty;
561 let support_crate = info.support_crate_name();
562
563 let mut ts = TokenStream::new();
564
565 let acc_sender_trait_bounds = quote! {
568 #support_crate ::SubsystemSender< #outgoing_wrapper >
569 #(
570 + #support_crate ::SubsystemSender< #outgoing >
571 )*
572 + #support_crate ::SubsystemSender< () >
573 + Send
574 + 'static
575 };
576
577 ts.extend(quote! {
578 pub trait #subsystem_sender_trait : #acc_sender_trait_bounds
580 {}
581
582 impl<T> #subsystem_sender_trait for T
583 where
584 T: #acc_sender_trait_bounds
585 {}
586 });
587
588 let where_clause = quote! {
590 #consumes: AssociateOutgoing + ::std::fmt::Debug + Send + 'static,
591 #all_messages_wrapper: From< #outgoing_wrapper >,
592 #all_messages_wrapper: From< #consumes >,
593 #all_messages_wrapper: From< () >,
594 #outgoing_wrapper: #( From< #outgoing > )+*,
595 #outgoing_wrapper: From< () >,
596 };
597
598 ts.extend(quote! {
599 pub trait #subsystem_ctx_trait : SubsystemContext <
601 Message = #consumes,
602 Signal = #signal_ty,
603 OutgoingMessages = #outgoing_wrapper,
604 Error = #error_ty,
606 >
607 where
608 #where_clause
609 <Self as SubsystemContext>::Sender:
610 #subsystem_sender_trait
611 + #acc_sender_trait_bounds,
612 {
613 type Sender: #subsystem_sender_trait;
615 }
616
617 impl<T> #subsystem_ctx_trait for T
618 where
619 T: SubsystemContext <
620 Message = #consumes,
621 Signal = #signal_ty,
622 OutgoingMessages = #outgoing_wrapper,
623 Error = #error_ty,
625 >,
626 #where_clause
627 <T as SubsystemContext>::Sender:
628 #subsystem_sender_trait
629 + #acc_sender_trait_bounds,
630 {
631 type Sender = <T as SubsystemContext>::Sender;
632 }
633 });
634
635 ts.extend(impl_subsystem_context_trait_for(
636 info,
637 parse_quote! { #consumes },
638 &Vec::from_iter(outgoing.iter().map(|path| {
639 parse_quote! { #path }
640 })),
641 parse_quote! { #outgoing_wrapper },
642 all_messages_wrapper,
643 subsystem_ctx_name,
644 subsystem_sender_name,
645 support_crate,
646 signal_ty,
647 error_ty,
648 ));
649 ts
650}
651
652pub(crate) fn impl_subsystem_context(
656 info: &OrchestraInfo,
657 subsystem_sender_name: &Ident,
658 subsystem_ctx_name: &Ident,
659) -> TokenStream {
660 let signal_ty = &info.extern_signal_ty;
661 let support_crate = info.support_crate_name();
662 let maybe_boxed_message_generic: Type = if info.boxed_messages {
663 parse_quote! { ::std::boxed::Box<M> }
664 } else {
665 parse_quote! { M }
666 };
667
668 let ts = quote! {
669 #[derive(Debug)]
677 #[allow(missing_docs)]
678 pub struct #subsystem_ctx_name<M: AssociateOutgoing + Send + 'static> {
679 signals: #support_crate ::metered::MeteredReceiver< #signal_ty >,
680 messages: SubsystemIncomingMessages< #maybe_boxed_message_generic >,
681 to_subsystems: #subsystem_sender_name < <M as AssociateOutgoing>::OutgoingMessages >,
682 to_orchestra: #support_crate ::metered::UnboundedMeteredSender<
683 #support_crate ::ToOrchestra
684 >,
685 signals_received: SignalsReceived,
686 pending_incoming: Option<(usize, M)>,
687 name: &'static str
688 }
689
690 impl<M> #subsystem_ctx_name <M>
691 where
692 M: AssociateOutgoing + Send + 'static,
693 {
694 fn new(
696 signals: #support_crate ::metered::MeteredReceiver< #signal_ty >,
697 messages: SubsystemIncomingMessages< #maybe_boxed_message_generic >,
698 to_subsystems: ChannelsOut,
699 to_orchestra: #support_crate ::metered::UnboundedMeteredSender<#support_crate:: ToOrchestra>,
700 name: &'static str
701 ) -> Self {
702 let signals_received = SignalsReceived::default();
703 #subsystem_ctx_name :: <M> {
704 signals,
705 messages,
706 to_subsystems: #subsystem_sender_name :: < <M as AssociateOutgoing>::OutgoingMessages > {
707 channels: to_subsystems,
708 signals_received: signals_received.clone(),
709 _phantom: ::core::marker::PhantomData,
710 },
711 to_orchestra,
712 signals_received,
713 pending_incoming: None,
714 name
715 }
716 }
717
718 fn name(&self) -> &'static str {
719 self.name
720 }
721 }
722 };
723
724 ts
725}