Start adding row-level splitting, refactor cli and graph into subcrates

This commit is contained in:
2024-08-09 22:13:43 +09:30
parent 3cdaa81da1
commit 0ee88e3a99
11 changed files with 259 additions and 110 deletions

338
src/graph/mod.rs Normal file
View File

@@ -0,0 +1,338 @@
use std::{
cmp::{min, Ordering},
collections::{HashMap, HashSet},
sync::{
mpsc::{self, Sender},
Arc,
},
thread,
};
use chrono::Local;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use split::{SplitNode, SplitNodeRunner};
use {
derive::DeriveNode,
filter::{FilterNode, FilterNodeRunner},
node::RunnableNode,
sql_rule::{SQLNode, SQLNodeRunner},
upload_to_db::{UploadNode, UploadNodeRunner},
};
mod derive;
mod filter;
mod node;
mod split;
mod sql_rule;
mod upload_to_db;
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum NodeConfiguration {
FileNode,
MoveMoneyNode(MoveMoneyNode),
MergeNode(MergeNode),
DeriveNode(DeriveNode),
CodeRuleNode(CodeRuleNode),
FilterNode(FilterNode),
UploadNode(UploadNode),
SQLNode(SQLNode),
Dynamic,
SplitNode(SplitNode),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct DynamicConfiguration {
pub node_type: String,
pub parameters: HashMap<String, String>,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct NodeInfo {
pub name: String,
pub output_files: Vec<String>,
pub configuration: NodeConfiguration,
pub dynamic_configuration: Option<DynamicConfiguration>,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum MoveMoneyAmountType {
Percent,
Amount,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct MoveMoneyRule {
pub from_account: String,
pub from_cc: String,
pub to_account: String,
pub to_cc: String,
pub value: f64,
pub amount_type: MoveMoneyAmountType,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct MoveMoneyNode {
pub departments_path: String,
pub accounts_path: String,
pub gl_path: String,
pub rules: Vec<MoveMoneyRule>,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum JoinType {
Left,
Inner,
Right,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct MergeJoin {
pub join_type: JoinType,
pub left_column_name: String,
pub right_column_name: String,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct MergeNode {
pub input_files: Vec<String>,
pub joins: Vec<MergeJoin>,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum CodeRuleLanguage {
Javascript,
Rust,
Go,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct CodeRuleNode {
pub language: CodeRuleLanguage,
pub text: String,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct Node {
pub id: i64,
pub info: NodeInfo,
pub dependent_node_ids: Vec<i64>,
// Lets us work out whether a task should be rerun, by
// inspecting the files at the output paths and check if
// their timestamp is before this
// TODO: Could just be seconds since unix epoch?
pub last_modified: chrono::DateTime<Local>,
}
impl Node {
pub fn has_dependent_nodes(&self) -> bool {
!self.dependent_node_ids.is_empty()
}
}
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
match node.info.configuration {
NodeConfiguration::FileNode => todo!(),
NodeConfiguration::MoveMoneyNode(_) => todo!(),
NodeConfiguration::MergeNode(_) => todo!(),
NodeConfiguration::DeriveNode(_) => todo!(),
NodeConfiguration::CodeRuleNode(_) => todo!(),
NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }),
NodeConfiguration::Dynamic => todo!(),
NodeConfiguration::SplitNode(split_node) => Box::new(SplitNodeRunner { split_node }),
}
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct Graph {
pub name: String,
pub nodes: Vec<Node>,
}
#[derive(Debug)]
pub enum NodeStatus {
Completed,
Running,
// Error code
Failed(anyhow::Error),
}
pub struct RunnableGraph {
pub graph: Graph,
}
impl RunnableGraph {
pub fn from_graph(graph: Graph) -> RunnableGraph {
RunnableGraph { graph }
}
pub fn run_default_tasks<F>(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()>
where
F: Fn(i64, NodeStatus),
{
self.run(
num_threads,
Box::new(|node| get_runnable_node(node)),
status_changed,
)
}
// Make this not mutable, emit node status when required in a function or some other message
pub fn run<'a, F, StatusChanged>(
&self,
num_threads: usize,
get_node_fn: F,
node_status_changed_fn: StatusChanged,
) -> anyhow::Result<()>
where
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
StatusChanged: Fn(i64, NodeStatus),
{
let mut nodes = self.graph.nodes.clone();
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
nodes.sort_by(|a, b| {
if b.dependent_node_ids.contains(&a.id) {
return Ordering::Greater;
}
if a.dependent_node_ids.contains(&b.id) {
return Ordering::Less;
}
Ordering::Equal
});
let num_threads = min(num_threads, nodes.len());
// Sync version
if num_threads < 2 {
for node in &nodes {
node_status_changed_fn(node.id, NodeStatus::Running);
match get_node_fn(node.clone()).run() {
Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)),
};
}
return Ok(());
}
let mut running_nodes = HashSet::new();
let mut completed_nodes = HashSet::new();
let mut senders: Vec<Sender<Node>> = vec![];
let mut handles = vec![];
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
let node_fn = Arc::new(get_node_fn);
for n in 0..num_threads {
let finish_task = finish_task.clone();
let (tx, rx) = mpsc::channel();
senders.push(tx);
let node_fn = node_fn.clone();
let handle = thread::spawn(move || {
for node in rx {
let status = match node_fn(node.clone()).run() {
Ok(_) => NodeStatus::Completed,
Err(err) => NodeStatus::Failed(err),
};
finish_task
.send((n, node, status))
.expect("Failed to notify node status completion");
}
println!("Thread {} finished", n);
});
handles.push(handle);
}
let mut running_threads = HashSet::new();
// Run nodes without dependencies. There'll always be at least one
for i in 0..nodes.len() {
// Ensure we don't overload threads
if i >= num_threads {
break;
}
if !nodes[i].has_dependent_nodes() {
running_threads.insert(i);
// Run all nodes that have no dependencies
let node = nodes.remove(i);
node_status_changed_fn(node.id, NodeStatus::Running);
running_nodes.insert(node.id);
senders[i % senders.len()].send(node)?;
}
}
// Run each dependent node after a graph above finishes.
for (n, node, error) in listen_finish_task {
running_threads.remove(&n);
node_status_changed_fn(node.id, error);
running_nodes.remove(&node.id);
completed_nodes.insert(node.id);
// Run all the nodes that can be run and aren't in completed
let mut i = 0;
while running_threads.len() < num_threads && i < nodes.len() {
if !running_nodes.contains(&nodes[i].id)
&& nodes[i]
.dependent_node_ids
.iter()
.all(|id| completed_nodes.contains(id))
{
let node = nodes.remove(i);
for i in 0..num_threads {
if !running_threads.contains(&i) {
senders[i].send(node)?;
break;
}
}
}
i += 1;
}
if nodes.is_empty() {
break;
}
}
for sender in senders {
drop(sender);
}
for handle in handles {
handle.join().expect("Failed to join thread");
}
println!("Process finished");
Ok(())
}
}
#[cfg(test)]
mod tests {
use chrono::Local;
use super::{NodeConfiguration, RunnableGraph};
#[test]
fn test_basic() -> anyhow::Result<()> {
let graph = RunnableGraph {
graph: super::Graph {
name: "Test".to_owned(),
nodes: vec![super::Node {
id: 1,
dependent_node_ids: vec![],
info: super::NodeInfo {
name: "Hello".to_owned(),
configuration: NodeConfiguration::FilterNode(super::FilterNode {
filters: vec![],
input_file_path: "".to_owned(),
output_file_path: "".to_owned(),
}),
output_files: vec![],
dynamic_configuration: None,
},
last_modified: Local::now(),
}],
},
};
graph.run_default_tasks(2, |_, _| {})?;
Ok(())
}
}