1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::Bytes;
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17 ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, LocalDeploy, Node, ProcessSpec,
18 RegisterPort,
19};
20use crate::ir::HydroLeaf;
21use crate::location::external_process::{
22 ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, ExternalProcess, Location, LocationId, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D: LocalDeploy<'a>> {
28 pub(super) ir: UnsafeCell<Vec<HydroLeaf>>,
32
33 pub(super) processes: HashMap<usize, D::Process>,
35
36 pub(super) process_id_name: Vec<(usize, String)>,
39
40 pub(super) externals: HashMap<usize, D::ExternalProcess>,
41 pub(super) external_id_name: Vec<(usize, String)>,
42
43 pub(super) clusters: HashMap<usize, D::Cluster>,
44 pub(super) cluster_id_name: Vec<(usize, String)>,
45 pub(super) used: bool,
46
47 pub(super) _phantom: Invariant<'a, D>,
48}
49
50impl<'a, D: LocalDeploy<'a>> Drop for DeployFlow<'a, D> {
51 fn drop(&mut self) {
52 if !self.used {
53 panic!(
54 "Dropped DeployFlow without instantiating, you may have forgotten to call `compile` or `deploy`."
55 );
56 }
57 }
58}
59
60impl<'a, D: LocalDeploy<'a>> DeployFlow<'a, D> {
61 pub fn ir(&self) -> &Vec<HydroLeaf> {
62 unsafe {
63 &*self.ir.get()
65 }
66 }
67
68 pub fn with_process<P>(
69 mut self,
70 process: &Process<P>,
71 spec: impl IntoProcessSpec<'a, D>,
72 ) -> Self {
73 let tag_name = std::any::type_name::<P>().to_string();
74 self.processes.insert(
75 process.id,
76 spec.into_process_spec().build(process.id, &tag_name),
77 );
78 self
79 }
80
81 pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
82 mut self,
83 spec: impl Fn() -> S,
84 ) -> Self {
85 for (id, name) in &self.process_id_name {
86 self.processes
87 .insert(*id, spec().into_process_spec().build(*id, name));
88 }
89
90 self
91 }
92
93 pub fn with_external<P>(
94 mut self,
95 process: &ExternalProcess<P>,
96 spec: impl ExternalSpec<'a, D>,
97 ) -> Self {
98 let tag_name = std::any::type_name::<P>().to_string();
99 self.externals
100 .insert(process.id, spec.build(process.id, &tag_name));
101 self
102 }
103
104 pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
105 mut self,
106 spec: impl Fn() -> S,
107 ) -> Self {
108 for (id, name) in &self.external_id_name {
109 self.externals.insert(*id, spec().build(*id, name));
110 }
111
112 self
113 }
114
115 pub fn with_cluster<C>(mut self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
116 let tag_name = std::any::type_name::<C>().to_string();
117 self.clusters
118 .insert(cluster.id, spec.build(cluster.id, &tag_name));
119 self
120 }
121
122 pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
123 mut self,
124 spec: impl Fn() -> S,
125 ) -> Self {
126 for (id, name) in &self.cluster_id_name {
127 self.clusters.insert(*id, spec().build(*id, name));
128 }
129
130 self
131 }
132
133 pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
136 CompiledFlow {
137 dfir: build_inner(unsafe {
138 &mut *self.ir.get()
141 }),
142 #[cfg(feature = "staged_macro")]
143 extra_stmts: BTreeMap::new(),
144 _phantom: PhantomData,
145 }
146 }
147
148 pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
149 self.used = true;
150
151 CompiledFlow {
152 dfir: build_inner(self.ir.get_mut()),
153 #[cfg(feature = "staged_macro")]
154 extra_stmts: BTreeMap::new(),
155 _phantom: PhantomData,
156 }
157 }
158}
159
160impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
161 pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
162 self.used = true;
163
164 let mut seen_tees: HashMap<_, _> = HashMap::new();
165 let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
166 self.ir.get_mut().iter_mut().for_each(|leaf| {
167 leaf.compile_network::<D>(
168 env,
169 &mut seen_tees,
170 &mut seen_tee_locations,
171 &self.processes,
172 &self.clusters,
173 &self.externals,
174 );
175 });
176
177 #[cfg(feature = "staged_macro")]
178 let extra_stmts = self.extra_stmts(env);
179
180 CompiledFlow {
181 dfir: build_inner(self.ir.get_mut()),
182 #[cfg(feature = "staged_macro")]
183 extra_stmts,
184 _phantom: PhantomData,
185 }
186 }
187
188 fn extra_stmts(&self, env: &<D as Deploy<'a>>::CompileEnv) -> BTreeMap<usize, Vec<syn::Stmt>> {
189 let mut extra_stmts: BTreeMap<usize, Vec<syn::Stmt>> = BTreeMap::new();
190
191 let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
192 all_clusters_sorted.sort();
193
194 for &c_id in all_clusters_sorted {
195 let self_id_ident = syn::Ident::new(
196 &format!("__hydro_lang_cluster_self_id_{}", c_id),
197 Span::call_site(),
198 );
199 let self_id_expr = D::cluster_self_id(env).splice_untyped();
200 extra_stmts
201 .entry(c_id)
202 .or_default()
203 .push(syn::parse_quote! {
204 let #self_id_ident = #self_id_expr;
205 });
206
207 for other_location in self.processes.keys().chain(self.clusters.keys()) {
208 let other_id_ident = syn::Ident::new(
209 &format!("__hydro_lang_cluster_ids_{}", c_id),
210 Span::call_site(),
211 );
212 let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
213 extra_stmts
214 .entry(*other_location)
215 .or_default()
216 .push(syn::parse_quote! {
217 let #other_id_ident = #other_id_expr;
218 });
219 }
220 }
221 extra_stmts
222 }
223}
224
225impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
226 #[must_use]
227 pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
228 self.used = true;
229
230 let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
231 let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
232 self.ir.get_mut().iter_mut().for_each(|leaf| {
233 leaf.compile_network::<D>(
234 &(),
235 &mut seen_tees_instantiate,
236 &mut seen_tee_locations,
237 &self.processes,
238 &self.clusters,
239 &self.externals,
240 );
241 });
242
243 let mut compiled = build_inner(self.ir.get_mut());
244 let mut extra_stmts = self.extra_stmts(&());
245 let mut meta = D::Meta::default();
246
247 let (mut processes, mut clusters, mut externals) = (
248 std::mem::take(&mut self.processes)
249 .into_iter()
250 .filter_map(|(node_id, node)| {
251 if let Some(ir) = compiled.remove(&node_id) {
252 node.instantiate(
253 env,
254 &mut meta,
255 ir,
256 extra_stmts.remove(&node_id).unwrap_or_default(),
257 );
258 Some((node_id, node))
259 } else {
260 None
261 }
262 })
263 .collect::<HashMap<_, _>>(),
264 std::mem::take(&mut self.clusters)
265 .into_iter()
266 .filter_map(|(cluster_id, cluster)| {
267 if let Some(ir) = compiled.remove(&cluster_id) {
268 cluster.instantiate(
269 env,
270 &mut meta,
271 ir,
272 extra_stmts.remove(&cluster_id).unwrap_or_default(),
273 );
274 Some((cluster_id, cluster))
275 } else {
276 None
277 }
278 })
279 .collect::<HashMap<_, _>>(),
280 std::mem::take(&mut self.externals)
281 .into_iter()
282 .map(|(external_id, external)| {
283 external.instantiate(
284 env,
285 &mut meta,
286 compiled.remove(&external_id).unwrap(),
287 extra_stmts.remove(&external_id).unwrap_or_default(),
288 );
289 (external_id, external)
290 })
291 .collect::<HashMap<_, _>>(),
292 );
293
294 for node in processes.values_mut() {
295 node.update_meta(&meta);
296 }
297
298 for cluster in clusters.values_mut() {
299 cluster.update_meta(&meta);
300 }
301
302 for external in externals.values_mut() {
303 external.update_meta(&meta);
304 }
305
306 let mut seen_tees_connect = HashMap::new();
307 self.ir.get_mut().iter_mut().for_each(|leaf| {
308 leaf.connect_network(&mut seen_tees_connect);
309 });
310
311 DeployResult {
312 processes,
313 clusters,
314 externals,
315 cluster_id_name: std::mem::take(&mut self.cluster_id_name)
316 .into_iter()
317 .collect(),
318 }
319 }
320}
321
322pub struct DeployResult<'a, D: Deploy<'a>> {
323 processes: HashMap<usize, D::Process>,
324 clusters: HashMap<usize, D::Cluster>,
325 externals: HashMap<usize, D::ExternalProcess>,
326 cluster_id_name: HashMap<usize, String>,
327}
328
329impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
330 pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
331 let id = match p.id() {
332 LocationId::Process(id) => id,
333 _ => panic!("Process ID expected"),
334 };
335
336 self.processes.get(&id).unwrap()
337 }
338
339 pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
340 let id = match c.id() {
341 LocationId::Cluster(id) => id,
342 _ => panic!("Cluster ID expected"),
343 };
344
345 self.clusters.get(&id).unwrap()
346 }
347
348 pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
349 self.clusters.iter().map(|(&id, c)| {
350 (
351 LocationId::Cluster(id),
352 self.cluster_id_name.get(&id).unwrap().clone(),
353 c,
354 )
355 })
356 }
357
358 pub fn get_external<P>(&self, p: &ExternalProcess<P>) -> &D::ExternalProcess {
359 self.externals.get(&p.id).unwrap()
360 }
361
362 pub fn raw_port(&self, port: ExternalBytesPort) -> D::ExternalRawPort {
363 self.externals
364 .get(&port.process_id)
365 .unwrap()
366 .raw_port(port.port_id)
367 }
368
369 pub async fn connect_sink_bytes(
370 &self,
371 port: ExternalBytesPort,
372 ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
373 self.externals
374 .get(&port.process_id)
375 .unwrap()
376 .as_bytes_sink(port.port_id)
377 .await
378 }
379
380 pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static>(
381 &self,
382 port: ExternalBincodeSink<T>,
383 ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
384 self.externals
385 .get(&port.process_id)
386 .unwrap()
387 .as_bincode_sink(port.port_id)
388 .await
389 }
390
391 pub async fn connect_source_bytes(
392 &self,
393 port: ExternalBytesPort,
394 ) -> Pin<Box<dyn Stream<Item = Bytes>>> {
395 self.externals
396 .get(&port.process_id)
397 .unwrap()
398 .as_bytes_source(port.port_id)
399 .await
400 }
401
402 pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
403 &self,
404 port: ExternalBincodeStream<T>,
405 ) -> Pin<Box<dyn Stream<Item = T>>> {
406 self.externals
407 .get(&port.process_id)
408 .unwrap()
409 .as_bincode_source(port.port_id)
410 .await
411 }
412}