339 lines
9.9 KiB
Rust
339 lines
9.9 KiB
Rust
use std::{
|
|
cmp::{min, Ordering},
|
|
collections::{HashMap, HashSet},
|
|
sync::{
|
|
mpsc::{self, Sender},
|
|
Arc,
|
|
},
|
|
thread,
|
|
};
|
|
|
|
use chrono::Local;
|
|
use schemars::JsonSchema;
|
|
use serde::{Deserialize, Serialize};
|
|
use split::{SplitNode, SplitNodeRunner};
|
|
|
|
use {
|
|
derive::DeriveNode,
|
|
filter::{FilterNode, FilterNodeRunner},
|
|
node::RunnableNode,
|
|
sql_rule::{SQLNode, SQLNodeRunner},
|
|
upload_to_db::{UploadNode, UploadNodeRunner},
|
|
};
|
|
|
|
mod derive;
|
|
mod filter;
|
|
mod node;
|
|
mod split;
|
|
mod sql_rule;
|
|
mod upload_to_db;
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub enum NodeConfiguration {
|
|
FileNode,
|
|
MoveMoneyNode(MoveMoneyNode),
|
|
MergeNode(MergeNode),
|
|
DeriveNode(DeriveNode),
|
|
CodeRuleNode(CodeRuleNode),
|
|
FilterNode(FilterNode),
|
|
UploadNode(UploadNode),
|
|
SQLNode(SQLNode),
|
|
Dynamic,
|
|
SplitNode(SplitNode),
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct DynamicConfiguration {
|
|
pub node_type: String,
|
|
pub parameters: HashMap<String, String>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct NodeInfo {
|
|
pub name: String,
|
|
pub output_files: Vec<String>,
|
|
pub configuration: NodeConfiguration,
|
|
pub dynamic_configuration: Option<DynamicConfiguration>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub enum MoveMoneyAmountType {
|
|
Percent,
|
|
Amount,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct MoveMoneyRule {
|
|
pub from_account: String,
|
|
pub from_cc: String,
|
|
pub to_account: String,
|
|
pub to_cc: String,
|
|
pub value: f64,
|
|
pub amount_type: MoveMoneyAmountType,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct MoveMoneyNode {
|
|
pub departments_path: String,
|
|
pub accounts_path: String,
|
|
pub gl_path: String,
|
|
pub rules: Vec<MoveMoneyRule>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub enum JoinType {
|
|
Left,
|
|
Inner,
|
|
Right,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct MergeJoin {
|
|
pub join_type: JoinType,
|
|
pub left_column_name: String,
|
|
pub right_column_name: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct MergeNode {
|
|
pub input_files: Vec<String>,
|
|
pub joins: Vec<MergeJoin>,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub enum CodeRuleLanguage {
|
|
Javascript,
|
|
Rust,
|
|
Go,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct CodeRuleNode {
|
|
pub language: CodeRuleLanguage,
|
|
pub text: String,
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct Node {
|
|
pub id: i64,
|
|
pub info: NodeInfo,
|
|
pub dependent_node_ids: Vec<i64>,
|
|
// Lets us work out whether a task should be rerun, by
|
|
// inspecting the files at the output paths and check if
|
|
// their timestamp is before this
|
|
// TODO: Could just be seconds since unix epoch?
|
|
pub last_modified: chrono::DateTime<Local>,
|
|
}
|
|
|
|
impl Node {
|
|
pub fn has_dependent_nodes(&self) -> bool {
|
|
!self.dependent_node_ids.is_empty()
|
|
}
|
|
}
|
|
|
|
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
|
|
match node.info.configuration {
|
|
NodeConfiguration::FileNode => todo!(),
|
|
NodeConfiguration::MoveMoneyNode(_) => todo!(),
|
|
NodeConfiguration::MergeNode(_) => todo!(),
|
|
NodeConfiguration::DeriveNode(_) => todo!(),
|
|
NodeConfiguration::CodeRuleNode(_) => todo!(),
|
|
NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
|
|
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
|
|
NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }),
|
|
NodeConfiguration::Dynamic => todo!(),
|
|
NodeConfiguration::SplitNode(split_node) => Box::new(SplitNodeRunner { split_node }),
|
|
}
|
|
}
|
|
|
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
|
pub struct Graph {
|
|
pub name: String,
|
|
pub nodes: Vec<Node>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum NodeStatus {
|
|
Completed,
|
|
Running,
|
|
// Error code
|
|
Failed(anyhow::Error),
|
|
}
|
|
|
|
pub struct RunnableGraph {
|
|
pub graph: Graph,
|
|
}
|
|
|
|
impl RunnableGraph {
|
|
pub fn from_graph(graph: Graph) -> RunnableGraph {
|
|
RunnableGraph { graph }
|
|
}
|
|
|
|
pub fn run_default_tasks<F>(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()>
|
|
where
|
|
F: Fn(i64, NodeStatus),
|
|
{
|
|
self.run(
|
|
num_threads,
|
|
Box::new(|node| get_runnable_node(node)),
|
|
status_changed,
|
|
)
|
|
}
|
|
|
|
// Make this not mutable, emit node status when required in a function or some other message
|
|
pub fn run<'a, F, StatusChanged>(
|
|
&self,
|
|
num_threads: usize,
|
|
get_node_fn: F,
|
|
node_status_changed_fn: StatusChanged,
|
|
) -> anyhow::Result<()>
|
|
where
|
|
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
|
|
StatusChanged: Fn(i64, NodeStatus),
|
|
{
|
|
let mut nodes = self.graph.nodes.clone();
|
|
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
|
|
nodes.sort_by(|a, b| {
|
|
if b.dependent_node_ids.contains(&a.id) {
|
|
return Ordering::Greater;
|
|
}
|
|
if a.dependent_node_ids.contains(&b.id) {
|
|
return Ordering::Less;
|
|
}
|
|
Ordering::Equal
|
|
});
|
|
|
|
let num_threads = min(num_threads, nodes.len());
|
|
|
|
// Sync version
|
|
if num_threads < 2 {
|
|
for node in &nodes {
|
|
node_status_changed_fn(node.id, NodeStatus::Running);
|
|
match get_node_fn(node.clone()).run() {
|
|
Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
|
|
Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)),
|
|
};
|
|
}
|
|
return Ok(());
|
|
}
|
|
|
|
let mut running_nodes = HashSet::new();
|
|
let mut completed_nodes = HashSet::new();
|
|
|
|
let mut senders: Vec<Sender<Node>> = vec![];
|
|
let mut handles = vec![];
|
|
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
|
|
let node_fn = Arc::new(get_node_fn);
|
|
for n in 0..num_threads {
|
|
let finish_task = finish_task.clone();
|
|
let (tx, rx) = mpsc::channel();
|
|
senders.push(tx);
|
|
let node_fn = node_fn.clone();
|
|
let handle = thread::spawn(move || {
|
|
for node in rx {
|
|
let status = match node_fn(node.clone()).run() {
|
|
Ok(_) => NodeStatus::Completed,
|
|
Err(err) => NodeStatus::Failed(err),
|
|
};
|
|
finish_task
|
|
.send((n, node, status))
|
|
.expect("Failed to notify node status completion");
|
|
}
|
|
println!("Thread {} finished", n);
|
|
});
|
|
handles.push(handle);
|
|
}
|
|
|
|
let mut running_threads = HashSet::new();
|
|
// Run nodes without dependencies. There'll always be at least one
|
|
|
|
for i in 0..nodes.len() {
|
|
// Ensure we don't overload threads
|
|
if i >= num_threads {
|
|
break;
|
|
}
|
|
if !nodes[i].has_dependent_nodes() {
|
|
running_threads.insert(i);
|
|
// Run all nodes that have no dependencies
|
|
let node = nodes.remove(i);
|
|
node_status_changed_fn(node.id, NodeStatus::Running);
|
|
running_nodes.insert(node.id);
|
|
senders[i % senders.len()].send(node)?;
|
|
}
|
|
}
|
|
|
|
// Run each dependent node after a graph above finishes.
|
|
for (n, node, error) in listen_finish_task {
|
|
running_threads.remove(&n);
|
|
node_status_changed_fn(node.id, error);
|
|
running_nodes.remove(&node.id);
|
|
completed_nodes.insert(node.id);
|
|
// Run all the nodes that can be run and aren't in completed
|
|
let mut i = 0;
|
|
while running_threads.len() < num_threads && i < nodes.len() {
|
|
if !running_nodes.contains(&nodes[i].id)
|
|
&& nodes[i]
|
|
.dependent_node_ids
|
|
.iter()
|
|
.all(|id| completed_nodes.contains(id))
|
|
{
|
|
let node = nodes.remove(i);
|
|
for i in 0..num_threads {
|
|
if !running_threads.contains(&i) {
|
|
senders[i].send(node)?;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
i += 1;
|
|
}
|
|
if nodes.is_empty() {
|
|
break;
|
|
}
|
|
}
|
|
for sender in senders {
|
|
drop(sender);
|
|
}
|
|
|
|
for handle in handles {
|
|
handle.join().expect("Failed to join thread");
|
|
}
|
|
println!("Process finished");
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
|
|
use chrono::Local;
|
|
|
|
use super::{NodeConfiguration, RunnableGraph};
|
|
|
|
#[test]
|
|
fn test_basic() -> anyhow::Result<()> {
|
|
let graph = RunnableGraph {
|
|
graph: super::Graph {
|
|
name: "Test".to_owned(),
|
|
nodes: vec![super::Node {
|
|
id: 1,
|
|
dependent_node_ids: vec![],
|
|
info: super::NodeInfo {
|
|
name: "Hello".to_owned(),
|
|
configuration: NodeConfiguration::FilterNode(super::FilterNode {
|
|
filters: vec![],
|
|
input_file_path: "".to_owned(),
|
|
output_file_path: "".to_owned(),
|
|
}),
|
|
output_files: vec![],
|
|
dynamic_configuration: None,
|
|
},
|
|
last_modified: Local::now(),
|
|
}],
|
|
},
|
|
};
|
|
graph.run_default_tasks(2, |_, _| {})?;
|
|
Ok(())
|
|
}
|
|
}
|