Add custom graph executor and implement filter node to test it (#2)
Reviewed-on: vato007/coster-rs#2
This commit is contained in:
310
src/graph.rs
310
src/graph.rs
@@ -1,9 +1,23 @@
|
||||
use itertools::Itertools;
|
||||
use std::{
|
||||
cmp::{min, Ordering},
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{
|
||||
mpsc::{self, Sender},
|
||||
Arc,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
use chrono::Local;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::filter::FilterNodeRunner;
|
||||
use crate::{
|
||||
derive::DeriveNode,
|
||||
filter::{FilterNode, FilterNodeRunner},
|
||||
node::RunnableNode,
|
||||
upload_to_db::{UploadNode, UploadNodeRunner},
|
||||
};
|
||||
|
||||
// TODO: Break all this up into separate files in the graph module
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum NodeConfiguration {
|
||||
FileNode,
|
||||
@@ -12,6 +26,14 @@ pub enum NodeConfiguration {
|
||||
DeriveNode(DeriveNode),
|
||||
CodeRuleNode(CodeRuleNode),
|
||||
FilterNode(FilterNode),
|
||||
UploadNode(UploadNode),
|
||||
Dynamic,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct DynamicConfiguration {
|
||||
pub node_type: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
@@ -19,6 +41,7 @@ pub struct NodeInfo {
|
||||
pub name: String,
|
||||
pub output_files: Vec<String>,
|
||||
pub configuration: NodeConfiguration,
|
||||
pub dynamic_configuration: Option<DynamicConfiguration>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
@@ -65,70 +88,6 @@ pub struct MergeNode {
|
||||
pub joins: Vec<MergeJoin>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum DeriveColumnType {
|
||||
Column(String),
|
||||
Constant(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct MapOperation {
|
||||
pub mapped_value: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum DatePart {
|
||||
Year,
|
||||
Month,
|
||||
Week,
|
||||
Day,
|
||||
Hour,
|
||||
Minute,
|
||||
Second,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum SplitType {
|
||||
DateTime(String, DatePart),
|
||||
Numeric(String, isize),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum MatchComparisonType {
|
||||
Equal,
|
||||
GreaterThan,
|
||||
LessThan,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum DeriveOperation {
|
||||
Concat(Vec<DeriveColumnType>),
|
||||
Add(Vec<DeriveColumnType>),
|
||||
Multiply(Vec<DeriveColumnType>),
|
||||
Subtract(DeriveColumnType, DeriveColumnType),
|
||||
Divide(DeriveColumnType, DeriveColumnType),
|
||||
Map(String, Vec<MapOperation>),
|
||||
Split(String, SplitType),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct DeriveFilter {
|
||||
pub column_name: String,
|
||||
pub comparator: MatchComparisonType,
|
||||
pub match_value: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct DeriveRule {
|
||||
pub operations: Vec<DeriveOperation>,
|
||||
pub filters: Vec<DeriveFilter>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct DeriveNode {
|
||||
pub rules: Vec<DeriveRule>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub enum CodeRuleLanguage {
|
||||
Javascript,
|
||||
@@ -142,16 +101,16 @@ pub struct CodeRuleNode {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct FilterNode {
|
||||
pub filters: Vec<DeriveFilter>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone)]
|
||||
pub struct Node {
|
||||
pub id: i64,
|
||||
pub info: NodeInfo,
|
||||
pub dependent_node_ids: Vec<i64>,
|
||||
// Lets us work out whether a task should be rerun, by
|
||||
// inspecting the files at the output paths and check if
|
||||
// their timestamp is before this
|
||||
// TODO: Could just be seconds since unix epoch?
|
||||
pub last_modified: chrono::DateTime<Local>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
@@ -160,21 +119,16 @@ impl Node {
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<RunnableGraphNode> for Node {
|
||||
fn into(self) -> RunnableGraphNode {
|
||||
RunnableGraphNode {
|
||||
runnable_node: Box::new(FilterNodeRunner {}),
|
||||
// TODO: Construct node objects
|
||||
// runnable_node: match &self.info.configuration {
|
||||
// NodeConfiguration::FileNode => todo!(),
|
||||
// NodeConfiguration::MoveMoneyNode(_) => todo!(),
|
||||
// NodeConfiguration::MergeNode(_) => todo!(),
|
||||
// NodeConfiguration::DeriveNode(_) => todo!(),
|
||||
// NodeConfiguration::CodeRuleNode(_) => todo!(),
|
||||
// NodeConfiguration::FilterNode(_) => todo!(),
|
||||
// },
|
||||
node: self,
|
||||
}
|
||||
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
|
||||
match node.info.configuration {
|
||||
NodeConfiguration::FileNode => todo!(),
|
||||
NodeConfiguration::MoveMoneyNode(_) => todo!(),
|
||||
NodeConfiguration::MergeNode(_) => todo!(),
|
||||
NodeConfiguration::DeriveNode(_) => todo!(),
|
||||
NodeConfiguration::CodeRuleNode(_) => todo!(),
|
||||
NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
|
||||
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
|
||||
NodeConfiguration::Dynamic => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -184,32 +138,176 @@ pub struct Graph {
|
||||
pub nodes: Vec<Node>,
|
||||
}
|
||||
|
||||
pub trait RunnableNode {
|
||||
fn run(&self);
|
||||
}
|
||||
|
||||
pub struct RunnableGraphNode {
|
||||
pub runnable_node: Box<dyn RunnableNode>,
|
||||
pub node: Node,
|
||||
pub enum NodeStatus {
|
||||
Completed,
|
||||
Running,
|
||||
// TODO: Error code?
|
||||
Failed,
|
||||
}
|
||||
|
||||
pub struct RunnableGraph {
|
||||
pub name: String,
|
||||
pub nodes: Vec<RunnableGraphNode>,
|
||||
pub graph: Graph,
|
||||
pub node_statuses: HashMap<i64, NodeStatus>,
|
||||
}
|
||||
|
||||
impl RunnableGraph {
|
||||
pub fn from_graph(graph: &Graph) -> RunnableGraph {
|
||||
pub fn from_graph(graph: Graph) -> RunnableGraph {
|
||||
RunnableGraph {
|
||||
name: graph.name.clone(),
|
||||
nodes: graph
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|node| {
|
||||
let runnable_graph_node: RunnableGraphNode = node.clone().into();
|
||||
runnable_graph_node
|
||||
})
|
||||
.collect_vec(),
|
||||
graph,
|
||||
node_statuses: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
|
||||
self.run(num_threads, Box::new(|node| get_runnable_node(node)))
|
||||
}
|
||||
|
||||
// Make this not mutable, emit node status when required in a function or some other message
|
||||
pub fn run(
|
||||
&mut self,
|
||||
num_threads: usize,
|
||||
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut nodes = self.graph.nodes.clone();
|
||||
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
|
||||
nodes.sort_by(|a, b| {
|
||||
if b.dependent_node_ids.contains(&a.id) {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
if a.dependent_node_ids.contains(&b.id) {
|
||||
return Ordering::Less;
|
||||
}
|
||||
Ordering::Equal
|
||||
});
|
||||
|
||||
let num_threads = min(num_threads, nodes.len());
|
||||
|
||||
// Sync version
|
||||
if num_threads < 2 {
|
||||
for node in &nodes {
|
||||
self.node_statuses.insert(node.id, NodeStatus::Running);
|
||||
match get_node_fn(node.clone()).run() {
|
||||
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed),
|
||||
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed),
|
||||
};
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut running_nodes = HashSet::new();
|
||||
let mut completed_nodes = HashSet::new();
|
||||
|
||||
let mut senders: Vec<Sender<Node>> = vec![];
|
||||
let mut handles = vec![];
|
||||
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
|
||||
let node_fn = Arc::new(get_node_fn);
|
||||
for n in 0..num_threads {
|
||||
let finish_task = finish_task.clone();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
senders.push(tx);
|
||||
let node_fn = node_fn.clone();
|
||||
let handle = thread::spawn(move || {
|
||||
for node in rx {
|
||||
node_fn(node.clone()).run();
|
||||
finish_task.send((n, node));
|
||||
}
|
||||
println!("Thread {} finished", n);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let mut running_threads = HashSet::new();
|
||||
// Run nodes without dependencies. There'll always be at least one
|
||||
|
||||
for i in 0..nodes.len() {
|
||||
// Ensure we don't overload threads
|
||||
if i >= num_threads {
|
||||
break;
|
||||
}
|
||||
if !nodes[i].has_dependent_nodes() {
|
||||
running_threads.insert(i);
|
||||
// Run all nodes that have no dependencies
|
||||
let node = nodes.remove(i);
|
||||
self.node_statuses.insert(node.id, NodeStatus::Running);
|
||||
running_nodes.insert(node.id);
|
||||
senders[i % senders.len()].send(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Run each dependent node after a graph above finishes.
|
||||
for (n, node) in listen_finish_task {
|
||||
running_threads.remove(&n);
|
||||
// TODO: Add error check here
|
||||
self.node_statuses.insert(node.id, NodeStatus::Completed);
|
||||
running_nodes.remove(&node.id);
|
||||
completed_nodes.insert(node.id);
|
||||
// Run all the nodes that can be run and aren't in completed
|
||||
let mut i = 0;
|
||||
while running_threads.len() < num_threads && i < nodes.len() {
|
||||
if !running_nodes.contains(&nodes[i].id)
|
||||
&& nodes[i]
|
||||
.dependent_node_ids
|
||||
.iter()
|
||||
.all(|id| completed_nodes.contains(id))
|
||||
{
|
||||
let node = nodes.remove(i);
|
||||
for i in 0..num_threads {
|
||||
if !running_threads.contains(&i) {
|
||||
senders[i].send(node);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
if nodes.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for sender in senders {
|
||||
drop(sender);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Failed to join thread");
|
||||
}
|
||||
println!("Process finished");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chrono::Local;
|
||||
|
||||
use super::{NodeConfiguration, RunnableGraph};
|
||||
|
||||
#[test]
|
||||
fn test_basic() -> anyhow::Result<()> {
|
||||
let mut graph = RunnableGraph {
|
||||
node_statuses: HashMap::new(),
|
||||
graph: super::Graph {
|
||||
name: "Test".to_owned(),
|
||||
nodes: vec![super::Node {
|
||||
id: 1,
|
||||
dependent_node_ids: vec![],
|
||||
info: super::NodeInfo {
|
||||
name: "Hello".to_owned(),
|
||||
configuration: NodeConfiguration::FilterNode(super::FilterNode {
|
||||
filters: vec![],
|
||||
input_file_path: "".to_owned(),
|
||||
output_file_path: "".to_owned(),
|
||||
}),
|
||||
output_files: vec![],
|
||||
dynamic_configuration: None,
|
||||
},
|
||||
last_modified: Local::now(),
|
||||
}],
|
||||
},
|
||||
};
|
||||
graph.run_default_tasks(2)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user