hydro_lang/builder/
deploy.rs

1use std::cell::UnsafeCell;
2use std::collections::{BTreeMap, HashMap};
3use std::io::Error;
4use std::marker::PhantomData;
5use std::pin::Pin;
6
7use bytes::Bytes;
8use futures::{Sink, Stream};
9use proc_macro2::Span;
10use serde::Serialize;
11use serde::de::DeserializeOwned;
12use stageleft::QuotedWithContext;
13
14use super::built::build_inner;
15use super::compiled::CompiledFlow;
16use crate::deploy::{
17    ClusterSpec, Deploy, ExternalSpec, IntoProcessSpec, LocalDeploy, Node, ProcessSpec,
18    RegisterPort,
19};
20use crate::ir::HydroLeaf;
21use crate::location::external_process::{
22    ExternalBincodeSink, ExternalBincodeStream, ExternalBytesPort,
23};
24use crate::location::{Cluster, ExternalProcess, Location, LocationId, Process};
25use crate::staging_util::Invariant;
26
27pub struct DeployFlow<'a, D: LocalDeploy<'a>> {
28    // We need to grab an `&mut` reference to the IR in `preview_compile` even though
29    // that function does not modify the IR. Using an `UnsafeCell` allows us to do this
30    // while still being able to lend out immutable references to the IR.
31    pub(super) ir: UnsafeCell<Vec<HydroLeaf>>,
32
33    /// Deployed instances of each process in the flow
34    pub(super) processes: HashMap<usize, D::Process>,
35
36    /// Lists all the processes that were created in the flow, same ID as `processes`
37    /// but with the type name of the tag.
38    pub(super) process_id_name: Vec<(usize, String)>,
39
40    pub(super) externals: HashMap<usize, D::ExternalProcess>,
41    pub(super) external_id_name: Vec<(usize, String)>,
42
43    pub(super) clusters: HashMap<usize, D::Cluster>,
44    pub(super) cluster_id_name: Vec<(usize, String)>,
45    pub(super) used: bool,
46
47    pub(super) _phantom: Invariant<'a, D>,
48}
49
50impl<'a, D: LocalDeploy<'a>> Drop for DeployFlow<'a, D> {
51    fn drop(&mut self) {
52        if !self.used {
53            panic!(
54                "Dropped DeployFlow without instantiating, you may have forgotten to call `compile` or `deploy`."
55            );
56        }
57    }
58}
59
60impl<'a, D: LocalDeploy<'a>> DeployFlow<'a, D> {
61    pub fn ir(&self) -> &Vec<HydroLeaf> {
62        unsafe {
63            // SAFETY: even when we grab this as mutable in `preview_compile`, we do not modify it
64            &*self.ir.get()
65        }
66    }
67
68    pub fn with_process<P>(
69        mut self,
70        process: &Process<P>,
71        spec: impl IntoProcessSpec<'a, D>,
72    ) -> Self {
73        let tag_name = std::any::type_name::<P>().to_string();
74        self.processes.insert(
75            process.id,
76            spec.into_process_spec().build(process.id, &tag_name),
77        );
78        self
79    }
80
81    pub fn with_remaining_processes<S: IntoProcessSpec<'a, D> + 'a>(
82        mut self,
83        spec: impl Fn() -> S,
84    ) -> Self {
85        for (id, name) in &self.process_id_name {
86            self.processes
87                .insert(*id, spec().into_process_spec().build(*id, name));
88        }
89
90        self
91    }
92
93    pub fn with_external<P>(
94        mut self,
95        process: &ExternalProcess<P>,
96        spec: impl ExternalSpec<'a, D>,
97    ) -> Self {
98        let tag_name = std::any::type_name::<P>().to_string();
99        self.externals
100            .insert(process.id, spec.build(process.id, &tag_name));
101        self
102    }
103
104    pub fn with_remaining_externals<S: ExternalSpec<'a, D> + 'a>(
105        mut self,
106        spec: impl Fn() -> S,
107    ) -> Self {
108        for (id, name) in &self.external_id_name {
109            self.externals.insert(*id, spec().build(*id, name));
110        }
111
112        self
113    }
114
115    pub fn with_cluster<C>(mut self, cluster: &Cluster<C>, spec: impl ClusterSpec<'a, D>) -> Self {
116        let tag_name = std::any::type_name::<C>().to_string();
117        self.clusters
118            .insert(cluster.id, spec.build(cluster.id, &tag_name));
119        self
120    }
121
122    pub fn with_remaining_clusters<S: ClusterSpec<'a, D> + 'a>(
123        mut self,
124        spec: impl Fn() -> S,
125    ) -> Self {
126        for (id, name) in &self.cluster_id_name {
127            self.clusters.insert(*id, spec().build(*id, name));
128        }
129
130        self
131    }
132
133    /// Compiles the flow into DFIR using placeholders for the network.
134    /// Useful for generating Mermaid diagrams of the DFIR.
135    pub fn preview_compile(&self) -> CompiledFlow<'a, ()> {
136        CompiledFlow {
137            dfir: build_inner(unsafe {
138                // SAFETY: `build_inner` does not mutate the IR, &mut is required
139                // only because the shared traversal logic requires it
140                &mut *self.ir.get()
141            }),
142            #[cfg(feature = "staged_macro")]
143            extra_stmts: BTreeMap::new(),
144            _phantom: PhantomData,
145        }
146    }
147
148    pub fn compile_no_network(mut self) -> CompiledFlow<'a, D::GraphId> {
149        self.used = true;
150
151        CompiledFlow {
152            dfir: build_inner(self.ir.get_mut()),
153            #[cfg(feature = "staged_macro")]
154            extra_stmts: BTreeMap::new(),
155            _phantom: PhantomData,
156        }
157    }
158}
159
160impl<'a, D: Deploy<'a>> DeployFlow<'a, D> {
161    pub fn compile(mut self, env: &D::CompileEnv) -> CompiledFlow<'a, D::GraphId> {
162        self.used = true;
163
164        let mut seen_tees: HashMap<_, _> = HashMap::new();
165        let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
166        self.ir.get_mut().iter_mut().for_each(|leaf| {
167            leaf.compile_network::<D>(
168                env,
169                &mut seen_tees,
170                &mut seen_tee_locations,
171                &self.processes,
172                &self.clusters,
173                &self.externals,
174            );
175        });
176
177        #[cfg(feature = "staged_macro")]
178        let extra_stmts = self.extra_stmts(env);
179
180        CompiledFlow {
181            dfir: build_inner(self.ir.get_mut()),
182            #[cfg(feature = "staged_macro")]
183            extra_stmts,
184            _phantom: PhantomData,
185        }
186    }
187
188    fn extra_stmts(&self, env: &<D as Deploy<'a>>::CompileEnv) -> BTreeMap<usize, Vec<syn::Stmt>> {
189        let mut extra_stmts: BTreeMap<usize, Vec<syn::Stmt>> = BTreeMap::new();
190
191        let mut all_clusters_sorted = self.clusters.keys().collect::<Vec<_>>();
192        all_clusters_sorted.sort();
193
194        for &c_id in all_clusters_sorted {
195            let self_id_ident = syn::Ident::new(
196                &format!("__hydro_lang_cluster_self_id_{}", c_id),
197                Span::call_site(),
198            );
199            let self_id_expr = D::cluster_self_id(env).splice_untyped();
200            extra_stmts
201                .entry(c_id)
202                .or_default()
203                .push(syn::parse_quote! {
204                    let #self_id_ident = #self_id_expr;
205                });
206
207            for other_location in self.processes.keys().chain(self.clusters.keys()) {
208                let other_id_ident = syn::Ident::new(
209                    &format!("__hydro_lang_cluster_ids_{}", c_id),
210                    Span::call_site(),
211                );
212                let other_id_expr = D::cluster_ids(env, c_id).splice_untyped();
213                extra_stmts
214                    .entry(*other_location)
215                    .or_default()
216                    .push(syn::parse_quote! {
217                        let #other_id_ident = #other_id_expr;
218                    });
219            }
220        }
221        extra_stmts
222    }
223}
224
225impl<'a, D: Deploy<'a, CompileEnv = ()>> DeployFlow<'a, D> {
226    #[must_use]
227    pub fn deploy(mut self, env: &mut D::InstantiateEnv) -> DeployResult<'a, D> {
228        self.used = true;
229
230        let mut seen_tees_instantiate: HashMap<_, _> = HashMap::new();
231        let mut seen_tee_locations: HashMap<_, _> = HashMap::new();
232        self.ir.get_mut().iter_mut().for_each(|leaf| {
233            leaf.compile_network::<D>(
234                &(),
235                &mut seen_tees_instantiate,
236                &mut seen_tee_locations,
237                &self.processes,
238                &self.clusters,
239                &self.externals,
240            );
241        });
242
243        let mut compiled = build_inner(self.ir.get_mut());
244        let mut extra_stmts = self.extra_stmts(&());
245        let mut meta = D::Meta::default();
246
247        let (mut processes, mut clusters, mut externals) = (
248            std::mem::take(&mut self.processes)
249                .into_iter()
250                .filter_map(|(node_id, node)| {
251                    if let Some(ir) = compiled.remove(&node_id) {
252                        node.instantiate(
253                            env,
254                            &mut meta,
255                            ir,
256                            extra_stmts.remove(&node_id).unwrap_or_default(),
257                        );
258                        Some((node_id, node))
259                    } else {
260                        None
261                    }
262                })
263                .collect::<HashMap<_, _>>(),
264            std::mem::take(&mut self.clusters)
265                .into_iter()
266                .filter_map(|(cluster_id, cluster)| {
267                    if let Some(ir) = compiled.remove(&cluster_id) {
268                        cluster.instantiate(
269                            env,
270                            &mut meta,
271                            ir,
272                            extra_stmts.remove(&cluster_id).unwrap_or_default(),
273                        );
274                        Some((cluster_id, cluster))
275                    } else {
276                        None
277                    }
278                })
279                .collect::<HashMap<_, _>>(),
280            std::mem::take(&mut self.externals)
281                .into_iter()
282                .map(|(external_id, external)| {
283                    external.instantiate(
284                        env,
285                        &mut meta,
286                        compiled.remove(&external_id).unwrap(),
287                        extra_stmts.remove(&external_id).unwrap_or_default(),
288                    );
289                    (external_id, external)
290                })
291                .collect::<HashMap<_, _>>(),
292        );
293
294        for node in processes.values_mut() {
295            node.update_meta(&meta);
296        }
297
298        for cluster in clusters.values_mut() {
299            cluster.update_meta(&meta);
300        }
301
302        for external in externals.values_mut() {
303            external.update_meta(&meta);
304        }
305
306        let mut seen_tees_connect = HashMap::new();
307        self.ir.get_mut().iter_mut().for_each(|leaf| {
308            leaf.connect_network(&mut seen_tees_connect);
309        });
310
311        DeployResult {
312            processes,
313            clusters,
314            externals,
315            cluster_id_name: std::mem::take(&mut self.cluster_id_name)
316                .into_iter()
317                .collect(),
318        }
319    }
320}
321
322pub struct DeployResult<'a, D: Deploy<'a>> {
323    processes: HashMap<usize, D::Process>,
324    clusters: HashMap<usize, D::Cluster>,
325    externals: HashMap<usize, D::ExternalProcess>,
326    cluster_id_name: HashMap<usize, String>,
327}
328
329impl<'a, D: Deploy<'a>> DeployResult<'a, D> {
330    pub fn get_process<P>(&self, p: &Process<P>) -> &D::Process {
331        let id = match p.id() {
332            LocationId::Process(id) => id,
333            _ => panic!("Process ID expected"),
334        };
335
336        self.processes.get(&id).unwrap()
337    }
338
339    pub fn get_cluster<C>(&self, c: &Cluster<'a, C>) -> &D::Cluster {
340        let id = match c.id() {
341            LocationId::Cluster(id) => id,
342            _ => panic!("Cluster ID expected"),
343        };
344
345        self.clusters.get(&id).unwrap()
346    }
347
348    pub fn get_all_clusters(&self) -> impl Iterator<Item = (LocationId, String, &D::Cluster)> {
349        self.clusters.iter().map(|(&id, c)| {
350            (
351                LocationId::Cluster(id),
352                self.cluster_id_name.get(&id).unwrap().clone(),
353                c,
354            )
355        })
356    }
357
358    pub fn get_external<P>(&self, p: &ExternalProcess<P>) -> &D::ExternalProcess {
359        self.externals.get(&p.id).unwrap()
360    }
361
362    pub fn raw_port(&self, port: ExternalBytesPort) -> D::ExternalRawPort {
363        self.externals
364            .get(&port.process_id)
365            .unwrap()
366            .raw_port(port.port_id)
367    }
368
369    pub async fn connect_sink_bytes(
370        &self,
371        port: ExternalBytesPort,
372    ) -> Pin<Box<dyn Sink<Bytes, Error = Error>>> {
373        self.externals
374            .get(&port.process_id)
375            .unwrap()
376            .as_bytes_sink(port.port_id)
377            .await
378    }
379
380    pub async fn connect_sink_bincode<T: Serialize + DeserializeOwned + 'static>(
381        &self,
382        port: ExternalBincodeSink<T>,
383    ) -> Pin<Box<dyn Sink<T, Error = Error>>> {
384        self.externals
385            .get(&port.process_id)
386            .unwrap()
387            .as_bincode_sink(port.port_id)
388            .await
389    }
390
391    pub async fn connect_source_bytes(
392        &self,
393        port: ExternalBytesPort,
394    ) -> Pin<Box<dyn Stream<Item = Bytes>>> {
395        self.externals
396            .get(&port.process_id)
397            .unwrap()
398            .as_bytes_source(port.port_id)
399            .await
400    }
401
402    pub async fn connect_source_bincode<T: Serialize + DeserializeOwned + 'static>(
403        &self,
404        port: ExternalBincodeStream<T>,
405    ) -> Pin<Box<dyn Stream<Item = T>>> {
406        self.externals
407            .get(&port.process_id)
408            .unwrap()
409            .as_bincode_source(port.port_id)
410            .await
411    }
412}