From 3cdaa81da1a31016a97670b76d3b69e43657ed1a Mon Sep 17 00:00:00 2001 From: vato007 Date: Sat, 3 Aug 2024 16:33:16 +0930 Subject: [PATCH] Add schema generation, refactor cli, add most of the derive operations --- Cargo.lock | 86 ++++++++++++++- Cargo.toml | 5 + src/cli/commands.rs | 107 +++++++++++++++++++ src/cli/mod.rs | 187 +++++++++++++++++++++++++++++++++ src/derive.rs | 169 +++++++++++++++++++++++------- src/filter.rs | 19 +--- src/graph.rs | 63 ++++++----- src/lib.rs | 2 +- src/main.rs | 249 +------------------------------------------- src/sql_rule.rs | 5 +- src/upload_to_db.rs | 3 +- 11 files changed, 567 insertions(+), 328 deletions(-) create mode 100644 src/cli/commands.rs create mode 100644 src/cli/mod.rs diff --git a/Cargo.lock b/Cargo.lock index e755377..965fa53 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -395,13 +395,18 @@ dependencies = [ "chrono", "clap", "csv", + "env_logger", "itertools", + "log", "nalgebra", + "num_cpus", "polars", "polars-sql", "rayon", "rmp-serde", + "schemars", "serde", + "serde_json", "sqlx", "tempfile", "tokio", @@ -583,6 +588,29 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "env_filter" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f2c92ceda6ceec50f43169f9ee8424fe2db276791afde7b2cd8bc084cb376ab" +dependencies = [ + "log", + "regex", +] + +[[package]] +name = "env_logger" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13fa619b91fb2381732789fc5de83b45675e882f66623b7d8cb4f643017018d" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -844,6 +872,12 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "humantime" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -1222,6 +1256,16 @@ dependencies = [ "libm", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "object" version = "0.36.2" @@ -2090,6 +2134,31 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "schemars" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09c024468a378b7e36765cd36702b7a90cc3cba11654f6685c8f233408e89e92" +dependencies = [ + "chrono", + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1eee588578aff73f856ab961cd2f79e36bc45d7ded33a7562adba4667aecc0e" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.72", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2127,10 +2196,21 @@ dependencies = [ ] [[package]] -name = "serde_json" -version = "1.0.121" +name = "serde_derive_internals" +version = "0.29.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ab380d7d9f22ef3f21ad3e6c1ebe8e4fc7a2000ccba2e4d71fc96f15b2cb609" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "serde_json" +version = "1.0.122" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784b6203951c57ff748476b126ccb5e8e2959a5c19e5c617ab1956be3dbc68da" dependencies = [ "itoa", "memchr", diff --git a/Cargo.toml b/Cargo.toml index 1111fe0..48070d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,11 @@ rmp-serde = "1.1" tempfile = "3.7" polars = {version = "0.41", features = ["lazy", "performant", "streaming", "cse", "dtype-datetime"]} polars-sql = "0.41" +serde_json = "1.0.122" +num_cpus = "1.16.0" +schemars = {version = "0.8.21", features = ["chrono"]} +log = "0.4.22" +env_logger = "0.11.5" # More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target [lib] diff --git a/src/cli/commands.rs b/src/cli/commands.rs new file mode 100644 index 0000000..d5539f3 --- /dev/null +++ b/src/cli/commands.rs @@ -0,0 +1,107 @@ +use std::path::PathBuf; + +use clap::Subcommand; + +#[derive(Subcommand)] +pub enum Commands { + /// Moves money between accounts and departments, using the given rules and lines + MoveMoney { + #[arg(short = 'r', long, value_name = "FILE")] + rules: PathBuf, + + #[arg(short = 'l', long, value_name = "FILE")] + lines: PathBuf, + + #[arg(short = 'a', long, value_name = "FILE")] + accounts: PathBuf, + + #[arg(short = 'c', long, value_name = "FILE")] + cost_centres: PathBuf, + + #[arg(short, long, value_name = "FILE")] + output: Option, + + #[arg(short, long)] + use_numeric_accounts: bool, + + #[arg(short, long)] + flush_pass: bool, + }, + /// Allocates servicing department amounts to operating departments + AllocateOverheads { + #[arg(short, long, value_name = "FILE")] + lines: PathBuf, + + #[arg(short, long, value_name = "FILE")] + accounts: PathBuf, + + #[arg(short = 's', long, value_name = "FILE")] + allocation_statistics: PathBuf, + + #[arg(short, long, value_name = "FILE")] + areas: PathBuf, + + #[arg(short, long, value_name = "FILE")] + cost_centres: PathBuf, + + #[arg(short, long)] + use_numeric_accounts: bool, + + #[arg(long, default_value = "E")] + account_type: String, + + #[arg(short, long)] + exclude_negative_allocation_statistics: bool, + + #[arg(short = 'f', long)] + show_from: bool, + + #[arg(short, long, default_value = "0.00000000000000001")] + zero_threshold: f64, + + #[arg(short, long, value_name = "FILE", default_value = "alloc_output.csv")] + output: PathBuf, + + #[arg(short, long)] + msgpack_serialisation: bool, + }, + CreateProducts { + #[arg(short, long, value_name = "FILE")] + definitions: PathBuf, + + #[arg(short, long, value_name = "FILE")] + encounters: PathBuf, + + #[arg(short, long, value_name = "FILE")] + services: PathBuf, + + #[arg(short, long, value_name = "FILE")] + transfers: PathBuf, + + #[arg(short, long, value_name = "FILE")] + procedures: PathBuf, + + #[arg(short, long, value_name = "FILE")] + diagnoses: PathBuf, + + #[arg(short, long, value_name = "FILE")] + patients: PathBuf, + + #[arg(short, long, value_name = "FILE")] + revenues: PathBuf, + + #[arg(short, long, value_name = "FILE")] + output: PathBuf, + }, + RunGraph { + #[arg(short, long, value_name = "FILE")] + graph: PathBuf, + + #[arg(short, long, default_value_t = num_cpus::get())] + threads: usize, + }, + GenerateSchema { + #[arg(short, long, value_name = "FILE", default_value = "schema.json")] + output: PathBuf, + }, +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..c2e4e16 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,187 @@ +use std::{ + collections::HashMap, + fs::File, + io::{BufReader, BufWriter}, + path::PathBuf, +}; + +use std::io::Write; + +use clap::{command, Parser}; + +pub use commands::Commands; +use coster_rs::{ + create_products::InputFile, + graph::{Graph, RunnableGraph}, + SourceType, +}; +use log::info; +use schemars::schema_for; + +mod commands; + +#[derive(Parser)] +#[command(name = "coster-rs")] +#[command(author = "Pivato M. ")] +#[command(version = "0.0.1")] +#[command(about = "Simple, fast, efficient costing tool", long_about = None)] +pub struct Cli { + #[clap(subcommand)] + pub command: Commands, +} + +impl Cli { + pub fn run(self) -> anyhow::Result<()> { + match self.command { + Commands::MoveMoney { + rules, + lines, + accounts, + cost_centres, + output, + use_numeric_accounts, + flush_pass, + } => coster_rs::move_money( + &mut csv::Reader::from_path(rules)?, + &mut csv::Reader::from_path(lines)?, + &mut csv::Reader::from_path(accounts)?, + &mut csv::Reader::from_path(cost_centres)?, + &mut csv::Writer::from_path(output.unwrap_or(PathBuf::from("output.csv")))?, + use_numeric_accounts, + flush_pass, + ), + Commands::AllocateOverheads { + lines, + accounts, + allocation_statistics, + areas, + cost_centres, + use_numeric_accounts, + account_type, + exclude_negative_allocation_statistics, + show_from, + zero_threshold, + output, + msgpack_serialisation, + } => { + if msgpack_serialisation { + let mut file = BufWriter::new(File::create(output)?); + coster_rs::reciprocal_allocation( + &mut csv::Reader::from_path(lines)?, + &mut csv::Reader::from_path(accounts)?, + &mut csv::Reader::from_path(allocation_statistics)?, + &mut csv::Reader::from_path(areas)?, + &mut csv::Reader::from_path(cost_centres)?, + &mut rmp_serde::Serializer::new(&mut file), + use_numeric_accounts, + exclude_negative_allocation_statistics, + true, + account_type, + show_from, + zero_threshold, + ) + } else { + coster_rs::reciprocal_allocation( + &mut csv::Reader::from_path(lines)?, + &mut csv::Reader::from_path(accounts)?, + &mut csv::Reader::from_path(allocation_statistics)?, + &mut csv::Reader::from_path(areas)?, + &mut csv::Reader::from_path(cost_centres)?, + &mut csv::Writer::from_path(output)?, + use_numeric_accounts, + exclude_negative_allocation_statistics, + true, + account_type, + show_from, + zero_threshold, + ) + } + } + Commands::CreateProducts { + definitions, + encounters, + services, + transfers, + procedures, + diagnoses, + patients, + revenues, + output, + } => { + let mut inputs = HashMap::new(); + inputs.insert( + SourceType::Encounter, + InputFile { + file_path: encounters, + joins: HashMap::new(), + date_order_column: Some("StartDateTime".to_owned()), + }, + ); + inputs.insert( + SourceType::Service, + InputFile { + file_path: services, + joins: HashMap::new(), + date_order_column: Some("StartDateTime".to_owned()), + }, + ); + inputs.insert( + SourceType::Transfer, + InputFile { + file_path: transfers, + joins: HashMap::new(), + date_order_column: Some("StartDateTime".to_owned()), + }, + ); + inputs.insert( + SourceType::CodingProcedure, + InputFile { + file_path: procedures, + joins: HashMap::new(), + date_order_column: Some("ProcedureDateTime".to_owned()), + }, + ); + inputs.insert( + SourceType::CodingDiagnosis, + InputFile { + file_path: diagnoses, + joins: HashMap::new(), + date_order_column: None, + }, + ); + inputs.insert( + SourceType::Patient, + InputFile { + file_path: patients, + joins: HashMap::new(), + date_order_column: None, + }, + ); + inputs.insert( + SourceType::Revenue, + InputFile { + file_path: revenues, + joins: HashMap::new(), + date_order_column: None, + }, + ); + coster_rs::create_products::create_products_polars(definitions, vec![], output) + } + Commands::RunGraph { graph, threads } => { + let file = File::open(graph)?; + let reader = BufReader::new(file); + let graph = serde_json::from_reader(reader)?; + let graph = RunnableGraph::from_graph(graph); + graph.run_default_tasks(threads, |id, status| { + info!("Node with id {} finished with status {:?}", id, status) + }) + } + Commands::GenerateSchema { output } => { + let schema = schema_for!(Graph); + let mut output = File::create(output).unwrap(); + write!(output, "{}", serde_json::to_string_pretty(&schema).unwrap())?; + Ok(()) + } + } + } +} diff --git a/src/derive.rs b/src/derive.rs index 44794ba..d286d19 100644 --- a/src/derive.rs +++ b/src/derive.rs @@ -1,5 +1,8 @@ use std::{collections::BTreeMap, str::FromStr}; +use anyhow::bail; +use itertools::Itertools; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::{ @@ -7,18 +10,13 @@ use crate::{ node::RunnableNode, }; -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum DeriveColumnType { Column(String), Constant(String), } -#[derive(Serialize, Deserialize, Clone)] -pub struct MapOperation { - pub mapped_value: String, -} - -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum DatePart { Year, Month, @@ -29,13 +27,13 @@ pub enum DatePart { Second, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum SplitType { DateTime(String, DatePart), Numeric(String, isize), } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum MatchComparisonType { Equal, GreaterThan, @@ -43,18 +41,22 @@ pub enum MatchComparisonType { NotEqual, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum DeriveOperation { Concat(Vec), Add(Vec), Multiply(Vec), - Subtract(DeriveColumnType, DeriveColumnType), - Divide(DeriveColumnType, DeriveColumnType), - Map(String, Vec), + Subtract(Vec), + Divide(Vec), + Map(String), + // Might be better putting this into its own node, then we can do sorting operations + // and ensure the split only happens when a particular column changes value. Could + // also just leave these more complex use cases for SQL/Code nodes instead (if even possible + // in an SQL node, and code nodes aren't even implemented yet) Split(String, SplitType), } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum ValueType { String, Integer, @@ -62,7 +64,7 @@ pub enum ValueType { Boolean, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DeriveFilter { pub column_name: String, pub comparator: MatchComparisonType, @@ -158,19 +160,19 @@ fn get_filter_rule(filter: &DeriveFilter, value: T) -> FilterRule } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DeriveColumnOperation { pub column_name: String, pub operation: DeriveOperation, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DeriveRule { pub operations: Vec, pub filters: Vec, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DeriveNode { pub rules: Vec, pub input_file_path: String, @@ -192,6 +194,120 @@ impl DeriveRule { } } +pub fn is_line_valid(line: &BTreeMap, rules: &DataValidators) -> bool { + rules.iter().all(|rule| { + line.get(&rule.get_field_name()).map_or(true, |value| { + if value.trim().is_empty() { + true + } else { + rule.is_valid(value) + } + }) + }) +} + +fn concat_columns(line: &BTreeMap, columns: &Vec) -> String { + columns + .iter() + .map(|col| match col { + DeriveColumnType::Column(column) => line + .get(column) + .map(|column| column.clone()) + .unwrap_or("".to_owned()), + DeriveColumnType::Constant(constant) => constant.clone(), + }) + .collect() +} + +fn reduce_numeric_columns( + line: &BTreeMap, + columns: &Vec, + reducer: F, +) -> String +where + F: Fn(f64, f64) -> f64, +{ + let value = columns + .iter() + .map(|col| match col { + DeriveColumnType::Column(column) => line + .get(column) + .map(|value| value.parse::().ok()) + .flatten(), + DeriveColumnType::Constant(constant) => constant.parse().ok(), + }) + .filter(|value| value.is_some()) + .map(|value| value.unwrap()) + .reduce(reducer); + value + .map(|value| value.to_string()) + .unwrap_or("".to_owned()) +} + +fn derive_line( + line: BTreeMap, + rules: &Vec, + output: &mut impl RecordSerializer, +) -> anyhow::Result<()> { + let mut line = line; + for rule in rules { + if !is_line_valid(&line, &rule.filters) { + continue; + } + + for operation in &rule.operations { + if let DeriveOperation::Split(_, _) = operation.operation { + continue; + } + let value = match &operation.operation { + DeriveOperation::Concat(concat) => concat_columns(&line, concat), + DeriveOperation::Add(columns) => { + reduce_numeric_columns(&line, columns, |a, b| a + b) + } + DeriveOperation::Multiply(columns) => { + reduce_numeric_columns(&line, columns, |a, b| a * b) + } + DeriveOperation::Subtract(columns) => { + reduce_numeric_columns(&line, columns, |a, b| a - b) + } + DeriveOperation::Divide(columns) => { + reduce_numeric_columns(&line, columns, |a, b| a / b) + } + DeriveOperation::Map(mapped_value) => mapped_value.clone(), + DeriveOperation::Split(_, _) => { + bail!("Invalid state, split type must be checked after other operations") + } + }; + line.insert(operation.column_name.clone(), value); + } + } + + let split_operations = rules + .iter() + .flat_map(|rule| { + if !is_line_valid(&line, &rule.filters) { + return vec![]; + } + rule.operations + .iter() + .filter(|operation| { + if let DeriveOperation::Split(_, _) = operation.operation { + return true; + } + false + }) + .collect_vec() + }) + .collect_vec(); + + if split_operations.is_empty() { + output.serialize(line)?; + } else { + } + + Ok(()) +} + fn derive( rules: &Vec, input: &mut impl RecordDeserializer, @@ -210,23 +326,6 @@ fn derive( Ok(()) } -fn derive_line( - line: BTreeMap, - rules: &Vec, - output: &mut impl RecordSerializer, -) -> anyhow::Result<()> { - for rule in rules { - // First check the filter works. If there are no filters, the rule applies to all rows - for filter in &rule.filters {} - // TODO: Split operations should be processed separately, after all the other operations have been applied - // Apply all operations individually, adding as a column to the record map - for operation in &rule.operations {} - } - // for line in line { - output.serialize(line) - // } -} - pub struct DeriveNodeRunner { derive_node: DeriveNode, } diff --git a/src/filter.rs b/src/filter.rs index 6a0427a..e3c96a6 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -1,25 +1,14 @@ use std::collections::BTreeMap; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::{ - derive::{to_filter_rules, DataValidator, DataValidators, DeriveFilter}, + derive::{is_line_valid, to_filter_rules, DataValidators, DeriveFilter}, io::{RecordDeserializer, RecordSerializer}, node::RunnableNode, }; -fn is_line_valid(line: &BTreeMap, rules: &DataValidators) -> bool { - rules.iter().all(|rule| { - line.get(&rule.get_field_name()).map_or(true, |value| { - if value.trim().is_empty() { - true - } else { - rule.is_valid(value) - } - }) - }) -} - /** * Write all lines from the input file to the output file, skipping records * that don't satisfy the filter criteria @@ -33,7 +22,7 @@ pub fn filter_file( let line: BTreeMap = line; output.write_header(&line)?; - if (is_line_valid(&line, &rules)) { + if is_line_valid(&line, &rules) { output.write_record(&line)?; } @@ -48,7 +37,7 @@ pub fn filter_file( Ok(()) } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct FilterNode { pub filters: Vec, pub input_file_path: String, diff --git a/src/graph.rs b/src/graph.rs index 9909cb9..5ba4a0f 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -9,6 +9,7 @@ use std::{ }; use chrono::Local; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::{ @@ -19,7 +20,7 @@ use crate::{ upload_to_db::{UploadNode, UploadNodeRunner}, }; -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum NodeConfiguration { FileNode, MoveMoneyNode(MoveMoneyNode), @@ -32,13 +33,13 @@ pub enum NodeConfiguration { Dynamic, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DynamicConfiguration { pub node_type: String, pub parameters: HashMap, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct NodeInfo { pub name: String, pub output_files: Vec, @@ -46,13 +47,13 @@ pub struct NodeInfo { pub dynamic_configuration: Option, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum MoveMoneyAmountType { Percent, Amount, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MoveMoneyRule { pub from_account: String, pub from_cc: String, @@ -62,7 +63,7 @@ pub struct MoveMoneyRule { pub amount_type: MoveMoneyAmountType, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MoveMoneyNode { pub departments_path: String, pub accounts_path: String, @@ -70,40 +71,40 @@ pub struct MoveMoneyNode { pub rules: Vec, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum JoinType { Left, Inner, Right, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MergeJoin { pub join_type: JoinType, pub left_column_name: String, pub right_column_name: String, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct MergeNode { pub input_files: Vec, pub joins: Vec, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub enum CodeRuleLanguage { Javascript, Rust, Go, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct CodeRuleNode { pub language: CodeRuleLanguage, pub text: String, } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct Node { pub id: i64, pub info: NodeInfo, @@ -135,17 +136,18 @@ fn get_runnable_node(node: Node) -> Box { } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct Graph { pub name: String, pub nodes: Vec, } +#[derive(Debug)] pub enum NodeStatus { Completed, Running, - // TODO: Error code? - Failed, + // Error code + Failed(anyhow::Error), } pub struct RunnableGraph { @@ -157,11 +159,14 @@ impl RunnableGraph { RunnableGraph { graph } } - pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> { + pub fn run_default_tasks(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()> + where + F: Fn(i64, NodeStatus), + { self.run( num_threads, Box::new(|node| get_runnable_node(node)), - |id, status| {}, + status_changed, ) } @@ -196,7 +201,7 @@ impl RunnableGraph { node_status_changed_fn(node.id, NodeStatus::Running); match get_node_fn(node.clone()).run() { Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed), - Err(_) => node_status_changed_fn(node.id, NodeStatus::Failed), + Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)), }; } return Ok(()); @@ -216,8 +221,13 @@ impl RunnableGraph { let node_fn = node_fn.clone(); let handle = thread::spawn(move || { for node in rx { - node_fn(node.clone()).run(); - finish_task.send((n, node)); + let status = match node_fn(node.clone()).run() { + Ok(_) => NodeStatus::Completed, + Err(err) => NodeStatus::Failed(err), + }; + finish_task + .send((n, node, status)) + .expect("Failed to notify node status completion"); } println!("Thread {} finished", n); }); @@ -238,15 +248,14 @@ impl RunnableGraph { let node = nodes.remove(i); node_status_changed_fn(node.id, NodeStatus::Running); running_nodes.insert(node.id); - senders[i % senders.len()].send(node); + senders[i % senders.len()].send(node)?; } } // Run each dependent node after a graph above finishes. - for (n, node) in listen_finish_task { + for (n, node, error) in listen_finish_task { running_threads.remove(&n); - // TODO: Add error check here - node_status_changed_fn(node.id, NodeStatus::Completed); + node_status_changed_fn(node.id, error); running_nodes.remove(&node.id); completed_nodes.insert(node.id); // Run all the nodes that can be run and aren't in completed @@ -261,7 +270,7 @@ impl RunnableGraph { let node = nodes.remove(i); for i in 0..num_threads { if !running_threads.contains(&i) { - senders[i].send(node); + senders[i].send(node)?; break; } } @@ -293,7 +302,7 @@ mod tests { #[test] fn test_basic() -> anyhow::Result<()> { - let mut graph = RunnableGraph { + let graph = RunnableGraph { graph: super::Graph { name: "Test".to_owned(), nodes: vec![super::Node { @@ -313,7 +322,7 @@ mod tests { }], }, }; - graph.run_default_tasks(2)?; + graph.run_default_tasks(2, |_, _| {})?; Ok(()) } } diff --git a/src/lib.rs b/src/lib.rs index 5ceeb4a..81afa2c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -14,7 +14,7 @@ pub use self::shared_models::*; pub mod code_rule; pub mod derive; pub mod filter; -mod graph; +pub mod graph; mod io; pub mod link; pub mod node; diff --git a/src/main.rs b/src/main.rs index 57c8848..07e9c88 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,248 +1,9 @@ -use std::{collections::HashMap, fs::File, io::BufWriter, path::PathBuf}; - -use clap::{Parser, Subcommand}; -use coster_rs::{create_products::InputFile, SourceType}; - -#[derive(Parser)] -#[command(name = "coster-rs")] -#[command(author = "Pivato M. ")] -#[command(version = "0.0.1")] -#[command(about = "Simple, fast, efficient costing tool", long_about = None)] -struct Cli { - #[clap(subcommand)] - command: Commands, -} - -#[derive(Subcommand)] -enum Commands { - /// Moves money between accounts and departments, using the given rules and lines - MoveMoney { - #[arg(short = 'r', long, value_name = "FILE")] - rules: PathBuf, - - #[arg(short = 'l', long, value_name = "FILE")] - lines: PathBuf, - - #[arg(short = 'a', long, value_name = "FILE")] - accounts: PathBuf, - - #[arg(short = 'c', long, value_name = "FILE")] - cost_centres: PathBuf, - - #[arg(short, long, value_name = "FILE")] - output: Option, - - #[arg(short, long)] - use_numeric_accounts: bool, - - #[arg(short, long)] - flush_pass: bool, - }, - /// Allocates servicing department amounts to operating departments - AllocateOverheads { - #[arg(short, long, value_name = "FILE")] - lines: PathBuf, - - #[arg(short, long, value_name = "FILE")] - accounts: PathBuf, - - #[arg(short = 's', long, value_name = "FILE")] - allocation_statistics: PathBuf, - - #[arg(short, long, value_name = "FILE")] - areas: PathBuf, - - #[arg(short, long, value_name = "FILE")] - cost_centres: PathBuf, - - #[arg(short, long)] - use_numeric_accounts: bool, - - #[arg(long, default_value = "E")] - account_type: String, - - #[arg(short, long)] - exclude_negative_allocation_statistics: bool, - - #[arg(short = 'f', long)] - show_from: bool, - - #[arg(short, long, default_value = "0.00000000000000001")] - zero_threshold: f64, - - #[arg(short, long, value_name = "FILE", default_value = "alloc_output.csv")] - output: PathBuf, - - #[arg(short, long)] - msgpack_serialisation: bool, - }, - CreateProducts { - #[arg(short, long, value_name = "FILE")] - definitions: PathBuf, - - #[arg(short, long, value_name = "FILE")] - encounters: PathBuf, - - #[arg(short, long, value_name = "FILE")] - services: PathBuf, - - #[arg(short, long, value_name = "FILE")] - transfers: PathBuf, - - #[arg(short, long, value_name = "FILE")] - procedures: PathBuf, - - #[arg(short, long, value_name = "FILE")] - diagnoses: PathBuf, - - #[arg(short, long, value_name = "FILE")] - patients: PathBuf, - - #[arg(short, long, value_name = "FILE")] - revenues: PathBuf, - - #[arg(short, long, value_name = "FILE")] - output: PathBuf, - }, -} +use clap::Parser; +use cli::Cli; +mod cli; fn main() -> anyhow::Result<()> { + env_logger::init(); let cli = Cli::parse(); - - match cli.command { - Commands::MoveMoney { - rules, - lines, - accounts, - cost_centres, - output, - use_numeric_accounts, - flush_pass, - } => coster_rs::move_money( - &mut csv::Reader::from_path(rules)?, - &mut csv::Reader::from_path(lines)?, - &mut csv::Reader::from_path(accounts)?, - &mut csv::Reader::from_path(cost_centres)?, - &mut csv::Writer::from_path(output.unwrap_or(PathBuf::from("output.csv")))?, - use_numeric_accounts, - flush_pass, - ), - Commands::AllocateOverheads { - lines, - accounts, - allocation_statistics, - areas, - cost_centres, - use_numeric_accounts, - account_type, - exclude_negative_allocation_statistics, - show_from, - zero_threshold, - output, - msgpack_serialisation, - } => { - if msgpack_serialisation { - let mut file = BufWriter::new(File::create(output)?); - coster_rs::reciprocal_allocation( - &mut csv::Reader::from_path(lines)?, - &mut csv::Reader::from_path(accounts)?, - &mut csv::Reader::from_path(allocation_statistics)?, - &mut csv::Reader::from_path(areas)?, - &mut csv::Reader::from_path(cost_centres)?, - &mut rmp_serde::Serializer::new(&mut file), - use_numeric_accounts, - exclude_negative_allocation_statistics, - true, - account_type, - show_from, - zero_threshold, - ) - } else { - coster_rs::reciprocal_allocation( - &mut csv::Reader::from_path(lines)?, - &mut csv::Reader::from_path(accounts)?, - &mut csv::Reader::from_path(allocation_statistics)?, - &mut csv::Reader::from_path(areas)?, - &mut csv::Reader::from_path(cost_centres)?, - &mut csv::Writer::from_path(output)?, - use_numeric_accounts, - exclude_negative_allocation_statistics, - true, - account_type, - show_from, - zero_threshold, - ) - } - } - Commands::CreateProducts { - definitions, - encounters, - services, - transfers, - procedures, - diagnoses, - patients, - revenues, - output, - } => { - let mut inputs = HashMap::new(); - inputs.insert( - SourceType::Encounter, - InputFile { - file_path: encounters, - joins: HashMap::new(), - date_order_column: Some("StartDateTime".to_owned()), - }, - ); - inputs.insert( - SourceType::Service, - InputFile { - file_path: services, - joins: HashMap::new(), - date_order_column: Some("StartDateTime".to_owned()), - }, - ); - inputs.insert( - SourceType::Transfer, - InputFile { - file_path: transfers, - joins: HashMap::new(), - date_order_column: Some("StartDateTime".to_owned()), - }, - ); - inputs.insert( - SourceType::CodingProcedure, - InputFile { - file_path: procedures, - joins: HashMap::new(), - date_order_column: Some("ProcedureDateTime".to_owned()), - }, - ); - inputs.insert( - SourceType::CodingDiagnosis, - InputFile { - file_path: diagnoses, - joins: HashMap::new(), - date_order_column: None, - }, - ); - inputs.insert( - SourceType::Patient, - InputFile { - file_path: patients, - joins: HashMap::new(), - date_order_column: None, - }, - ); - inputs.insert( - SourceType::Revenue, - InputFile { - file_path: revenues, - joins: HashMap::new(), - date_order_column: None, - }, - ); - coster_rs::create_products::create_products_polars(definitions, vec![], output) - } - } + cli.run() } diff --git a/src/sql_rule.rs b/src/sql_rule.rs index 547a098..1d00d2e 100644 --- a/src/sql_rule.rs +++ b/src/sql_rule.rs @@ -5,11 +5,12 @@ use polars::{ prelude::{CsvWriter, LazyCsvReader, LazyFileListReader}, }; use polars_sql::SQLContext; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use crate::node::RunnableNode; -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct CSVFile { name: String, path: String, @@ -30,7 +31,7 @@ fn run_sql(files: &Vec, output_path: &String, query: &String) -> anyhow Ok(()) } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct SQLNode { pub files: Vec, pub output_file: String, diff --git a/src/upload_to_db.rs b/src/upload_to_db.rs index 034b7df..1a591ca 100644 --- a/src/upload_to_db.rs +++ b/src/upload_to_db.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use anyhow::bail; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use sqlx::{Any, Pool, QueryBuilder}; @@ -69,7 +70,7 @@ pub async fn upload_file_bulk(pool: &Pool, upload_node: &UploadNode) -> any } } -#[derive(Serialize, Deserialize, Clone)] +#[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct UploadNode { file_path: String, table_name: String,