From 19e08f9ca7b949cbf13a03325bc06dd786c37e2e Mon Sep 17 00:00:00 2001 From: Michael Pivato Date: Wed, 3 Jan 2024 22:43:08 +1030 Subject: [PATCH] Add necessary joins --- src/main.rs | 14 +++++----- src/products/create_products.rs | 46 ++++++++++++++++++++++++--------- src/products/csv.rs | 16 ++++-------- 3 files changed, 46 insertions(+), 30 deletions(-) diff --git a/src/main.rs b/src/main.rs index ca3daa6..7169415 100644 --- a/src/main.rs +++ b/src/main.rs @@ -190,7 +190,7 @@ fn main() -> anyhow::Result<()> { SourceType::Encounter, InputFile { file_path: encounters, - joins: vec![], + joins: HashMap::new(), date_order_column: Some("StartDateTime".to_owned()), }, ); @@ -198,7 +198,7 @@ fn main() -> anyhow::Result<()> { SourceType::Service, InputFile { file_path: services, - joins: vec![], + joins: HashMap::new(), date_order_column: Some("StartDateTime".to_owned()), }, ); @@ -206,7 +206,7 @@ fn main() -> anyhow::Result<()> { SourceType::Transfer, InputFile { file_path: transfers, - joins: vec![], + joins: HashMap::new(), date_order_column: Some("StartDateTime".to_owned()), }, ); @@ -214,7 +214,7 @@ fn main() -> anyhow::Result<()> { SourceType::CodingProcedure, InputFile { file_path: procedures, - joins: vec![], + joins: HashMap::new(), date_order_column: Some("ProcedureDateTime".to_owned()), }, ); @@ -222,7 +222,7 @@ fn main() -> anyhow::Result<()> { SourceType::CodingDiagnosis, InputFile { file_path: diagnoses, - joins: vec![], + joins: HashMap::new(), date_order_column: None, }, ); @@ -230,7 +230,7 @@ fn main() -> anyhow::Result<()> { SourceType::Patient, InputFile { file_path: patients, - joins: vec![], + joins: HashMap::new(), date_order_column: None, }, ); @@ -238,7 +238,7 @@ fn main() -> anyhow::Result<()> { SourceType::Revenue, InputFile { file_path: revenues, - joins: vec![], + joins: HashMap::new(), date_order_column: None, }, ); diff --git a/src/products/create_products.rs b/src/products/create_products.rs index 337e236..3c9268a 100644 --- a/src/products/create_products.rs +++ b/src/products/create_products.rs @@ -1,4 +1,7 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::{ + collections::{HashMap, HashSet}, + path::PathBuf, +}; use anyhow::anyhow; use chrono::NaiveDateTime; @@ -8,7 +11,7 @@ use polars::lazy::dsl::*; use polars::prelude::*; use serde::Serialize; -use super::csv::{read_definitions, Definition, FileJoin, SourceType}; +use super::csv::{read_definitions, Component, Definition, FileJoin, SourceType}; // TODO: Polars suggests this, but docs suggest it doesn't have very good platform support //use jemallocator::Jemalloc; @@ -36,7 +39,7 @@ struct Product { pub struct InputFile { pub file_path: PathBuf, - pub joins: Vec, + pub joins: HashMap, // if not specified, then don't allow change in type builds, as there's no way to detect changes over time pub date_order_column: Option, } @@ -54,12 +57,6 @@ pub fn create_products_polars( Ok(()) } -//TODO: This will iterate over the file multiple times, which could technically be -// slower than just going through the file once since reading from disk is slower -// than reading from memory. However, reading from -// Also, we can use a custom definition format that is translated from the -// ppm format, so things like constraints/filters are one thing, and way more generic -// (i.e. filter can be based on a join between files). pub fn build_polars( definition: &Definition, inputs: &HashMap, @@ -87,11 +84,36 @@ pub fn build_polars( let input_file = inputs .get(&definition.source_type) .ok_or(anyhow!("Failed to find valid file"))?; - let reader = LazyCsvReader::new(&input_file.file_path) + let mut reader = LazyCsvReader::new(&input_file.file_path) .has_header(true) .finish()?; - // TODO: Do joins based on usage in definitions components and filters. Ideally just join the columns that are actually wanted. - // Can do this by first going over each component/filter, and + let mut required_files = HashSet::new(); + for component in &definition.components { + if let Component::Field(file, field) = component { + required_files.insert(file); + } + } + for filter in &definition.filters { + required_files.insert(&filter.file); + } + for source_type in required_files { + // TODO: Better error messages + if source_type != &definition.source_type { + let source_file = inputs + .get(&source_type) + .ok_or(anyhow!("Input file was not specified for source type"))?; + let join_reader = LazyCsvReader::new(source_file.file_path.clone()).finish()?; + let left_column = input_file + .joins + .get(source_type) + .ok_or(anyhow!("Failed to get left join column"))?; + let right_column = source_file + .joins + .get(&definition.source_type) + .ok_or(anyhow!("Failed to get right join column"))?; + reader = reader.inner_join(join_reader, col(&left_column), col(&right_column)); + } + } let mut filtered = match filter { Some(filter) => reader.filter(filter), diff --git a/src/products/csv.rs b/src/products/csv.rs index 4f4ddc2..262df22 100644 --- a/src/products/csv.rs +++ b/src/products/csv.rs @@ -6,12 +6,12 @@ use chrono::NaiveDateTime; #[derive(Hash, PartialEq, PartialOrd)] pub struct Filter { pub filter_type: FilterType, - pub file: String, + pub file: SourceType, pub field: String, pub value: String, } -#[derive(Hash, PartialEq, PartialOrd, Eq, Ord)] +#[derive(Hash, PartialEq, PartialOrd, Eq, Ord, Clone)] pub enum SourceType { CodingDiagnosis, CodingProcedure, @@ -298,7 +298,7 @@ where let source_type = SourceType::try_from(record.get("FilterSourceType").unwrap())?; Filter { - // TODO: This all looks wrong + // TODO: This looks wrong filter_type: if record.get("FilterNotIn").unwrap() != "" { FilterType::Equal } else { @@ -307,12 +307,7 @@ where // TODO: extra/classification types need to append Extra:/Classification: to the start of the field field: record.get("FilterField").unwrap().clone(), value: record.get("FilterValue").unwrap().clone(), - // TODO: Work out a way to handle this - file: match source_type { - SourceType::CodingDiagnosis => "", - _ => "", - } - .to_owned(), + file: source_type, } }; let all_filters = &mut all_definitions @@ -354,8 +349,7 @@ where field: record.get("ConstraintColumn").unwrap().to_owned(), filter_type, value: record.get("ConstraintValue").unwrap().to_owned(), - // TODO: Figure this out, should be determined from the source type - file: "".to_owned(), + file: source_type, } }; let all_filters = &mut all_definitions