Refactor graph run to not be mutable, use generic for node retriever

This commit is contained in:
2024-07-31 20:28:49 +09:30
parent e9caf43de3
commit 0f0d40c2a1

View File

@@ -150,27 +150,32 @@ pub enum NodeStatus {
pub struct RunnableGraph {
pub graph: Graph,
pub node_statuses: HashMap<i64, NodeStatus>,
}
impl RunnableGraph {
pub fn from_graph(graph: Graph) -> RunnableGraph {
RunnableGraph {
graph,
node_statuses: HashMap::new(),
}
RunnableGraph { graph }
}
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
self.run(num_threads, Box::new(|node| get_runnable_node(node)))
self.run(
num_threads,
Box::new(|node| get_runnable_node(node)),
|id, status| {},
)
}
// Make this not mutable, emit node status when required in a function or some other message
pub fn run(
&mut self,
pub fn run<'a, F, StatusChanged>(
&self,
num_threads: usize,
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>,
) -> anyhow::Result<()> {
get_node_fn: F,
node_status_changed_fn: StatusChanged,
) -> anyhow::Result<()>
where
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
StatusChanged: Fn(i64, NodeStatus),
{
let mut nodes = self.graph.nodes.clone();
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
nodes.sort_by(|a, b| {
@@ -188,10 +193,10 @@ impl RunnableGraph {
// Sync version
if num_threads < 2 {
for node in &nodes {
self.node_statuses.insert(node.id, NodeStatus::Running);
node_status_changed_fn(node.id, NodeStatus::Running);
match get_node_fn(node.clone()).run() {
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed),
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed),
Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
Err(_) => node_status_changed_fn(node.id, NodeStatus::Failed),
};
}
return Ok(());
@@ -231,7 +236,7 @@ impl RunnableGraph {
running_threads.insert(i);
// Run all nodes that have no dependencies
let node = nodes.remove(i);
self.node_statuses.insert(node.id, NodeStatus::Running);
node_status_changed_fn(node.id, NodeStatus::Running);
running_nodes.insert(node.id);
senders[i % senders.len()].send(node);
}
@@ -241,7 +246,7 @@ impl RunnableGraph {
for (n, node) in listen_finish_task {
running_threads.remove(&n);
// TODO: Add error check here
self.node_statuses.insert(node.id, NodeStatus::Completed);
node_status_changed_fn(node.id, NodeStatus::Completed);
running_nodes.remove(&node.id);
completed_nodes.insert(node.id);
// Run all the nodes that can be run and aren't in completed
@@ -281,7 +286,6 @@ impl RunnableGraph {
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use chrono::Local;
@@ -290,7 +294,6 @@ mod tests {
#[test]
fn test_basic() -> anyhow::Result<()> {
let mut graph = RunnableGraph {
node_statuses: HashMap::new(),
graph: super::Graph {
name: "Test".to_owned(),
nodes: vec![super::Node {