Refactor graph run to not be mutable, use generic for node retriever
This commit is contained in:
37
src/graph.rs
37
src/graph.rs
@@ -150,27 +150,32 @@ pub enum NodeStatus {
|
||||
|
||||
pub struct RunnableGraph {
|
||||
pub graph: Graph,
|
||||
pub node_statuses: HashMap<i64, NodeStatus>,
|
||||
}
|
||||
|
||||
impl RunnableGraph {
|
||||
pub fn from_graph(graph: Graph) -> RunnableGraph {
|
||||
RunnableGraph {
|
||||
graph,
|
||||
node_statuses: HashMap::new(),
|
||||
}
|
||||
RunnableGraph { graph }
|
||||
}
|
||||
|
||||
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
|
||||
self.run(num_threads, Box::new(|node| get_runnable_node(node)))
|
||||
self.run(
|
||||
num_threads,
|
||||
Box::new(|node| get_runnable_node(node)),
|
||||
|id, status| {},
|
||||
)
|
||||
}
|
||||
|
||||
// Make this not mutable, emit node status when required in a function or some other message
|
||||
pub fn run(
|
||||
&mut self,
|
||||
pub fn run<'a, F, StatusChanged>(
|
||||
&self,
|
||||
num_threads: usize,
|
||||
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>,
|
||||
) -> anyhow::Result<()> {
|
||||
get_node_fn: F,
|
||||
node_status_changed_fn: StatusChanged,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
|
||||
StatusChanged: Fn(i64, NodeStatus),
|
||||
{
|
||||
let mut nodes = self.graph.nodes.clone();
|
||||
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
|
||||
nodes.sort_by(|a, b| {
|
||||
@@ -188,10 +193,10 @@ impl RunnableGraph {
|
||||
// Sync version
|
||||
if num_threads < 2 {
|
||||
for node in &nodes {
|
||||
self.node_statuses.insert(node.id, NodeStatus::Running);
|
||||
node_status_changed_fn(node.id, NodeStatus::Running);
|
||||
match get_node_fn(node.clone()).run() {
|
||||
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed),
|
||||
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed),
|
||||
Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
|
||||
Err(_) => node_status_changed_fn(node.id, NodeStatus::Failed),
|
||||
};
|
||||
}
|
||||
return Ok(());
|
||||
@@ -231,7 +236,7 @@ impl RunnableGraph {
|
||||
running_threads.insert(i);
|
||||
// Run all nodes that have no dependencies
|
||||
let node = nodes.remove(i);
|
||||
self.node_statuses.insert(node.id, NodeStatus::Running);
|
||||
node_status_changed_fn(node.id, NodeStatus::Running);
|
||||
running_nodes.insert(node.id);
|
||||
senders[i % senders.len()].send(node);
|
||||
}
|
||||
@@ -241,7 +246,7 @@ impl RunnableGraph {
|
||||
for (n, node) in listen_finish_task {
|
||||
running_threads.remove(&n);
|
||||
// TODO: Add error check here
|
||||
self.node_statuses.insert(node.id, NodeStatus::Completed);
|
||||
node_status_changed_fn(node.id, NodeStatus::Completed);
|
||||
running_nodes.remove(&node.id);
|
||||
completed_nodes.insert(node.id);
|
||||
// Run all the nodes that can be run and aren't in completed
|
||||
@@ -281,7 +286,6 @@ impl RunnableGraph {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chrono::Local;
|
||||
|
||||
@@ -290,7 +294,6 @@ mod tests {
|
||||
#[test]
|
||||
fn test_basic() -> anyhow::Result<()> {
|
||||
let mut graph = RunnableGraph {
|
||||
node_statuses: HashMap::new(),
|
||||
graph: super::Graph {
|
||||
name: "Test".to_owned(),
|
||||
nodes: vec![super::Node {
|
||||
|
||||
Reference in New Issue
Block a user