use std::{ cmp::{min, Ordering}, collections::{HashMap, HashSet}, sync::{ mpsc::{self, Sender}, Arc, }, thread, }; use chrono::Local; use serde::{Deserialize, Serialize}; use crate::{ derive::DeriveNode, filter::{FilterNode, FilterNodeRunner}, node::RunnableNode, sql_rule::{SQLNode, SQLNodeRunner}, upload_to_db::{UploadNode, UploadNodeRunner}, }; #[derive(Serialize, Deserialize, Clone)] pub enum NodeConfiguration { FileNode, MoveMoneyNode(MoveMoneyNode), MergeNode(MergeNode), DeriveNode(DeriveNode), CodeRuleNode(CodeRuleNode), FilterNode(FilterNode), UploadNode(UploadNode), SQLNode(SQLNode), Dynamic, } #[derive(Serialize, Deserialize, Clone)] pub struct DynamicConfiguration { pub node_type: String, pub parameters: HashMap, } #[derive(Serialize, Deserialize, Clone)] pub struct NodeInfo { pub name: String, pub output_files: Vec, pub configuration: NodeConfiguration, pub dynamic_configuration: Option, } #[derive(Serialize, Deserialize, Clone)] pub enum MoveMoneyAmountType { Percent, Amount, } #[derive(Serialize, Deserialize, Clone)] pub struct MoveMoneyRule { pub from_account: String, pub from_cc: String, pub to_account: String, pub to_cc: String, pub value: f64, pub amount_type: MoveMoneyAmountType, } #[derive(Serialize, Deserialize, Clone)] pub struct MoveMoneyNode { pub departments_path: String, pub accounts_path: String, pub gl_path: String, pub rules: Vec, } #[derive(Serialize, Deserialize, Clone)] pub enum JoinType { Left, Inner, Right, } #[derive(Serialize, Deserialize, Clone)] pub struct MergeJoin { pub join_type: JoinType, pub left_column_name: String, pub right_column_name: String, } #[derive(Serialize, Deserialize, Clone)] pub struct MergeNode { pub input_files: Vec, pub joins: Vec, } #[derive(Serialize, Deserialize, Clone)] pub enum CodeRuleLanguage { Javascript, Rust, Go, } #[derive(Serialize, Deserialize, Clone)] pub struct CodeRuleNode { pub language: CodeRuleLanguage, pub text: String, } #[derive(Serialize, Deserialize, Clone)] pub struct Node { pub id: i64, pub info: NodeInfo, pub dependent_node_ids: Vec, // Lets us work out whether a task should be rerun, by // inspecting the files at the output paths and check if // their timestamp is before this // TODO: Could just be seconds since unix epoch? pub last_modified: chrono::DateTime, } impl Node { pub fn has_dependent_nodes(&self) -> bool { !self.dependent_node_ids.is_empty() } } fn get_runnable_node(node: Node) -> Box { match node.info.configuration { NodeConfiguration::FileNode => todo!(), NodeConfiguration::MoveMoneyNode(_) => todo!(), NodeConfiguration::MergeNode(_) => todo!(), NodeConfiguration::DeriveNode(_) => todo!(), NodeConfiguration::CodeRuleNode(_) => todo!(), NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }), NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }), NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }), NodeConfiguration::Dynamic => todo!(), } } #[derive(Serialize, Deserialize, Clone)] pub struct Graph { pub name: String, pub nodes: Vec, } pub enum NodeStatus { Completed, Running, // TODO: Error code? Failed, } pub struct RunnableGraph { pub graph: Graph, pub node_statuses: HashMap, } impl RunnableGraph { pub fn from_graph(graph: Graph) -> RunnableGraph { RunnableGraph { graph, node_statuses: HashMap::new(), } } pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> { self.run(num_threads, Box::new(|node| get_runnable_node(node))) } // Make this not mutable, emit node status when required in a function or some other message pub fn run( &mut self, num_threads: usize, get_node_fn: Box Box + Send + Sync>, ) -> anyhow::Result<()> { let mut nodes = self.graph.nodes.clone(); // 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first) nodes.sort_by(|a, b| { if b.dependent_node_ids.contains(&a.id) { return Ordering::Greater; } if a.dependent_node_ids.contains(&b.id) { return Ordering::Less; } Ordering::Equal }); let num_threads = min(num_threads, nodes.len()); // Sync version if num_threads < 2 { for node in &nodes { self.node_statuses.insert(node.id, NodeStatus::Running); match get_node_fn(node.clone()).run() { Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed), Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed), }; } return Ok(()); } let mut running_nodes = HashSet::new(); let mut completed_nodes = HashSet::new(); let mut senders: Vec> = vec![]; let mut handles = vec![]; let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads); let node_fn = Arc::new(get_node_fn); for n in 0..num_threads { let finish_task = finish_task.clone(); let (tx, rx) = mpsc::channel(); senders.push(tx); let node_fn = node_fn.clone(); let handle = thread::spawn(move || { for node in rx { node_fn(node.clone()).run(); finish_task.send((n, node)); } println!("Thread {} finished", n); }); handles.push(handle); } let mut running_threads = HashSet::new(); // Run nodes without dependencies. There'll always be at least one for i in 0..nodes.len() { // Ensure we don't overload threads if i >= num_threads { break; } if !nodes[i].has_dependent_nodes() { running_threads.insert(i); // Run all nodes that have no dependencies let node = nodes.remove(i); self.node_statuses.insert(node.id, NodeStatus::Running); running_nodes.insert(node.id); senders[i % senders.len()].send(node); } } // Run each dependent node after a graph above finishes. for (n, node) in listen_finish_task { running_threads.remove(&n); // TODO: Add error check here self.node_statuses.insert(node.id, NodeStatus::Completed); running_nodes.remove(&node.id); completed_nodes.insert(node.id); // Run all the nodes that can be run and aren't in completed let mut i = 0; while running_threads.len() < num_threads && i < nodes.len() { if !running_nodes.contains(&nodes[i].id) && nodes[i] .dependent_node_ids .iter() .all(|id| completed_nodes.contains(id)) { let node = nodes.remove(i); for i in 0..num_threads { if !running_threads.contains(&i) { senders[i].send(node); break; } } } i += 1; } if nodes.is_empty() { break; } } for sender in senders { drop(sender); } for handle in handles { handle.join().expect("Failed to join thread"); } println!("Process finished"); Ok(()) } } #[cfg(test)] mod tests { use std::collections::HashMap; use chrono::Local; use super::{NodeConfiguration, RunnableGraph}; #[test] fn test_basic() -> anyhow::Result<()> { let mut graph = RunnableGraph { node_statuses: HashMap::new(), graph: super::Graph { name: "Test".to_owned(), nodes: vec![super::Node { id: 1, dependent_node_ids: vec![], info: super::NodeInfo { name: "Hello".to_owned(), configuration: NodeConfiguration::FilterNode(super::FilterNode { filters: vec![], input_file_path: "".to_owned(), output_file_path: "".to_owned(), }), output_files: vec![], dynamic_configuration: None, }, last_modified: Local::now(), }], }, }; graph.run_default_tasks(2)?; Ok(()) } }