use std::{ cmp::{min, Ordering}, collections::{HashMap, HashSet}, sync::{ mpsc::{self, Sender}, Arc, }, thread, }; use chrono::Local; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use split::{SplitNode, SplitNodeRunner}; use { derive::DeriveNode, filter::{FilterNode, FilterNodeRunner}, node::RunnableNode, sql_rule::{SQLNode, SQLNodeRunner}, upload_to_db::{UploadNode, UploadNodeRunner}, }; mod derive; mod filter; mod node; mod split; mod sql_rule; mod upload_to_db; #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum NodeConfiguration { FileNode, MoveMoneyNode(MoveMoneyNode), MergeNode(MergeNode), DeriveNode(DeriveNode), CodeRuleNode(CodeRuleNode), FilterNode(FilterNode), UploadNode(UploadNode), SQLNode(SQLNode), Dynamic, SplitNode(SplitNode), } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DynamicConfiguration { pub node_type: String, pub parameters: HashMap, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct NodeInfo { pub name: String, pub output_files: Vec, pub configuration: NodeConfiguration, pub dynamic_configuration: Option, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum MoveMoneyAmountType { Percent, Amount, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MoveMoneyRule { pub from_account: String, pub from_cc: String, pub to_account: String, pub to_cc: String, pub value: f64, pub amount_type: MoveMoneyAmountType, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MoveMoneyNode { pub departments_path: String, pub accounts_path: String, pub gl_path: String, pub rules: Vec, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum JoinType { Left, Inner, Right, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MergeJoin { pub join_type: JoinType, pub left_column_name: String, pub right_column_name: String, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MergeNode { pub input_files: Vec, pub joins: Vec, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum CodeRuleLanguage { Javascript, Rust, Go, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct CodeRuleNode { pub language: CodeRuleLanguage, pub text: String, } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct Node { pub id: i64, pub info: NodeInfo, pub dependent_node_ids: Vec, // Lets us work out whether a task should be rerun, by // inspecting the files at the output paths and check if // their timestamp is before this // TODO: Could just be seconds since unix epoch? pub last_modified: chrono::DateTime, } impl Node { pub fn has_dependent_nodes(&self) -> bool { !self.dependent_node_ids.is_empty() } } fn get_runnable_node(node: Node) -> Box { match node.info.configuration { NodeConfiguration::FileNode => todo!(), NodeConfiguration::MoveMoneyNode(_) => todo!(), NodeConfiguration::MergeNode(_) => todo!(), NodeConfiguration::DeriveNode(_) => todo!(), NodeConfiguration::CodeRuleNode(_) => todo!(), NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }), NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }), NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }), NodeConfiguration::Dynamic => todo!(), NodeConfiguration::SplitNode(split_node) => Box::new(SplitNodeRunner { split_node }), } } #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct Graph { pub name: String, pub nodes: Vec, } #[derive(Debug)] pub enum NodeStatus { Completed, Running, // Error code Failed(anyhow::Error), } pub struct RunnableGraph { pub graph: Graph, } impl RunnableGraph { pub fn from_graph(graph: Graph) -> RunnableGraph { RunnableGraph { graph } } pub fn run_default_tasks(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()> where F: Fn(i64, NodeStatus), { self.run( num_threads, Box::new(|node| get_runnable_node(node)), status_changed, ) } // Make this not mutable, emit node status when required in a function or some other message pub fn run<'a, F, StatusChanged>( &self, num_threads: usize, get_node_fn: F, node_status_changed_fn: StatusChanged, ) -> anyhow::Result<()> where F: Fn(Node) -> Box + Send + Sync + 'static, StatusChanged: Fn(i64, NodeStatus), { let mut nodes = self.graph.nodes.clone(); // 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first) nodes.sort_by(|a, b| { if b.dependent_node_ids.contains(&a.id) { return Ordering::Greater; } if a.dependent_node_ids.contains(&b.id) { return Ordering::Less; } Ordering::Equal }); let num_threads = min(num_threads, nodes.len()); // Sync version if num_threads < 2 { for node in &nodes { node_status_changed_fn(node.id, NodeStatus::Running); match get_node_fn(node.clone()).run() { Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed), Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)), }; } return Ok(()); } let mut running_nodes = HashSet::new(); let mut completed_nodes = HashSet::new(); let mut senders: Vec> = vec![]; let mut handles = vec![]; let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads); let node_fn = Arc::new(get_node_fn); for n in 0..num_threads { let finish_task = finish_task.clone(); let (tx, rx) = mpsc::channel(); senders.push(tx); let node_fn = node_fn.clone(); let handle = thread::spawn(move || { for node in rx { let status = match node_fn(node.clone()).run() { Ok(_) => NodeStatus::Completed, Err(err) => NodeStatus::Failed(err), }; finish_task .send((n, node, status)) .expect("Failed to notify node status completion"); } println!("Thread {} finished", n); }); handles.push(handle); } let mut running_threads = HashSet::new(); // Run nodes without dependencies. There'll always be at least one for i in 0..nodes.len() { // Ensure we don't overload threads if i >= num_threads { break; } if !nodes[i].has_dependent_nodes() { running_threads.insert(i); // Run all nodes that have no dependencies let node = nodes.remove(i); node_status_changed_fn(node.id, NodeStatus::Running); running_nodes.insert(node.id); senders[i % senders.len()].send(node)?; } } // Run each dependent node after a graph above finishes. for (n, node, error) in listen_finish_task { running_threads.remove(&n); node_status_changed_fn(node.id, error); running_nodes.remove(&node.id); completed_nodes.insert(node.id); // Run all the nodes that can be run and aren't in completed let mut i = 0; while running_threads.len() < num_threads && i < nodes.len() { if !running_nodes.contains(&nodes[i].id) && nodes[i] .dependent_node_ids .iter() .all(|id| completed_nodes.contains(id)) { let node = nodes.remove(i); for i in 0..num_threads { if !running_threads.contains(&i) { senders[i].send(node)?; break; } } } i += 1; } if nodes.is_empty() { break; } } for sender in senders { drop(sender); } for handle in handles { handle.join().expect("Failed to join thread"); } println!("Process finished"); Ok(()) } } #[cfg(test)] mod tests { use chrono::Local; use super::{NodeConfiguration, RunnableGraph}; #[test] fn test_basic() -> anyhow::Result<()> { let graph = RunnableGraph { graph: super::Graph { name: "Test".to_owned(), nodes: vec![super::Node { id: 1, dependent_node_ids: vec![], info: super::NodeInfo { name: "Hello".to_owned(), configuration: NodeConfiguration::FilterNode(super::FilterNode { filters: vec![], input_file_path: "".to_owned(), output_file_path: "".to_owned(), }), output_files: vec![], dynamic_configuration: None, }, last_modified: Local::now(), }], }, }; graph.run_default_tasks(2, |_, _| {})?; Ok(()) } }