Start adding row-level splitting, refactor cli and graph into subcrates

This commit is contained in:
2024-08-09 22:13:43 +09:30
parent 3cdaa81da1
commit 0ee88e3a99
11 changed files with 259 additions and 110 deletions

View File

@@ -1,6 +1,5 @@
use clap::Parser;
use cli::Cli;
mod cli;
use coster_rs::cli::Cli;
fn main() -> anyhow::Result<()> {
env_logger::init();

View File

@@ -10,13 +10,15 @@ use std::io::Write;
use clap::{command, Parser};
pub use commands::Commands;
use coster_rs::{
use log::info;
use schemars::schema_for;
use crate::{
create_products::InputFile,
graph::{Graph, RunnableGraph},
SourceType,
};
use log::info;
use schemars::schema_for;
mod commands;
@@ -41,7 +43,7 @@ impl Cli {
output,
use_numeric_accounts,
flush_pass,
} => coster_rs::move_money(
} => crate::move_money(
&mut csv::Reader::from_path(rules)?,
&mut csv::Reader::from_path(lines)?,
&mut csv::Reader::from_path(accounts)?,
@@ -66,7 +68,7 @@ impl Cli {
} => {
if msgpack_serialisation {
let mut file = BufWriter::new(File::create(output)?);
coster_rs::reciprocal_allocation(
crate::reciprocal_allocation(
&mut csv::Reader::from_path(lines)?,
&mut csv::Reader::from_path(accounts)?,
&mut csv::Reader::from_path(allocation_statistics)?,
@@ -81,7 +83,7 @@ impl Cli {
zero_threshold,
)
} else {
coster_rs::reciprocal_allocation(
crate::reciprocal_allocation(
&mut csv::Reader::from_path(lines)?,
&mut csv::Reader::from_path(accounts)?,
&mut csv::Reader::from_path(allocation_statistics)?,
@@ -165,7 +167,7 @@ impl Cli {
date_order_column: None,
},
);
coster_rs::create_products::create_products_polars(definitions, vec![], output)
crate::create_products::create_products_polars(definitions, vec![], output)
}
Commands::RunGraph { graph, threads } => {
let file = File::open(graph)?;

View File

@@ -1,14 +1,11 @@
use std::{collections::BTreeMap, str::FromStr};
use anyhow::bail;
use itertools::Itertools;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{
io::{RecordDeserializer, RecordSerializer},
node::RunnableNode,
};
use crate::io::{RecordDeserializer, RecordSerializer};
use super::node::RunnableNode;
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum DeriveColumnType {
@@ -16,29 +13,14 @@ pub enum DeriveColumnType {
Constant(String),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum DatePart {
Year,
Month,
Week,
Day,
Hour,
Minute,
Second,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum SplitType {
DateTime(String, DatePart),
Numeric(String, isize),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
#[derive(Serialize, Deserialize, Clone, JsonSchema, PartialEq)]
pub enum MatchComparisonType {
Equal,
GreaterThan,
LessThan,
NotEqual,
In,
NotIn,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
@@ -49,11 +31,6 @@ pub enum DeriveOperation {
Subtract(Vec<DeriveColumnType>),
Divide(Vec<DeriveColumnType>),
Map(String),
// Might be better putting this into its own node, then we can do sorting operations
// and ensure the split only happens when a particular column changes value. Could
// also just leave these more complex use cases for SQL/Code nodes instead (if even possible
// in an SQL node, and code nodes aren't even implemented yet)
Split(String, SplitType),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
@@ -68,7 +45,7 @@ pub enum ValueType {
pub struct DeriveFilter {
pub column_name: String,
pub comparator: MatchComparisonType,
pub match_value: String,
pub match_value: Vec<String>,
pub value_type: ValueType,
}
@@ -126,36 +103,47 @@ impl<T: FromStr + PartialOrd> DataValidator for FilterRule<T> {
pub fn to_filter_rules(filters: &Vec<DeriveFilter>) -> anyhow::Result<Vec<Box<dyn DataValidator>>> {
filters
.iter()
// For some reason inlining to_filter_rules causes a compiler error, so leaving
// in a separate function (it is cleaner at least)
.map(|filter| to_filter_rule(filter))
.collect()
}
fn to_filter_rule(filter: &DeriveFilter) -> anyhow::Result<Box<dyn DataValidator>> {
let value = filter.match_value.clone();
let value = &filter.match_value;
match filter.value_type {
crate::derive::ValueType::String => Ok(Box::new(get_filter_rule(filter, value))),
crate::derive::ValueType::Integer => {
Ok(Box::new(get_filter_rule(filter, value.parse::<i64>()?)))
}
crate::derive::ValueType::Float => {
Ok(Box::new(get_filter_rule(filter, value.parse::<f64>()?)))
}
crate::derive::ValueType::Boolean => {
Ok(Box::new(get_filter_rule(filter, value.parse::<bool>()?)))
}
ValueType::String => Ok(Box::new(get_filter_rule(filter, value.clone()))),
ValueType::Integer => Ok(Box::new(get_filter_rule(
filter,
parse_values(value, |value| value.parse::<i64>())?,
))),
ValueType::Float => Ok(Box::new(get_filter_rule(
filter,
parse_values(value, |value| value.parse::<f64>())?,
))),
ValueType::Boolean => Ok(Box::new(get_filter_rule(
filter,
parse_values(value, |value| value.parse::<bool>())?,
))),
}
}
fn get_filter_rule<T: PartialOrd>(filter: &DeriveFilter, value: T) -> FilterRule<T> {
fn parse_values<T, E, F>(values: &Vec<String>, parse: F) -> Result<Vec<T>, E>
where
F: Fn(&String) -> Result<T, E>,
{
let values: Result<Vec<T>, E> = values.into_iter().map(|value| parse(value)).collect();
values
}
fn get_filter_rule<T: PartialOrd + Clone>(filter: &DeriveFilter, value: Vec<T>) -> FilterRule<T> {
FilterRule {
column_name: filter.column_name.clone(),
comparator: match filter.comparator {
MatchComparisonType::Equal => Comparator::Equal(value),
MatchComparisonType::GreaterThan => Comparator::GreaterThan(value),
MatchComparisonType::LessThan => Comparator::LessThan(value),
MatchComparisonType::NotEqual => Comparator::NotEqual(value),
MatchComparisonType::Equal => Comparator::Equal(value[0].clone()),
MatchComparisonType::GreaterThan => Comparator::GreaterThan(value[0].clone()),
MatchComparisonType::LessThan => Comparator::LessThan(value[0].clone()),
MatchComparisonType::NotEqual => Comparator::NotEqual(value[0].clone()),
MatchComparisonType::In => Comparator::In(value),
MatchComparisonType::NotIn => Comparator::NotIn(value),
},
}
}
@@ -170,6 +158,7 @@ pub struct DeriveColumnOperation {
pub struct DeriveRule {
pub operations: Vec<DeriveColumnOperation>,
pub filters: Vec<DeriveFilter>,
pub copy_all_columns: bool,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
@@ -177,6 +166,7 @@ pub struct DeriveNode {
pub rules: Vec<DeriveRule>,
pub input_file_path: String,
pub output_file_path: String,
pub copy_all_columns: bool,
}
pub struct RunnableDeriveRule {
@@ -248,17 +238,20 @@ fn derive_line(
line: BTreeMap<String, String>,
rules: &Vec<RunnableDeriveRule>,
output: &mut impl RecordSerializer,
copy_all_columns: bool,
) -> anyhow::Result<()> {
let mut line = line;
let mut output_line;
if copy_all_columns {
output_line = line.clone();
} else {
output_line = BTreeMap::new();
}
for rule in rules {
if !is_line_valid(&line, &rule.filters) {
continue;
}
for operation in &rule.operations {
if let DeriveOperation::Split(_, _) = operation.operation {
continue;
}
let value = match &operation.operation {
DeriveOperation::Concat(concat) => concat_columns(&line, concat),
DeriveOperation::Add(columns) => {
@@ -274,53 +267,28 @@ fn derive_line(
reduce_numeric_columns(&line, columns, |a, b| a / b)
}
DeriveOperation::Map(mapped_value) => mapped_value.clone(),
DeriveOperation::Split(_, _) => {
bail!("Invalid state, split type must be checked after other operations")
}
};
line.insert(operation.column_name.clone(), value);
output_line.insert(operation.column_name.clone(), value);
}
}
let split_operations = rules
.iter()
.flat_map(|rule| {
if !is_line_valid(&line, &rule.filters) {
return vec![];
}
rule.operations
.iter()
.filter(|operation| {
if let DeriveOperation::Split(_, _) = operation.operation {
return true;
}
false
})
.collect_vec()
})
.collect_vec();
if split_operations.is_empty() {
output.serialize(line)?;
} else {
}
Ok(())
output.serialize(output_line)
}
fn derive(
rules: &Vec<RunnableDeriveRule>,
input: &mut impl RecordDeserializer,
output: &mut impl RecordSerializer,
copy_all_columns: bool,
) -> anyhow::Result<()> {
if let Some(line) = input.deserialize()? {
let line: BTreeMap<String, String> = line;
output.write_header(&line)?;
derive_line(line, rules, output)?;
derive_line(line, rules, output, copy_all_columns)?;
while let Some(line) = input.deserialize()? {
let line: BTreeMap<String, String> = line;
derive_line(line, rules, output)?;
derive_line(line, rules, output, copy_all_columns)?;
}
}
Ok(())
@@ -341,6 +309,11 @@ impl RunnableNode for DeriveNodeRunner {
.map(|rule| rule.to_runnable_rule())
.collect();
let rules = rules?;
derive(&rules, &mut reader, &mut writer)
derive(
&rules,
&mut reader,
&mut writer,
self.derive_node.copy_all_columns,
)
}
}

View File

@@ -3,11 +3,12 @@ use std::collections::BTreeMap;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::{
derive::{is_line_valid, to_filter_rules, DataValidators, DeriveFilter},
io::{RecordDeserializer, RecordSerializer},
node::RunnableNode,
};
use crate::io::{RecordDeserializer, RecordSerializer};
use super::derive::{DataValidators, DeriveFilter};
use super::derive;
use super::node::RunnableNode;
/**
* Write all lines from the input file to the output file, skipping records
@@ -22,13 +23,13 @@ pub fn filter_file(
let line: BTreeMap<String, String> = line;
output.write_header(&line)?;
if is_line_valid(&line, &rules) {
if derive::is_line_valid(&line, &rules) {
output.write_record(&line)?;
}
while let Some(line) = input.deserialize()? {
let line: BTreeMap<String, String> = line;
if is_line_valid(&line, rules) {
if derive::is_line_valid(&line, rules) {
output.write_record(&line)?;
}
}
@@ -52,7 +53,7 @@ impl RunnableNode for FilterNodeRunner {
fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
let rules = to_filter_rules(&self.filter_node.filters)?;
let rules = derive::to_filter_rules(&self.filter_node.filters)?;
filter_file(&rules, &mut reader, &mut writer)
}
}
@@ -60,7 +61,7 @@ impl RunnableNode for FilterNodeRunner {
#[cfg(test)]
mod tests {
use crate::derive::{Comparator, FilterRule};
use super::derive::{Comparator, FilterRule};
use super::filter_file;

View File

@@ -11,8 +11,9 @@ use std::{
use chrono::Local;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use split::{SplitNode, SplitNodeRunner};
use crate::{
use {
derive::DeriveNode,
filter::{FilterNode, FilterNodeRunner},
node::RunnableNode,
@@ -20,6 +21,13 @@ use crate::{
upload_to_db::{UploadNode, UploadNodeRunner},
};
mod derive;
mod filter;
mod node;
mod split;
mod sql_rule;
mod upload_to_db;
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum NodeConfiguration {
FileNode,
@@ -31,6 +39,7 @@ pub enum NodeConfiguration {
UploadNode(UploadNode),
SQLNode(SQLNode),
Dynamic,
SplitNode(SplitNode),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
@@ -133,6 +142,7 @@ fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
NodeConfiguration::SQLNode(sql_node) => Box::new(SQLNodeRunner { sql_node }),
NodeConfiguration::Dynamic => todo!(),
NodeConfiguration::SplitNode(split_node) => Box::new(SplitNodeRunner { split_node }),
}
}

169
src/graph/split.rs Normal file
View File

@@ -0,0 +1,169 @@
use std::{collections::BTreeMap, fs::File};
use chrono::DateTime;
use polars::{
io::SerWriter,
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tempfile::tempfile;
use crate::io::{RecordDeserializer, RecordSerializer};
use super::{
derive::{self, DataValidator, DeriveFilter},
node::RunnableNode,
};
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum DatePart {
Year,
Month,
Week,
Day,
Hour,
Minute,
Second,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum SplitType {
// Column, frequency
DateTime(DatePart),
Numeric(isize),
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct SplitOnChangeInColumn {
id_column: String,
change_column: String,
limit: Option<u64>,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct Split {
column: String,
split_type: SplitType,
// If specified, a split will also be made when the change column changes for the id column
change_in_column: Option<SplitOnChangeInColumn>,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct SplitRule {
pub filters: Vec<DeriveFilter>,
pub splits: Vec<Split>,
}
impl SplitRule {
fn to_runnable_rule(&self) -> anyhow::Result<RunnableSplitRule> {
let filters = derive::to_filter_rules(&self.filters)?;
Ok(RunnableSplitRule {
filters,
splits: self.splits.clone(),
})
}
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct SplitNode {
pub filters: Vec<DeriveFilter>,
pub rules: Vec<SplitRule>,
pub input_file_path: String,
pub output_file_path: String,
}
pub struct RunnableSplitRule {
pub filters: Vec<Box<dyn DataValidator>>,
pub splits: Vec<Split>,
}
pub struct SplitNodeRunner {
pub split_node: SplitNode,
}
fn split_line(
line: BTreeMap<String, String>,
rules: &Vec<RunnableSplitRule>,
output: &mut impl RecordSerializer,
date_format: &str,
last_split_value: Option<(String, String)>,
) -> anyhow::Result<Option<(String, String)>> {
let mut result_lines = vec![];
for rule in rules {
if !derive::is_line_valid(&line, &rule.filters) {
continue;
}
for split in &rule.splits {
let value = line.get(&split.column);
if let Some(value) = value {
// Parse the value in the column for the rule
match &split.split_type {
SplitType::DateTime(frequency) => {
let date_time = DateTime::parse_from_str(&value, &date_format)?;
// TODO: Now split the row up based on the frequency in the rule
}
SplitType::Numeric(frequency) => {
// TODO: Just skip unparsable values, log out it's incorrect?
let number = value.parse::<f64>()?;
}
}
}
}
}
result_lines.push(line);
for line in result_lines {
output.serialize(line)?;
}
Ok(None)
}
fn split(
rules: &Vec<RunnableSplitRule>,
input: &String,
output: &mut impl RecordSerializer,
date_format: &str,
) -> anyhow::Result<()> {
// First sort the input file into the output file
let mut temp_path = tempfile()?;
// This needs to be done for each split rule with a change column specified
let df = LazyCsvReader::new(input).finish()?;
let df = df.sort(["", ""], Default::default());
CsvWriter::new(&mut temp_path).finish(&mut df.collect()?)?;
// Then read from the temporary file (since it's sorted) and do the standard split over each row
let mut input = csv::Reader::from_reader(temp_path);
if let Some(line) = input.deserialize().next() {
let line: BTreeMap<String, String> = line?;
output.write_header(&line)?;
let mut last_split_value = split_line(line, rules, output, &date_format, None)?;
for line in input.deserialize() {
let line: BTreeMap<String, String> = line?;
last_split_value = split_line(line, rules, output, &date_format, last_split_value)?;
}
}
Ok(())
}
impl RunnableNode for SplitNodeRunner {
fn run(&self) -> anyhow::Result<()> {
let mut output = csv::Writer::from_path(&self.split_node.output_file_path)?;
let rules: anyhow::Result<Vec<RunnableSplitRule>> = self
.split_node
.rules
.iter()
.map(|rule| rule.to_runnable_rule())
.collect();
let rules = rules?;
split(
&rules,
&self.split_node.input_file_path,
&mut output,
"%Y-%m-%d %H-%M-%S",
)
}
}

View File

@@ -8,7 +8,7 @@ use polars_sql::SQLContext;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::node::RunnableNode;
use super::node::RunnableNode;
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct CSVFile {

View File

@@ -5,7 +5,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sqlx::{Any, Pool, QueryBuilder};
use crate::node::RunnableNode;
use super::node::RunnableNode;
const BIND_LIMIT: usize = 65535;

View File

@@ -11,15 +11,10 @@ pub use self::products::create_products;
pub use self::products::csv::SourceType;
mod shared_models;
pub use self::shared_models::*;
pub mod code_rule;
pub mod derive;
pub mod filter;
pub mod cli;
pub mod graph;
mod io;
pub mod link;
pub mod node;
pub mod sql_rule;
pub mod upload_to_db;
#[no_mangle]
pub extern "C" fn move_money_from_text(