diff --git a/src/graph.rs b/src/graph.rs index b6a0e9c..9909cb9 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -150,27 +150,32 @@ pub enum NodeStatus { pub struct RunnableGraph { pub graph: Graph, - pub node_statuses: HashMap, } impl RunnableGraph { pub fn from_graph(graph: Graph) -> RunnableGraph { - RunnableGraph { - graph, - node_statuses: HashMap::new(), - } + RunnableGraph { graph } } pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> { - self.run(num_threads, Box::new(|node| get_runnable_node(node))) + self.run( + num_threads, + Box::new(|node| get_runnable_node(node)), + |id, status| {}, + ) } // Make this not mutable, emit node status when required in a function or some other message - pub fn run( - &mut self, + pub fn run<'a, F, StatusChanged>( + &self, num_threads: usize, - get_node_fn: Box Box + Send + Sync>, - ) -> anyhow::Result<()> { + get_node_fn: F, + node_status_changed_fn: StatusChanged, + ) -> anyhow::Result<()> + where + F: Fn(Node) -> Box + Send + Sync + 'static, + StatusChanged: Fn(i64, NodeStatus), + { let mut nodes = self.graph.nodes.clone(); // 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first) nodes.sort_by(|a, b| { @@ -188,10 +193,10 @@ impl RunnableGraph { // Sync version if num_threads < 2 { for node in &nodes { - self.node_statuses.insert(node.id, NodeStatus::Running); + node_status_changed_fn(node.id, NodeStatus::Running); match get_node_fn(node.clone()).run() { - Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed), - Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed), + Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed), + Err(_) => node_status_changed_fn(node.id, NodeStatus::Failed), }; } return Ok(()); @@ -231,7 +236,7 @@ impl RunnableGraph { running_threads.insert(i); // Run all nodes that have no dependencies let node = nodes.remove(i); - self.node_statuses.insert(node.id, NodeStatus::Running); + node_status_changed_fn(node.id, NodeStatus::Running); running_nodes.insert(node.id); senders[i % senders.len()].send(node); } @@ -241,7 +246,7 @@ impl RunnableGraph { for (n, node) in listen_finish_task { running_threads.remove(&n); // TODO: Add error check here - self.node_statuses.insert(node.id, NodeStatus::Completed); + node_status_changed_fn(node.id, NodeStatus::Completed); running_nodes.remove(&node.id); completed_nodes.insert(node.id); // Run all the nodes that can be run and aren't in completed @@ -281,7 +286,6 @@ impl RunnableGraph { #[cfg(test)] mod tests { - use std::collections::HashMap; use chrono::Local; @@ -290,7 +294,6 @@ mod tests { #[test] fn test_basic() -> anyhow::Result<()> { let mut graph = RunnableGraph { - node_statuses: HashMap::new(), graph: super::Graph { name: "Test".to_owned(), nodes: vec![super::Node {