diff --git a/src/bin/agent2/main.rs b/src/bin/agent2/main.rs index 02e5b03..dd515ae 100644 --- a/src/bin/agent2/main.rs +++ b/src/bin/agent2/main.rs @@ -1,7 +1,4 @@ -use sqlx::{ - mssql::{MssqlConnectOptions, MssqlPoolOptions}, - ConnectOptions, -}; +use sqlx::mssql::MssqlPoolOptions; #[tokio::main] async fn main() -> anyhow::Result<()> { diff --git a/src/filter.rs b/src/filter.rs index 6ccffa1..454fc85 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -3,23 +3,23 @@ use std::{collections::HashMap, io::Read, str::FromStr}; use crate::io::RecordSerializer; pub enum Comparator { - EQUAL(T), - NOT_EQUAL(T), - GREATER_THAN(T), - LESS_THAN(T), - IN(Vec), - NOT_IN(Vec), + Equal(T), + NotEqual(T), + GreaterThan(T), + LessThan(T), + In(Vec), + NotIn(Vec), } impl Comparator { pub fn is_valid(&self, value: T) -> bool { match self { - Comparator::EQUAL(v) => value == *v, - Comparator::NOT_EQUAL(v) => value != *v, - Comparator::GREATER_THAN(v) => value > *v, - Comparator::LESS_THAN(v) => value < *v, - Comparator::IN(v) => v.contains(&value), - Comparator::NOT_IN(v) => !v.contains(&value), + Comparator::Equal(v) => value == *v, + Comparator::NotEqual(v) => value != *v, + Comparator::GreaterThan(v) => value > *v, + Comparator::LessThan(v) => value < *v, + Comparator::In(v) => v.contains(&value), + Comparator::NotIn(v) => !v.contains(&value), } } } @@ -31,7 +31,7 @@ pub trait FieldName { pub trait DataValidator: FieldName { // Whether the given value is valid for the validator - fn is_valid(&self, s: &String) -> bool; + fn is_valid(&self, s: &str) -> bool; } pub struct FilterRule { @@ -46,7 +46,7 @@ impl FieldName for FilterRule { } impl DataValidator for FilterRule { - fn is_valid(&self, s: &String) -> bool { + fn is_valid(&self, s: &str) -> bool { s.parse().map_or(false, |f| self.comparator.is_valid(f)) } } diff --git a/src/io.rs b/src/io.rs index 4c97a1e..d493c58 100644 --- a/src/io.rs +++ b/src/io.rs @@ -1,14 +1,8 @@ -use std::{ - io::{Read, Seek, Write}, - thread::current, -}; +use std::io::{Read, Seek, Write}; use anyhow::bail; use csv::Position; -use rmp_serde::{ - decode::{ReadReader, ReadRefReader, ReadSlice}, - from_read, Deserializer, Serializer, -}; +use rmp_serde::{decode::ReadReader, Deserializer, Serializer}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; pub trait RecordSerializer { diff --git a/src/lib.rs b/src/lib.rs index b39646d..4695db3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub use self::overhead_allocation::*; mod products; pub use self::products::create_products; +pub use self::products::CreateProductInputs; mod shared_models; pub use self::shared_models::*; @@ -56,7 +57,7 @@ pub extern "C" fn move_money_from_text( } #[no_mangle] -pub extern "C" fn move_money_from_text_free(s: *mut c_char) { +pub unsafe extern "C" fn move_money_from_text_free(s: *mut c_char) { unsafe { if s.is_null() { return; @@ -150,7 +151,7 @@ fn unwrap_c_char<'a>(s: *const c_char) -> &'a CStr { } #[no_mangle] -pub extern "C" fn allocate_overheads_from_text_free(s: *mut c_char) { +pub unsafe extern "C" fn allocate_overheads_from_text_free(s: *mut c_char) { unsafe { if s.is_null() { return; diff --git a/src/main.rs b/src/main.rs index 01daa25..cd1664e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ use std::{fs::File, io::BufWriter, path::PathBuf}; use clap::{Parser, Subcommand}; +use coster_rs::CreateProductInputs; #[derive(Parser)] #[command(name = "coster-rs")] @@ -177,11 +178,13 @@ fn main() -> anyhow::Result<()> { output, } => coster_rs::create_products( &mut csv::Reader::from_path(definitions)?, - &mut csv::Reader::from_path(encounters)?, - &mut csv::Reader::from_path(services)?, - &mut csv::Reader::from_path(transfers)?, - &mut csv::Reader::from_path(procedures)?, - &mut csv::Reader::from_path(diagnoses)?, + CreateProductInputs { + encounters: csv::Reader::from_path(encounters)?, + services: csv::Reader::from_path(services)?, + transfers: csv::Reader::from_path(transfers)?, + procedures: csv::Reader::from_path(procedures)?, + diagnoses: csv::Reader::from_path(diagnoses)?, + }, &mut csv::Writer::from_path(output)?, 1000000, ), diff --git a/src/move_money.rs b/src/move_money.rs index 6396e6c..e626dce 100644 --- a/src/move_money.rs +++ b/src/move_money.rs @@ -201,45 +201,41 @@ where let is_separator = movement_rule.apply == "-DIVIDER-"; let from_accounts = if is_separator { HashSet::new() + } else if movement_rule.cost_output.is_some() { + account_mappings + .iter() + .filter(|(_, account)| { + account.cost_output.is_some() + && account.cost_output.clone().unwrap() + == movement_rule.cost_output.clone().unwrap() + }) + .map(|(code, _)| code.clone()) + .collect() } else { - if movement_rule.cost_output.is_some() { - account_mappings - .iter() - .filter(|(_, account)| { - account.cost_output.is_some() - && account.cost_output.clone().unwrap() - == movement_rule.cost_output.clone().unwrap() - }) - .map(|(code, _)| code.clone()) - .collect() - } else { - extract_range( - movement_rule.source_from_account, - movement_rule.source_to_account, - &all_accounts_sorted, - ) - } + extract_range( + movement_rule.source_from_account, + movement_rule.source_to_account, + &all_accounts_sorted, + ) }; let to_accounts = if is_separator { HashSet::new() + } else if movement_rule.cost_output.is_some() { + account_mappings + .iter() + .filter(|(_, account)| { + account.cost_output.is_some() + && account.cost_output.clone().unwrap() + == movement_rule.cost_output.clone().unwrap() + }) + .map(|(code, _)| code.clone()) + .collect() } else { - if movement_rule.cost_output.is_some() { - account_mappings - .iter() - .filter(|(_, account)| { - account.cost_output.is_some() - && account.cost_output.clone().unwrap() - == movement_rule.cost_output.clone().unwrap() - }) - .map(|(code, _)| code.clone()) - .collect() - } else { - extract_range( - movement_rule.dest_from_account, - movement_rule.dest_to_account, - &all_accounts_sorted, - ) - } + extract_range( + movement_rule.dest_from_account, + movement_rule.dest_to_account, + &all_accounts_sorted, + ) }; let from_departments = if is_separator { HashSet::new() diff --git a/src/overhead_allocation.rs b/src/overhead_allocation.rs index 129b003..ac2e2f4 100644 --- a/src/overhead_allocation.rs +++ b/src/overhead_allocation.rs @@ -418,7 +418,7 @@ where initial_account_costs .into_iter() .map(|(account, total_cost)| AccountCost { - account: account, + account, summed_department_costs: total_cost, }) .collect(), @@ -752,7 +752,7 @@ fn solve_reciprocal_with_from( .map(|(department, value)| MovedAmount { account: total_costs.account.clone(), cost_centre: department.clone(), - value: value, + value, from_cost_centre: department.clone(), }) .filter(|cost| cost.value != 0_f64) diff --git a/src/products/create_products.rs b/src/products/create_products.rs index f2023b7..f868aa8 100644 --- a/src/products/create_products.rs +++ b/src/products/create_products.rs @@ -1,9 +1,6 @@ -use core::panic; use std::{ collections::HashMap, io::{Read, Write}, - sync::mpsc, - thread, }; use chrono::NaiveDateTime; @@ -31,15 +28,26 @@ struct Product { source_allocated_amount: Option, } +pub struct CreateProductInputs +where + E: Read, + S: Read, + T: Read, + P: Read, + Di: Read, +{ + pub encounters: csv::Reader, + pub services: csv::Reader, + pub transfers: csv::Reader, + pub procedures: csv::Reader

