Refactor graph run to not be mutable, use generic for node retriever

This commit is contained in:
2024-07-31 20:28:49 +09:30
parent e9caf43de3
commit 0f0d40c2a1

View File

@@ -150,27 +150,32 @@ pub enum NodeStatus {
pub struct RunnableGraph { pub struct RunnableGraph {
pub graph: Graph, pub graph: Graph,
pub node_statuses: HashMap<i64, NodeStatus>,
} }
impl RunnableGraph { impl RunnableGraph {
pub fn from_graph(graph: Graph) -> RunnableGraph { pub fn from_graph(graph: Graph) -> RunnableGraph {
RunnableGraph { RunnableGraph { graph }
graph,
node_statuses: HashMap::new(),
}
} }
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> { pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
self.run(num_threads, Box::new(|node| get_runnable_node(node))) self.run(
num_threads,
Box::new(|node| get_runnable_node(node)),
|id, status| {},
)
} }
// Make this not mutable, emit node status when required in a function or some other message // Make this not mutable, emit node status when required in a function or some other message
pub fn run( pub fn run<'a, F, StatusChanged>(
&mut self, &self,
num_threads: usize, num_threads: usize,
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>, get_node_fn: F,
) -> anyhow::Result<()> { node_status_changed_fn: StatusChanged,
) -> anyhow::Result<()>
where
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
StatusChanged: Fn(i64, NodeStatus),
{
let mut nodes = self.graph.nodes.clone(); let mut nodes = self.graph.nodes.clone();
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first) // 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
nodes.sort_by(|a, b| { nodes.sort_by(|a, b| {
@@ -188,10 +193,10 @@ impl RunnableGraph {
// Sync version // Sync version
if num_threads < 2 { if num_threads < 2 {
for node in &nodes { for node in &nodes {
self.node_statuses.insert(node.id, NodeStatus::Running); node_status_changed_fn(node.id, NodeStatus::Running);
match get_node_fn(node.clone()).run() { match get_node_fn(node.clone()).run() {
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed), Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed), Err(_) => node_status_changed_fn(node.id, NodeStatus::Failed),
}; };
} }
return Ok(()); return Ok(());
@@ -231,7 +236,7 @@ impl RunnableGraph {
running_threads.insert(i); running_threads.insert(i);
// Run all nodes that have no dependencies // Run all nodes that have no dependencies
let node = nodes.remove(i); let node = nodes.remove(i);
self.node_statuses.insert(node.id, NodeStatus::Running); node_status_changed_fn(node.id, NodeStatus::Running);
running_nodes.insert(node.id); running_nodes.insert(node.id);
senders[i % senders.len()].send(node); senders[i % senders.len()].send(node);
} }
@@ -241,7 +246,7 @@ impl RunnableGraph {
for (n, node) in listen_finish_task { for (n, node) in listen_finish_task {
running_threads.remove(&n); running_threads.remove(&n);
// TODO: Add error check here // TODO: Add error check here
self.node_statuses.insert(node.id, NodeStatus::Completed); node_status_changed_fn(node.id, NodeStatus::Completed);
running_nodes.remove(&node.id); running_nodes.remove(&node.id);
completed_nodes.insert(node.id); completed_nodes.insert(node.id);
// Run all the nodes that can be run and aren't in completed // Run all the nodes that can be run and aren't in completed
@@ -281,7 +286,6 @@ impl RunnableGraph {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::collections::HashMap;
use chrono::Local; use chrono::Local;
@@ -290,7 +294,6 @@ mod tests {
#[test] #[test]
fn test_basic() -> anyhow::Result<()> { fn test_basic() -> anyhow::Result<()> {
let mut graph = RunnableGraph { let mut graph = RunnableGraph {
node_statuses: HashMap::new(),
graph: super::Graph { graph: super::Graph {
name: "Test".to_owned(), name: "Test".to_owned(),
nodes: vec![super::Node { nodes: vec![super::Node {