Start adding row-level splitting, refactor cli and graph into subcrates
This commit is contained in:
319
src/graph/derive.rs
Normal file
319
src/graph/derive.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
use std::{collections::BTreeMap, str::FromStr};
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::io::{RecordDeserializer, RecordSerializer};
|
||||
|
||||
use super::node::RunnableNode;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum DeriveColumnType {
|
||||
Column(String),
|
||||
Constant(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema, PartialEq)]
|
||||
pub enum MatchComparisonType {
|
||||
Equal,
|
||||
GreaterThan,
|
||||
LessThan,
|
||||
NotEqual,
|
||||
In,
|
||||
NotIn,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum DeriveOperation {
|
||||
Concat(Vec<DeriveColumnType>),
|
||||
Add(Vec<DeriveColumnType>),
|
||||
Multiply(Vec<DeriveColumnType>),
|
||||
Subtract(Vec<DeriveColumnType>),
|
||||
Divide(Vec<DeriveColumnType>),
|
||||
Map(String),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum ValueType {
|
||||
String,
|
||||
Integer,
|
||||
Float,
|
||||
Boolean,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct DeriveFilter {
|
||||
pub column_name: String,
|
||||
pub comparator: MatchComparisonType,
|
||||
pub match_value: Vec<String>,
|
||||
pub value_type: ValueType,
|
||||
}
|
||||
|
||||
pub enum Comparator<T: PartialOrd> {
|
||||
Equal(T),
|
||||
NotEqual(T),
|
||||
GreaterThan(T),
|
||||
LessThan(T),
|
||||
In(Vec<T>),
|
||||
NotIn(Vec<T>),
|
||||
}
|
||||
|
||||
impl<T: PartialOrd> Comparator<T> {
|
||||
pub fn is_valid(&self, value: T) -> bool {
|
||||
match self {
|
||||
Comparator::Equal(v) => value == *v,
|
||||
Comparator::NotEqual(v) => value != *v,
|
||||
Comparator::GreaterThan(v) => value > *v,
|
||||
Comparator::LessThan(v) => value < *v,
|
||||
Comparator::In(v) => v.contains(&value),
|
||||
Comparator::NotIn(v) => !v.contains(&value),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FieldName {
|
||||
// Name of the field this validator should work on
|
||||
fn get_field_name(&self) -> String;
|
||||
}
|
||||
|
||||
pub type DataValidators = Vec<Box<dyn DataValidator>>;
|
||||
|
||||
pub trait DataValidator: FieldName {
|
||||
// Whether the given value is valid for the validator
|
||||
fn is_valid(&self, s: &str) -> bool;
|
||||
}
|
||||
|
||||
pub struct FilterRule<T: PartialOrd> {
|
||||
pub column_name: String,
|
||||
pub comparator: Comparator<T>,
|
||||
}
|
||||
|
||||
impl<T: PartialOrd> FieldName for FilterRule<T> {
|
||||
fn get_field_name(&self) -> String {
|
||||
self.column_name.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FromStr + PartialOrd> DataValidator for FilterRule<T> {
|
||||
fn is_valid(&self, s: &str) -> bool {
|
||||
s.parse().map_or(false, |f| self.comparator.is_valid(f))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_filter_rules(filters: &Vec<DeriveFilter>) -> anyhow::Result<Vec<Box<dyn DataValidator>>> {
|
||||
filters
|
||||
.iter()
|
||||
.map(|filter| to_filter_rule(filter))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn to_filter_rule(filter: &DeriveFilter) -> anyhow::Result<Box<dyn DataValidator>> {
|
||||
let value = &filter.match_value;
|
||||
match filter.value_type {
|
||||
ValueType::String => Ok(Box::new(get_filter_rule(filter, value.clone()))),
|
||||
ValueType::Integer => Ok(Box::new(get_filter_rule(
|
||||
filter,
|
||||
parse_values(value, |value| value.parse::<i64>())?,
|
||||
))),
|
||||
ValueType::Float => Ok(Box::new(get_filter_rule(
|
||||
filter,
|
||||
parse_values(value, |value| value.parse::<f64>())?,
|
||||
))),
|
||||
ValueType::Boolean => Ok(Box::new(get_filter_rule(
|
||||
filter,
|
||||
parse_values(value, |value| value.parse::<bool>())?,
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_values<T, E, F>(values: &Vec<String>, parse: F) -> Result<Vec<T>, E>
|
||||
where
|
||||
F: Fn(&String) -> Result<T, E>,
|
||||
{
|
||||
let values: Result<Vec<T>, E> = values.into_iter().map(|value| parse(value)).collect();
|
||||
values
|
||||
}
|
||||
|
||||
fn get_filter_rule<T: PartialOrd + Clone>(filter: &DeriveFilter, value: Vec<T>) -> FilterRule<T> {
|
||||
FilterRule {
|
||||
column_name: filter.column_name.clone(),
|
||||
comparator: match filter.comparator {
|
||||
MatchComparisonType::Equal => Comparator::Equal(value[0].clone()),
|
||||
MatchComparisonType::GreaterThan => Comparator::GreaterThan(value[0].clone()),
|
||||
MatchComparisonType::LessThan => Comparator::LessThan(value[0].clone()),
|
||||
MatchComparisonType::NotEqual => Comparator::NotEqual(value[0].clone()),
|
||||
MatchComparisonType::In => Comparator::In(value),
|
||||
MatchComparisonType::NotIn => Comparator::NotIn(value),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct DeriveColumnOperation {
|
||||
pub column_name: String,
|
||||
pub operation: DeriveOperation,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct DeriveRule {
|
||||
pub operations: Vec<DeriveColumnOperation>,
|
||||
pub filters: Vec<DeriveFilter>,
|
||||
pub copy_all_columns: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct DeriveNode {
|
||||
pub rules: Vec<DeriveRule>,
|
||||
pub input_file_path: String,
|
||||
pub output_file_path: String,
|
||||
pub copy_all_columns: bool,
|
||||
}
|
||||
|
||||
pub struct RunnableDeriveRule {
|
||||
pub operations: Vec<DeriveColumnOperation>,
|
||||
pub filters: Vec<Box<dyn DataValidator>>,
|
||||
}
|
||||
|
||||
impl DeriveRule {
|
||||
fn to_runnable_rule(&self) -> anyhow::Result<RunnableDeriveRule> {
|
||||
let filters = to_filter_rules(&self.filters)?;
|
||||
Ok(RunnableDeriveRule {
|
||||
operations: self.operations.clone(),
|
||||
filters,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_line_valid(line: &BTreeMap<String, String>, rules: &DataValidators) -> bool {
|
||||
rules.iter().all(|rule| {
|
||||
line.get(&rule.get_field_name()).map_or(true, |value| {
|
||||
if value.trim().is_empty() {
|
||||
true
|
||||
} else {
|
||||
rule.is_valid(value)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn concat_columns(line: &BTreeMap<String, String>, columns: &Vec<DeriveColumnType>) -> String {
|
||||
columns
|
||||
.iter()
|
||||
.map(|col| match col {
|
||||
DeriveColumnType::Column(column) => line
|
||||
.get(column)
|
||||
.map(|column| column.clone())
|
||||
.unwrap_or("".to_owned()),
|
||||
DeriveColumnType::Constant(constant) => constant.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn reduce_numeric_columns<F>(
|
||||
line: &BTreeMap<String, String>,
|
||||
columns: &Vec<DeriveColumnType>,
|
||||
reducer: F,
|
||||
) -> String
|
||||
where
|
||||
F: Fn(f64, f64) -> f64,
|
||||
{
|
||||
let value = columns
|
||||
.iter()
|
||||
.map(|col| match col {
|
||||
DeriveColumnType::Column(column) => line
|
||||
.get(column)
|
||||
.map(|value| value.parse::<f64>().ok())
|
||||
.flatten(),
|
||||
DeriveColumnType::Constant(constant) => constant.parse().ok(),
|
||||
})
|
||||
.filter(|value| value.is_some())
|
||||
.map(|value| value.unwrap())
|
||||
.reduce(reducer);
|
||||
value
|
||||
.map(|value| value.to_string())
|
||||
.unwrap_or("".to_owned())
|
||||
}
|
||||
|
||||
fn derive_line(
|
||||
line: BTreeMap<String, String>,
|
||||
rules: &Vec<RunnableDeriveRule>,
|
||||
output: &mut impl RecordSerializer,
|
||||
copy_all_columns: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut output_line;
|
||||
if copy_all_columns {
|
||||
output_line = line.clone();
|
||||
} else {
|
||||
output_line = BTreeMap::new();
|
||||
}
|
||||
for rule in rules {
|
||||
if !is_line_valid(&line, &rule.filters) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for operation in &rule.operations {
|
||||
let value = match &operation.operation {
|
||||
DeriveOperation::Concat(concat) => concat_columns(&line, concat),
|
||||
DeriveOperation::Add(columns) => {
|
||||
reduce_numeric_columns(&line, columns, |a, b| a + b)
|
||||
}
|
||||
DeriveOperation::Multiply(columns) => {
|
||||
reduce_numeric_columns(&line, columns, |a, b| a * b)
|
||||
}
|
||||
DeriveOperation::Subtract(columns) => {
|
||||
reduce_numeric_columns(&line, columns, |a, b| a - b)
|
||||
}
|
||||
DeriveOperation::Divide(columns) => {
|
||||
reduce_numeric_columns(&line, columns, |a, b| a / b)
|
||||
}
|
||||
DeriveOperation::Map(mapped_value) => mapped_value.clone(),
|
||||
};
|
||||
output_line.insert(operation.column_name.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
output.serialize(output_line)
|
||||
}
|
||||
|
||||
fn derive(
|
||||
rules: &Vec<RunnableDeriveRule>,
|
||||
input: &mut impl RecordDeserializer,
|
||||
output: &mut impl RecordSerializer,
|
||||
copy_all_columns: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
if let Some(line) = input.deserialize()? {
|
||||
let line: BTreeMap<String, String> = line;
|
||||
output.write_header(&line)?;
|
||||
derive_line(line, rules, output, copy_all_columns)?;
|
||||
|
||||
while let Some(line) = input.deserialize()? {
|
||||
let line: BTreeMap<String, String> = line;
|
||||
derive_line(line, rules, output, copy_all_columns)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct DeriveNodeRunner {
|
||||
derive_node: DeriveNode,
|
||||
}
|
||||
|
||||
impl RunnableNode for DeriveNodeRunner {
|
||||
fn run(&self) -> anyhow::Result<()> {
|
||||
let mut reader = csv::Reader::from_path(&self.derive_node.input_file_path)?;
|
||||
let mut writer = csv::Writer::from_path(&self.derive_node.output_file_path)?;
|
||||
let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self
|
||||
.derive_node
|
||||
.rules
|
||||
.iter()
|
||||
.map(|rule| rule.to_runnable_rule())
|
||||
.collect();
|
||||
let rules = rules?;
|
||||
derive(
|
||||
&rules,
|
||||
&mut reader,
|
||||
&mut writer,
|
||||
self.derive_node.copy_all_columns,
|
||||
)
|
||||
}
|
||||
}
|
||||
114
src/graph/filter.rs
Normal file
114
src/graph/filter.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::io::{RecordDeserializer, RecordSerializer};
|
||||
|
||||
use super::derive::{DataValidators, DeriveFilter};
|
||||
|
||||
use super::derive;
|
||||
use super::node::RunnableNode;
|
||||
|
||||
/**
|
||||
* Write all lines from the input file to the output file, skipping records
|
||||
* that don't satisfy the filter criteria
|
||||
*/
|
||||
pub fn filter_file(
|
||||
rules: &DataValidators,
|
||||
input: &mut impl RecordDeserializer,
|
||||
output: &mut impl RecordSerializer,
|
||||
) -> anyhow::Result<()> {
|
||||
if let Some(line) = input.deserialize()? {
|
||||
let line: BTreeMap<String, String> = line;
|
||||
output.write_header(&line)?;
|
||||
|
||||
if derive::is_line_valid(&line, &rules) {
|
||||
output.write_record(&line)?;
|
||||
}
|
||||
|
||||
while let Some(line) = input.deserialize()? {
|
||||
let line: BTreeMap<String, String> = line;
|
||||
if derive::is_line_valid(&line, rules) {
|
||||
output.write_record(&line)?;
|
||||
}
|
||||
}
|
||||
output.flush()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct FilterNode {
|
||||
pub filters: Vec<DeriveFilter>,
|
||||
pub input_file_path: String,
|
||||
pub output_file_path: String,
|
||||
}
|
||||
|
||||
pub struct FilterNodeRunner {
|
||||
pub filter_node: FilterNode,
|
||||
}
|
||||
|
||||
impl RunnableNode for FilterNodeRunner {
|
||||
fn run(&self) -> anyhow::Result<()> {
|
||||
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
|
||||
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
|
||||
let rules = derive::to_filter_rules(&self.filter_node.filters)?;
|
||||
filter_file(&rules, &mut reader, &mut writer)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::derive::{Comparator, FilterRule};
|
||||
|
||||
use super::filter_file;
|
||||
|
||||
#[test]
|
||||
fn no_filters_passes_through() -> anyhow::Result<()> {
|
||||
let records = "Column1,Column2
|
||||
Value1,Value2
|
||||
Value3,Value4
|
||||
";
|
||||
let mut reader: csv::Reader<&[u8]> = csv::Reader::from_reader(records.as_bytes());
|
||||
let mut writer = csv::Writer::from_writer(vec![]);
|
||||
filter_file(&vec![], &mut reader, &mut writer)?;
|
||||
let result = String::from_utf8(writer.into_inner()?)?;
|
||||
assert_eq!(
|
||||
records, result,
|
||||
"Should not modify input when no filters are defined"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filters_data() -> anyhow::Result<()> {
|
||||
let records = "Column1,Column2
|
||||
Value1,Value2
|
||||
Value3,Value4
|
||||
";
|
||||
let mut reader: csv::Reader<&[u8]> = csv::Reader::from_reader(records.as_bytes());
|
||||
let mut writer = csv::Writer::from_writer(vec![]);
|
||||
filter_file(
|
||||
&vec![Box::new(FilterRule {
|
||||
column_name: "Column1".to_owned(),
|
||||
comparator: Comparator::NotEqual("Value3".to_owned()),
|
||||
})],
|
||||
&mut reader,
|
||||
&mut writer,
|
||||
)?;
|
||||
let result = String::from_utf8(writer.into_inner()?)?;
|
||||
assert_eq!(
|
||||
"Column1,Column2
|
||||
Value1,Value2
|
||||
",
|
||||
result,
|
||||
"Should filter out second record due to filter rules"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_print_header_when_no_rules_pass() {}
|
||||
}
|
||||
338
src/graph/mod.rs
Normal file
338
src/graph/mod.rs
Normal file
@@ -0,0 +1,338 @@
|
||||
use std::{
|
||||
cmp::{min, Ordering},
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{
|
||||
mpsc::{self, Sender},
|
||||
Arc,
|
||||
},
|
||||
thread,
|
||||
};
|
||||
|
||||
use chrono::Local;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use split::{SplitNode, SplitNodeRunner};
|
||||
|
||||
use {
|
||||
derive::DeriveNode,
|
||||
filter::{FilterNode, FilterNodeRunner},
|
||||
node::RunnableNode,
|
||||
sql_rule::{SQLNode, SQLNodeRunner},
|
||||
upload_to_db::{UploadNode, UploadNodeRunner},
|
||||
};
|
||||
|
||||
mod derive;
|
||||
mod filter;
|
||||
mod node;
|
||||
mod split;
|
||||
mod sql_rule;
|
||||
mod upload_to_db;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum NodeConfiguration {
|
||||
FileNode,
|
||||
MoveMoneyNode(MoveMoneyNode),
|
||||
MergeNode(MergeNode),
|
||||
DeriveNode(DeriveNode),
|
||||
CodeRuleNode(CodeRuleNode),
|
||||
FilterNode(FilterNode),
|
||||
UploadNode(UploadNode),
|
||||
SQLNode(SQLNode),
|
||||
Dynamic,
|
||||
SplitNode(SplitNode),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct DynamicConfiguration {
|
||||
pub node_type: String,
|
||||
pub parameters: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct NodeInfo {
|
||||
pub name: String,
|
||||
pub output_files: Vec<String>,
|
||||
pub configuration: NodeConfiguration,
|
||||
pub dynamic_configuration: Option<DynamicConfiguration>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum MoveMoneyAmountType {
|
||||
Percent,
|
||||
Amount,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct MoveMoneyRule {
|
||||
pub from_account: String,
|
||||
pub from_cc: String,
|
||||
pub to_account: String,
|
||||
pub to_cc: String,
|
||||
pub value: f64,
|
||||
pub amount_type: MoveMoneyAmountType,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct MoveMoneyNode {
|
||||
pub departments_path: String,
|
||||
pub accounts_path: String,
|
||||
pub gl_path: String,
|
||||
pub rules: Vec<MoveMoneyRule>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum JoinType {
|
||||
Left,
|
||||
Inner,
|
||||
Right,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct MergeJoin {
|
||||
pub join_type: JoinType,
|
||||
pub left_column_name: String,
|
||||
pub right_column_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct MergeNode {
|
||||
pub input_files: Vec<String>,
|
||||
pub joins: Vec<MergeJoin>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum CodeRuleLanguage {
|
||||
Javascript,
|
||||
Rust,
|
||||
Go,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct CodeRuleNode {
|
||||
pub language: CodeRuleLanguage,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct Node {
|
||||
pub id: i64,
|
||||
pub info: NodeInfo,
|
||||
pub dependent_node_ids: Vec<i64>,
|
||||
// Lets us work out whether a task should be rerun, by
|
||||
// inspecting the files at the output paths and check if
|
||||
// their timestamp is before this
|
||||
// TODO: Could just be seconds since unix epoch?
|
||||
pub last_modified: chrono::DateTime<Local>,
|
||||
}
|
||||
|
||||
impl Node {
|
||||
pub fn has_dependent_nodes(&self) -> bool {
|
||||
!self.dependent_node_ids.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
|
||||
match node.info.configuration {
|
||||
NodeConfiguration::FileNode => todo!(),
|
||||
NodeConfiguration::MoveMoneyNode(_) => todo!(),
|
||||
NodeConfiguration::MergeNode(_) => todo!(),
|
||||
NodeConfiguration::DeriveNode(_) => todo!(),
|
||||
NodeConfiguration::CodeRuleNode(_) => todo!(),
|
||||
NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
|
||||
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
|
||||
NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }),
|
||||
NodeConfiguration::Dynamic => todo!(),
|
||||
NodeConfiguration::SplitNode(split_node) => Box::new(SplitNodeRunner { split_node }),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct Graph {
|
||||
pub name: String,
|
||||
pub nodes: Vec<Node>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum NodeStatus {
|
||||
Completed,
|
||||
Running,
|
||||
// Error code
|
||||
Failed(anyhow::Error),
|
||||
}
|
||||
|
||||
pub struct RunnableGraph {
|
||||
pub graph: Graph,
|
||||
}
|
||||
|
||||
impl RunnableGraph {
|
||||
pub fn from_graph(graph: Graph) -> RunnableGraph {
|
||||
RunnableGraph { graph }
|
||||
}
|
||||
|
||||
pub fn run_default_tasks<F>(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()>
|
||||
where
|
||||
F: Fn(i64, NodeStatus),
|
||||
{
|
||||
self.run(
|
||||
num_threads,
|
||||
Box::new(|node| get_runnable_node(node)),
|
||||
status_changed,
|
||||
)
|
||||
}
|
||||
|
||||
// Make this not mutable, emit node status when required in a function or some other message
|
||||
pub fn run<'a, F, StatusChanged>(
|
||||
&self,
|
||||
num_threads: usize,
|
||||
get_node_fn: F,
|
||||
node_status_changed_fn: StatusChanged,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
|
||||
StatusChanged: Fn(i64, NodeStatus),
|
||||
{
|
||||
let mut nodes = self.graph.nodes.clone();
|
||||
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
|
||||
nodes.sort_by(|a, b| {
|
||||
if b.dependent_node_ids.contains(&a.id) {
|
||||
return Ordering::Greater;
|
||||
}
|
||||
if a.dependent_node_ids.contains(&b.id) {
|
||||
return Ordering::Less;
|
||||
}
|
||||
Ordering::Equal
|
||||
});
|
||||
|
||||
let num_threads = min(num_threads, nodes.len());
|
||||
|
||||
// Sync version
|
||||
if num_threads < 2 {
|
||||
for node in &nodes {
|
||||
node_status_changed_fn(node.id, NodeStatus::Running);
|
||||
match get_node_fn(node.clone()).run() {
|
||||
Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
|
||||
Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)),
|
||||
};
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut running_nodes = HashSet::new();
|
||||
let mut completed_nodes = HashSet::new();
|
||||
|
||||
let mut senders: Vec<Sender<Node>> = vec![];
|
||||
let mut handles = vec![];
|
||||
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
|
||||
let node_fn = Arc::new(get_node_fn);
|
||||
for n in 0..num_threads {
|
||||
let finish_task = finish_task.clone();
|
||||
let (tx, rx) = mpsc::channel();
|
||||
senders.push(tx);
|
||||
let node_fn = node_fn.clone();
|
||||
let handle = thread::spawn(move || {
|
||||
for node in rx {
|
||||
let status = match node_fn(node.clone()).run() {
|
||||
Ok(_) => NodeStatus::Completed,
|
||||
Err(err) => NodeStatus::Failed(err),
|
||||
};
|
||||
finish_task
|
||||
.send((n, node, status))
|
||||
.expect("Failed to notify node status completion");
|
||||
}
|
||||
println!("Thread {} finished", n);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let mut running_threads = HashSet::new();
|
||||
// Run nodes without dependencies. There'll always be at least one
|
||||
|
||||
for i in 0..nodes.len() {
|
||||
// Ensure we don't overload threads
|
||||
if i >= num_threads {
|
||||
break;
|
||||
}
|
||||
if !nodes[i].has_dependent_nodes() {
|
||||
running_threads.insert(i);
|
||||
// Run all nodes that have no dependencies
|
||||
let node = nodes.remove(i);
|
||||
node_status_changed_fn(node.id, NodeStatus::Running);
|
||||
running_nodes.insert(node.id);
|
||||
senders[i % senders.len()].send(node)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Run each dependent node after a graph above finishes.
|
||||
for (n, node, error) in listen_finish_task {
|
||||
running_threads.remove(&n);
|
||||
node_status_changed_fn(node.id, error);
|
||||
running_nodes.remove(&node.id);
|
||||
completed_nodes.insert(node.id);
|
||||
// Run all the nodes that can be run and aren't in completed
|
||||
let mut i = 0;
|
||||
while running_threads.len() < num_threads && i < nodes.len() {
|
||||
if !running_nodes.contains(&nodes[i].id)
|
||||
&& nodes[i]
|
||||
.dependent_node_ids
|
||||
.iter()
|
||||
.all(|id| completed_nodes.contains(id))
|
||||
{
|
||||
let node = nodes.remove(i);
|
||||
for i in 0..num_threads {
|
||||
if !running_threads.contains(&i) {
|
||||
senders[i].send(node)?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
if nodes.is_empty() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
for sender in senders {
|
||||
drop(sender);
|
||||
}
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Failed to join thread");
|
||||
}
|
||||
println!("Process finished");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use chrono::Local;
|
||||
|
||||
use super::{NodeConfiguration, RunnableGraph};
|
||||
|
||||
#[test]
|
||||
fn test_basic() -> anyhow::Result<()> {
|
||||
let graph = RunnableGraph {
|
||||
graph: super::Graph {
|
||||
name: "Test".to_owned(),
|
||||
nodes: vec![super::Node {
|
||||
id: 1,
|
||||
dependent_node_ids: vec![],
|
||||
info: super::NodeInfo {
|
||||
name: "Hello".to_owned(),
|
||||
configuration: NodeConfiguration::FilterNode(super::FilterNode {
|
||||
filters: vec![],
|
||||
input_file_path: "".to_owned(),
|
||||
output_file_path: "".to_owned(),
|
||||
}),
|
||||
output_files: vec![],
|
||||
dynamic_configuration: None,
|
||||
},
|
||||
last_modified: Local::now(),
|
||||
}],
|
||||
},
|
||||
};
|
||||
graph.run_default_tasks(2, |_, _| {})?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
6
src/graph/node.rs
Normal file
6
src/graph/node.rs
Normal file
@@ -0,0 +1,6 @@
|
||||
pub trait RunnableNode {
|
||||
// TODO: Get inputs/outputs to determine whether we can skip running this task
|
||||
|
||||
// TODO: Status
|
||||
fn run(&self) -> anyhow::Result<()>;
|
||||
}
|
||||
169
src/graph/split.rs
Normal file
169
src/graph/split.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
use std::{collections::BTreeMap, fs::File};
|
||||
|
||||
use chrono::DateTime;
|
||||
use polars::{
|
||||
io::SerWriter,
|
||||
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tempfile::tempfile;
|
||||
|
||||
use crate::io::{RecordDeserializer, RecordSerializer};
|
||||
|
||||
use super::{
|
||||
derive::{self, DataValidator, DeriveFilter},
|
||||
node::RunnableNode,
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum DatePart {
|
||||
Year,
|
||||
Month,
|
||||
Week,
|
||||
Day,
|
||||
Hour,
|
||||
Minute,
|
||||
Second,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub enum SplitType {
|
||||
// Column, frequency
|
||||
DateTime(DatePart),
|
||||
Numeric(isize),
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct SplitOnChangeInColumn {
|
||||
id_column: String,
|
||||
change_column: String,
|
||||
limit: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct Split {
|
||||
column: String,
|
||||
split_type: SplitType,
|
||||
// If specified, a split will also be made when the change column changes for the id column
|
||||
change_in_column: Option<SplitOnChangeInColumn>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct SplitRule {
|
||||
pub filters: Vec<DeriveFilter>,
|
||||
pub splits: Vec<Split>,
|
||||
}
|
||||
|
||||
impl SplitRule {
|
||||
fn to_runnable_rule(&self) -> anyhow::Result<RunnableSplitRule> {
|
||||
let filters = derive::to_filter_rules(&self.filters)?;
|
||||
Ok(RunnableSplitRule {
|
||||
filters,
|
||||
splits: self.splits.clone(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct SplitNode {
|
||||
pub filters: Vec<DeriveFilter>,
|
||||
pub rules: Vec<SplitRule>,
|
||||
pub input_file_path: String,
|
||||
pub output_file_path: String,
|
||||
}
|
||||
|
||||
pub struct RunnableSplitRule {
|
||||
pub filters: Vec<Box<dyn DataValidator>>,
|
||||
pub splits: Vec<Split>,
|
||||
}
|
||||
|
||||
pub struct SplitNodeRunner {
|
||||
pub split_node: SplitNode,
|
||||
}
|
||||
|
||||
fn split_line(
|
||||
line: BTreeMap<String, String>,
|
||||
rules: &Vec<RunnableSplitRule>,
|
||||
output: &mut impl RecordSerializer,
|
||||
date_format: &str,
|
||||
last_split_value: Option<(String, String)>,
|
||||
) -> anyhow::Result<Option<(String, String)>> {
|
||||
let mut result_lines = vec![];
|
||||
for rule in rules {
|
||||
if !derive::is_line_valid(&line, &rule.filters) {
|
||||
continue;
|
||||
}
|
||||
for split in &rule.splits {
|
||||
let value = line.get(&split.column);
|
||||
if let Some(value) = value {
|
||||
// Parse the value in the column for the rule
|
||||
match &split.split_type {
|
||||
SplitType::DateTime(frequency) => {
|
||||
let date_time = DateTime::parse_from_str(&value, &date_format)?;
|
||||
// TODO: Now split the row up based on the frequency in the rule
|
||||
}
|
||||
SplitType::Numeric(frequency) => {
|
||||
// TODO: Just skip unparsable values, log out it's incorrect?
|
||||
let number = value.parse::<f64>()?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result_lines.push(line);
|
||||
for line in result_lines {
|
||||
output.serialize(line)?;
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn split(
|
||||
rules: &Vec<RunnableSplitRule>,
|
||||
input: &String,
|
||||
output: &mut impl RecordSerializer,
|
||||
date_format: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
// First sort the input file into the output file
|
||||
|
||||
let mut temp_path = tempfile()?;
|
||||
|
||||
// This needs to be done for each split rule with a change column specified
|
||||
let df = LazyCsvReader::new(input).finish()?;
|
||||
let df = df.sort(["", ""], Default::default());
|
||||
CsvWriter::new(&mut temp_path).finish(&mut df.collect()?)?;
|
||||
|
||||
// Then read from the temporary file (since it's sorted) and do the standard split over each row
|
||||
let mut input = csv::Reader::from_reader(temp_path);
|
||||
if let Some(line) = input.deserialize().next() {
|
||||
let line: BTreeMap<String, String> = line?;
|
||||
output.write_header(&line)?;
|
||||
let mut last_split_value = split_line(line, rules, output, &date_format, None)?;
|
||||
|
||||
for line in input.deserialize() {
|
||||
let line: BTreeMap<String, String> = line?;
|
||||
last_split_value = split_line(line, rules, output, &date_format, last_split_value)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl RunnableNode for SplitNodeRunner {
|
||||
fn run(&self) -> anyhow::Result<()> {
|
||||
let mut output = csv::Writer::from_path(&self.split_node.output_file_path)?;
|
||||
let rules: anyhow::Result<Vec<RunnableSplitRule>> = self
|
||||
.split_node
|
||||
.rules
|
||||
.iter()
|
||||
.map(|rule| rule.to_runnable_rule())
|
||||
.collect();
|
||||
let rules = rules?;
|
||||
split(
|
||||
&rules,
|
||||
&self.split_node.input_file_path,
|
||||
&mut output,
|
||||
"%Y-%m-%d %H-%M-%S",
|
||||
)
|
||||
}
|
||||
}
|
||||
82
src/graph/sql_rule.rs
Normal file
82
src/graph/sql_rule.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use std::fs::File;
|
||||
|
||||
use polars::{
|
||||
io::SerWriter,
|
||||
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
|
||||
};
|
||||
use polars_sql::SQLContext;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::node::RunnableNode;
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct CSVFile {
|
||||
name: String,
|
||||
path: String,
|
||||
}
|
||||
|
||||
/**
|
||||
* Run SQL over files using polars, export results to output file
|
||||
*/
|
||||
fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow::Result<()> {
|
||||
let mut ctx = SQLContext::new();
|
||||
for file in files {
|
||||
let df = LazyCsvReader::new(&file.path).finish()?;
|
||||
ctx.register(&file.name, df);
|
||||
}
|
||||
let result = ctx.execute(&query)?;
|
||||
let mut file = File::create(output_path)?;
|
||||
CsvWriter::new(&mut file).finish(&mut result.collect()?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct SQLNode {
|
||||
pub files: Vec<CSVFile>,
|
||||
pub output_file: String,
|
||||
pub query: String,
|
||||
}
|
||||
|
||||
pub struct SQLNodeRunner {
|
||||
pub sql_node: SQLNode,
|
||||
}
|
||||
|
||||
impl RunnableNode for SQLNodeRunner {
|
||||
fn run(&self) -> anyhow::Result<()> {
|
||||
run_sql(
|
||||
&self.sql_node.files,
|
||||
&self.sql_node.output_file,
|
||||
&self.sql_node.query,
|
||||
)
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::{fs::File, io::Read};
|
||||
|
||||
use super::{run_sql, CSVFile};
|
||||
|
||||
#[test]
|
||||
fn basic_query_works() -> anyhow::Result<()> {
|
||||
let output_path = "./testing/output/output.csv".to_owned();
|
||||
run_sql(
|
||||
&vec![CSVFile {
|
||||
name: "Account".to_owned(),
|
||||
path: "./testing/test.csv".to_owned(),
|
||||
}],
|
||||
&output_path,
|
||||
&"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(),
|
||||
)?;
|
||||
let mut output = String::new();
|
||||
let mut output_file = File::open(output_path)?;
|
||||
output_file.read_to_string(&mut output)?;
|
||||
assert_eq!(
|
||||
output,
|
||||
"Code,Description,Type,CostOutput,PercentFixed
|
||||
A195950,A195950 Staff Related Other,E,GS,100.00
|
||||
"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
91
src/graph/upload_to_db.rs
Normal file
91
src/graph/upload_to_db.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::bail;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Any, Pool, QueryBuilder};
|
||||
|
||||
use super::node::RunnableNode;
|
||||
|
||||
const BIND_LIMIT: usize = 65535;
|
||||
|
||||
// Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any
|
||||
// type for sqlx 0.6 and earlier due to a query_builder lifetime issue,
|
||||
// however sqlx >=0.7 currently doesn't support mssql.
|
||||
|
||||
// Upload data in a file to a db table, with an optional post-script to run,
|
||||
// such as to move data from the upload table into other tables
|
||||
// TODO: Add bulk insert options for non-mssql dbs
|
||||
// TODO: Add fallback insert when bulk insert fails (e.g. due to
|
||||
// permission errors)
|
||||
pub async fn upload_file_bulk(pool: &Pool<Any>, upload_node: &UploadNode) -> anyhow::Result<u64> {
|
||||
let mut rows_affected = None;
|
||||
if upload_node.column_mappings.is_none() {
|
||||
let insert_from_file_query = match pool.connect_options().database_url.scheme() {
|
||||
"postgres" => Some(format!("COPY {} FROM $1", upload_node.table_name)),
|
||||
"mysql" => Some(format!(
|
||||
"LOAD DATA INFILE ? INTO {}",
|
||||
upload_node.table_name,
|
||||
)),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(insert_from_file_query) = insert_from_file_query {
|
||||
let result = sqlx::query(&insert_from_file_query)
|
||||
.bind(&upload_node.file_path)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
rows_affected = Some(result.rows_affected());
|
||||
}
|
||||
}
|
||||
|
||||
if rows_affected == None {
|
||||
let rows: Vec<HashMap<String, String>> = vec![];
|
||||
|
||||
// TODO: Columns to insert... needs some kind of mapping from file column name <-> db column
|
||||
let mut query_builder =
|
||||
QueryBuilder::new(format!("INSERT INTO {}({}) ", upload_node.table_name, ""));
|
||||
// TODO: Iterate over all values in file, not the limit
|
||||
query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| {
|
||||
b.push_bind(row.get("s"));
|
||||
});
|
||||
let mut query_builder = query_builder;
|
||||
// TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978
|
||||
// Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx
|
||||
// to use this, unless we explicity specified mssql only, not Any as the db type...
|
||||
// Can probably work around this by specifying an actual implementation?
|
||||
let query = query_builder.build();
|
||||
let result = query.execute(pool).await?;
|
||||
rows_affected = Some(result.rows_affected());
|
||||
}
|
||||
|
||||
if let Some(post_script) = &upload_node.post_script {
|
||||
sqlx::query(&post_script).execute(pool).await?;
|
||||
}
|
||||
|
||||
match rows_affected {
|
||||
Some(rows_affected) => Ok(rows_affected),
|
||||
None => bail!(
|
||||
"Invalid state, rows_affected must be populated, or an error should have occurred"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||
pub struct UploadNode {
|
||||
file_path: String,
|
||||
table_name: String,
|
||||
// Mappings from column in file -> column in db
|
||||
column_mappings: Option<HashMap<String, String>>,
|
||||
post_script: Option<String>,
|
||||
}
|
||||
|
||||
pub struct UploadNodeRunner {
|
||||
pub upload_node: UploadNode,
|
||||
}
|
||||
|
||||
impl RunnableNode for UploadNodeRunner {
|
||||
fn run(&self) -> anyhow::Result<()> {
|
||||
// TODO: Get db connection from some kind of property manager/context
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user