Files
ingey/src/graph.rs
2024-07-31 20:00:33 +09:30

317 lines
9.2 KiB
Rust

use std::{
cmp::{min, Ordering},
collections::{HashMap, HashSet},
sync::{
mpsc::{self, Sender},
Arc,
},
thread,
};
use chrono::Local;
use serde::{Deserialize, Serialize};
use crate::{
derive::DeriveNode,
filter::{FilterNode, FilterNodeRunner},
node::RunnableNode,
sql_rule::{SQLNode, SQLNodeRunner},
upload_to_db::{UploadNode, UploadNodeRunner},
};
#[derive(Serialize, Deserialize, Clone)]
pub enum NodeConfiguration {
FileNode,
MoveMoneyNode(MoveMoneyNode),
MergeNode(MergeNode),
DeriveNode(DeriveNode),
CodeRuleNode(CodeRuleNode),
FilterNode(FilterNode),
UploadNode(UploadNode),
SQLNode(SQLNode),
Dynamic,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DynamicConfiguration {
pub node_type: String,
pub parameters: HashMap<String, String>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct NodeInfo {
pub name: String,
pub output_files: Vec<String>,
pub configuration: NodeConfiguration,
pub dynamic_configuration: Option<DynamicConfiguration>,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum MoveMoneyAmountType {
Percent,
Amount,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MoveMoneyRule {
pub from_account: String,
pub from_cc: String,
pub to_account: String,
pub to_cc: String,
pub value: f64,
pub amount_type: MoveMoneyAmountType,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MoveMoneyNode {
pub departments_path: String,
pub accounts_path: String,
pub gl_path: String,
pub rules: Vec<MoveMoneyRule>,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum JoinType {
Left,
Inner,
Right,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MergeJoin {
pub join_type: JoinType,
pub left_column_name: String,
pub right_column_name: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MergeNode {
pub input_files: Vec<String>,
pub joins: Vec<MergeJoin>,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum CodeRuleLanguage {
Javascript,
Rust,
Go,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct CodeRuleNode {
pub language: CodeRuleLanguage,
pub text: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Node {
pub id: i64,
pub info: NodeInfo,
pub dependent_node_ids: Vec<i64>,
// Lets us work out whether a task should be rerun, by
// inspecting the files at the output paths and check if
// their timestamp is before this
// TODO: Could just be seconds since unix epoch?
pub last_modified: chrono::DateTime<Local>,
}
impl Node {
pub fn has_dependent_nodes(&self) -> bool {
!self.dependent_node_ids.is_empty()
}
}
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
match node.info.configuration {
NodeConfiguration::FileNode => todo!(),
NodeConfiguration::MoveMoneyNode(_) => todo!(),
NodeConfiguration::MergeNode(_) => todo!(),
NodeConfiguration::DeriveNode(_) => todo!(),
NodeConfiguration::CodeRuleNode(_) => todo!(),
NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }),
NodeConfiguration::Dynamic => todo!(),
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Graph {
pub name: String,
pub nodes: Vec<Node>,
}
pub enum NodeStatus {
Completed,
Running,
// TODO: Error code?
Failed,
}
pub struct RunnableGraph {
pub graph: Graph,
pub node_statuses: HashMap<i64, NodeStatus>,
}
impl RunnableGraph {
pub fn from_graph(graph: Graph) -> RunnableGraph {
RunnableGraph {
graph,
node_statuses: HashMap::new(),
}
}
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
self.run(num_threads, Box::new(|node| get_runnable_node(node)))
}
// Make this not mutable, emit node status when required in a function or some other message
pub fn run(
&mut self,
num_threads: usize,
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>,
) -> anyhow::Result<()> {
let mut nodes = self.graph.nodes.clone();
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
nodes.sort_by(|a, b| {
if b.dependent_node_ids.contains(&a.id) {
return Ordering::Greater;
}
if a.dependent_node_ids.contains(&b.id) {
return Ordering::Less;
}
Ordering::Equal
});
let num_threads = min(num_threads, nodes.len());
// Sync version
if num_threads < 2 {
for node in &nodes {
self.node_statuses.insert(node.id, NodeStatus::Running);
match get_node_fn(node.clone()).run() {
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed),
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed),
};
}
return Ok(());
}
let mut running_nodes = HashSet::new();
let mut completed_nodes = HashSet::new();
let mut senders: Vec<Sender<Node>> = vec![];
let mut handles = vec![];
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
let node_fn = Arc::new(get_node_fn);
for n in 0..num_threads {
let finish_task = finish_task.clone();
let (tx, rx) = mpsc::channel();
senders.push(tx);
let node_fn = node_fn.clone();
let handle = thread::spawn(move || {
for node in rx {
node_fn(node.clone()).run();
finish_task.send((n, node));
}
println!("Thread {} finished", n);
});
handles.push(handle);
}
let mut running_threads = HashSet::new();
// Run nodes without dependencies. There'll always be at least one
for i in 0..nodes.len() {
// Ensure we don't overload threads
if i >= num_threads {
break;
}
if !nodes[i].has_dependent_nodes() {
running_threads.insert(i);
// Run all nodes that have no dependencies
let node = nodes.remove(i);
self.node_statuses.insert(node.id, NodeStatus::Running);
running_nodes.insert(node.id);
senders[i % senders.len()].send(node);
}
}
// Run each dependent node after a graph above finishes.
for (n, node) in listen_finish_task {
running_threads.remove(&n);
// TODO: Add error check here
self.node_statuses.insert(node.id, NodeStatus::Completed);
running_nodes.remove(&node.id);
completed_nodes.insert(node.id);
// Run all the nodes that can be run and aren't in completed
let mut i = 0;
while running_threads.len() < num_threads && i < nodes.len() {
if !running_nodes.contains(&nodes[i].id)
&& nodes[i]
.dependent_node_ids
.iter()
.all(|id| completed_nodes.contains(id))
{
let node = nodes.remove(i);
for i in 0..num_threads {
if !running_threads.contains(&i) {
senders[i].send(node);
break;
}
}
}
i += 1;
}
if nodes.is_empty() {
break;
}
}
for sender in senders {
drop(sender);
}
for handle in handles {
handle.join().expect("Failed to join thread");
}
println!("Process finished");
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use chrono::Local;
use super::{NodeConfiguration, RunnableGraph};
#[test]
fn test_basic() -> anyhow::Result<()> {
let mut graph = RunnableGraph {
node_statuses: HashMap::new(),
graph: super::Graph {
name: "Test".to_owned(),
nodes: vec![super::Node {
id: 1,
dependent_node_ids: vec![],
info: super::NodeInfo {
name: "Hello".to_owned(),
configuration: NodeConfiguration::FilterNode(super::FilterNode {
filters: vec![],
input_file_path: "".to_owned(),
output_file_path: "".to_owned(),
}),
output_files: vec![],
dynamic_configuration: None,
},
last_modified: Local::now(),
}],
},
};
graph.run_default_tasks(2)?;
Ok(())
}
}