, + pub diagnoses: csv::Reader, +} + // TODO: Build from linked dataset is pretty hard, it potentially requires knowing everything abuot the previous year's // cosing run (BSCO, Dataset_Encounter_Cache, etc). pub fn create_products( definitions: &mut csv::Reader, - encounters: &mut csv::Reader, - services: &mut csv::Reader, - transfers: &mut csv::Reader, - procedures: &mut csv::Reader

, - diagnoses: &mut csv::Reader, + product_inputs: CreateProductInputs, // TODO: Looks kind of bad, any other way around it? I'd rather not have to depend on crossbeam as well output: &mut csv::Writer, // TODO: Default to 10 million or something sane @@ -82,7 +90,7 @@ where // TODO: Try with and without rayon, should be able to help I think as we're going through so much data sequentially, // although we're still likely to be bottlenecked by just write-speed - let mut encounters = encounters; + let mut encounters = product_inputs.encounters; let headers = encounters.headers()?.clone(); for encounter in encounters.records() { @@ -105,9 +113,9 @@ where } let field = field.unwrap(); if filter.equal { - return filter.value == *field; + filter.value == *field } else { - return filter.value != *field; + filter.value != *field } })) && (definition.constraints.is_empty() @@ -130,7 +138,7 @@ where } // TODO: Generate the built service - output.serialize(Product::default()); + output.serialize(Product::default())?; } // Now do the same with transfers, services, etc, referencing the encounter reader by using the