diff --git a/Cargo.lock b/Cargo.lock index 005a9b4..94abe65 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -624,7 +624,6 @@ dependencies = [ "serde", "serde_json", "sqlx", - "tempfile", "testcontainers", "tiberius", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 3c55a0c..8174889 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,25 +6,17 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -# https://nalgebra.org/docs/user_guide/getting_started nalgebra = "0.33" - -# https://docs.rs/csv/1.1.6/csv/ csv = "1" serde = { version = "1", features = ["derive"] } - -# num = "0.4" clap = { version = "4", features = ["derive"] } anyhow = "1" - itertools = "0.14.0" chrono = { version = "0.4.39", features = ["default", "serde"] } - rayon = "1.10.0" tokio = { version = "1.42.0", features = ["full"] } -sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any"] } +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any", "mysql", "postgres", "sqlite"] } rmp-serde = "1.1" -tempfile = "3.7" polars = { version = "0.45.1", features = ["lazy", "performant", "parquet", "streaming", "cse", "dtype-datetime"] } polars-sql = "0.45.1" serde_json = "1.0.122" diff --git a/src/graph/dynamic/csv_reader.rs b/src/graph/dynamic/csv_reader.rs index 32606a3..b96444d 100644 --- a/src/graph/dynamic/csv_reader.rs +++ b/src/graph/dynamic/csv_reader.rs @@ -1,4 +1,7 @@ -use super::{csv_row::CsvRow, dynamic_state::{vato007::ingey::types::HostCsvReader, DynamicState, ReadMapData}}; +use super::{ + csv_row::CsvRow, + dynamic_state::{vato007::ingey::types::HostCsvReader, DynamicState, ReadMapData}, +}; use csv::{Reader, StringRecord}; use polars::datatypes::AnyValue; use polars::prelude::{col, lit, LazyCsvReader, LazyFileListReader}; @@ -14,26 +17,38 @@ pub struct CsvReader { impl CsvReader { pub fn new(path: String) -> Self { let reader = Reader::from_path(&path).expect("Failed to create csv reader"); - CsvReader { - path, - reader, - } + CsvReader { path, reader } } } impl HostCsvReader for DynamicState { fn columns(&mut self, self_: wasmtime::component::Resource) -> Vec { - let resource = self.resources.get_mut(&self_).expect("Failed to find resource"); + let resource = self + .resources + .get_mut(&self_) + .expect("Failed to find resource"); if resource.reader.has_headers() { - resource.reader.headers().expect("Reader says it has headers but doesn't").iter().map(|element| element.to_owned()).collect() + resource + .reader + .headers() + .expect("Reader says it has headers but doesn't") + .iter() + .map(|element| element.to_owned()) + .collect() } else { vec![] } } // TODO: These next methods needs to be cleaned up badly - fn next(&mut self, self_: wasmtime::component::Resource) -> Result, String> { - let resource = self.resources.get_mut(&self_).expect("Failed to find resource"); + fn next( + &mut self, + self_: wasmtime::component::Resource, + ) -> Result, String> { + let resource = self + .resources + .get_mut(&self_) + .expect("Failed to find resource"); let mut buf = StringRecord::new(); let result = resource.reader.read_record(&mut buf); match result { @@ -41,25 +56,37 @@ impl HostCsvReader for DynamicState { if read { let mut record_map = BTreeMap::new(); if resource.reader.has_headers() { - resource.reader.headers().expect("Reader says it has headers but doesn't").iter().enumerate().for_each(|(i, name)| { - record_map.insert(name.to_owned(), buf.get(i).unwrap().to_owned()); - }); + resource + .reader + .headers() + .expect("Reader says it has headers but doesn't") + .iter() + .enumerate() + .for_each(|(i, name)| { + record_map.insert(name.to_owned(), buf.get(i).unwrap().to_owned()); + }); } - let result = self.resources.push(CsvRow { values: record_map }).expect(""); + let result = self + .resources + .push(CsvRow { values: record_map }) + .expect(""); Ok(result) } else { Err("No more records available to read".to_owned()) } } - Err(err) => { - Err(err.to_string()) - } + Err(err) => Err(err.to_string()), } } - - fn next_into_map(&mut self, self_: wasmtime::component::Resource) -> wasmtime::component::Resource { - let resource = self.resources.get_mut(&self_).expect("Failed to find resource"); + fn next_into_map( + &mut self, + self_: wasmtime::component::Resource, + ) -> wasmtime::component::Resource { + let resource = self + .resources + .get_mut(&self_) + .expect("Failed to find resource"); let mut buf = StringRecord::new(); let result = resource.reader.read_record(&mut buf); let record_map = match result { @@ -67,32 +94,50 @@ impl HostCsvReader for DynamicState { if read { let mut record_map = HashMap::new(); if resource.reader.has_headers() { - resource.reader.headers().expect("Reader says it has headers but doesn't").iter().enumerate().for_each(|(i, name)| { - record_map.insert(name.to_owned(), buf.get(i).unwrap().to_owned()); - }); + resource + .reader + .headers() + .expect("Reader says it has headers but doesn't") + .iter() + .enumerate() + .for_each(|(i, name)| { + record_map.insert(name.to_owned(), buf.get(i).unwrap().to_owned()); + }); } record_map } else { HashMap::new() } } - Err(_) => { - HashMap::new() - } + Err(_) => HashMap::new(), }; - self.resources.push(ReadMapData { data: record_map }).expect("") + self.resources + .push(ReadMapData { data: record_map }) + .expect("") } fn has_next(&mut self, self_: wasmtime::component::Resource) -> bool { - let resource = self.resources.get_mut(&self_).expect("Failed to find resource"); + let resource = self + .resources + .get_mut(&self_) + .expect("Failed to find resource"); !resource.reader.is_done() } // TODO: Clean this up as well #[doc = " Get a row by values in one or more columns"] - fn query(&mut self, self_: wasmtime::component::Resource, values: Vec<(String, String,)>) -> wasmtime::component::Resource { - let resource = self.resources.get_mut(&self_).expect("Failed to find resource"); - let mut df = LazyCsvReader::new(&resource.path).finish().expect("Failed to open file"); + fn query( + &mut self, + self_: wasmtime::component::Resource, + values: Vec<(String, String)>, + ) -> wasmtime::component::Resource { + let resource = self + .resources + .get_mut(&self_) + .expect("Failed to find resource"); + let mut df = LazyCsvReader::new(&resource.path) + .finish() + .expect("Failed to open file"); for (key, value) in values { df = df.filter(col(key).eq(lit(value))); } @@ -106,7 +151,9 @@ impl HostCsvReader for DynamicState { record_map.insert(field.name.to_string(), value.to_string()); } } - self.resources.push(CsvRow { values: record_map }).expect("Failed to create csv row") + self.resources + .push(CsvRow { values: record_map }) + .expect("Failed to create csv row") } fn read_into_string(&mut self, self_: wasmtime::component::Resource) -> String { @@ -121,4 +168,4 @@ impl HostCsvReader for DynamicState { self.resources.delete(rep)?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/dynamic/csv_readers.rs b/src/graph/dynamic/csv_readers.rs index c429251..14a5dfd 100644 --- a/src/graph/dynamic/csv_readers.rs +++ b/src/graph/dynamic/csv_readers.rs @@ -1,4 +1,7 @@ -use super::{csv_reader::CsvReader, dynamic_state::{vato007::ingey::types::HostCsvReaders, DynamicState}}; +use super::{ + csv_reader::CsvReader, + dynamic_state::{vato007::ingey::types::HostCsvReaders, DynamicState}, +}; use std::collections::HashMap; use wasmtime::component::Resource; @@ -8,8 +11,15 @@ pub struct CsvReadersData { } impl HostCsvReaders for DynamicState { - fn get_reader(&mut self, self_: Resource, name: String) -> Option> { - let resource = self.resources.get(&self_).expect("Failed to find own resource"); + fn get_reader( + &mut self, + self_: Resource, + name: String, + ) -> Option> { + let resource = self + .resources + .get(&self_) + .expect("Failed to find own resource"); let file_path = resource.readers.get(&name); if let Some(path) = file_path.cloned() { let csv_reader = CsvReader::new(path); @@ -23,4 +33,4 @@ impl HostCsvReaders for DynamicState { self.resources.delete(rep)?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/dynamic/csv_row.rs b/src/graph/dynamic/csv_row.rs index c013df5..3763692 100644 --- a/src/graph/dynamic/csv_row.rs +++ b/src/graph/dynamic/csv_row.rs @@ -8,22 +8,42 @@ pub struct CsvRow { impl HostCsvRow for DynamicState { fn columns(&mut self, self_: wasmtime::component::Resource) -> Vec { - let resource = self.resources.get(&self_).expect("Failed to find the required resource"); + let resource = self + .resources + .get(&self_) + .expect("Failed to find the required resource"); resource.values.keys().cloned().collect() } fn values(&mut self, self_: wasmtime::component::Resource) -> Vec { - let resource = self.resources.get(&self_).expect("Failed to find the required resource"); + let resource = self + .resources + .get(&self_) + .expect("Failed to find the required resource"); resource.values.values().cloned().collect() } - fn entries(&mut self, self_: wasmtime::component::Resource) -> Vec<(String, String,)> { - let resource = self.resources.get(&self_).expect("Failed to find the required resource"); - resource.values.keys().map(|key| (key.clone(), resource.values.get(key).unwrap().clone())).collect() + fn entries(&mut self, self_: wasmtime::component::Resource) -> Vec<(String, String)> { + let resource = self + .resources + .get(&self_) + .expect("Failed to find the required resource"); + resource + .values + .keys() + .map(|key| (key.clone(), resource.values.get(key).unwrap().clone())) + .collect() } - fn value(&mut self, self_: wasmtime::component::Resource, name: String) -> Option { - let resource = self.resources.get(&self_).expect("Failed to find the required resource"); + fn value( + &mut self, + self_: wasmtime::component::Resource, + name: String, + ) -> Option { + let resource = self + .resources + .get(&self_) + .expect("Failed to find the required resource"); resource.values.get(&name).cloned() } @@ -31,4 +51,4 @@ impl HostCsvRow for DynamicState { self.resources.delete(rep)?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/dynamic/csv_writer.rs b/src/graph/dynamic/csv_writer.rs index 045eef7..113f40f 100644 --- a/src/graph/dynamic/csv_writer.rs +++ b/src/graph/dynamic/csv_writer.rs @@ -19,20 +19,32 @@ impl CsvWriterData { } } - impl HostCsvWriter for DynamicState { - fn write_row(&mut self, self_: wasmtime::component::Resource, row: Vec<(String, String)>) -> () { - let resource = self.resources.get_mut(&self_).expect("Failed to find resource"); + fn write_row( + &mut self, + self_: wasmtime::component::Resource, + row: Vec<(String, String)>, + ) -> () { + let resource = self + .resources + .get_mut(&self_) + .expect("Failed to find resource"); let write_data: BTreeMap = row.into_iter().collect(); if !resource.wrote_header { - resource.writer.write_header(&write_data).expect("Failed to write header"); + resource + .writer + .write_header(&write_data) + .expect("Failed to write header"); resource.wrote_header = true; } - resource.writer.write_record(write_data.values()).expect("Failed to write row"); + resource + .writer + .write_record(write_data.values()) + .expect("Failed to write row"); } fn drop(&mut self, rep: wasmtime::component::Resource) -> wasmtime::Result<()> { self.resources.delete(rep)?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/dynamic/dynamic_state.rs b/src/graph/dynamic/dynamic_state.rs index 07fe2e9..d6a0f11 100644 --- a/src/graph/dynamic/dynamic_state.rs +++ b/src/graph/dynamic/dynamic_state.rs @@ -22,8 +22,10 @@ pub struct DynamicState { impl DynamicState { pub fn new() -> DynamicState { - DynamicState { resources: ResourceTable::new() } + DynamicState { + resources: ResourceTable::new(), + } } } -impl Host for DynamicState {} \ No newline at end of file +impl Host for DynamicState {} diff --git a/src/graph/dynamic/mod.rs b/src/graph/dynamic/mod.rs index a193b44..ea00e26 100644 --- a/src/graph/dynamic/mod.rs +++ b/src/graph/dynamic/mod.rs @@ -12,13 +12,12 @@ use crate::graph::dynamic::csv_writer::CsvWriterData; use crate::graph::dynamic::read_map::ReadMapData; use dynamic_state::{Dynamic, DynamicState}; -mod csv_row; mod csv_reader; mod csv_readers; +mod csv_row; mod csv_writer; mod read_map; - #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct DynamicNode { pub wasm_file_path: String, @@ -39,16 +38,17 @@ impl RunnableNode for DynamicNodeRunner { let component = Component::from_file(&engine, &self.dynamic_node.wasm_file_path)?; let mut linker = Linker::new(&engine); Dynamic::add_to_linker(&mut linker, |state: &mut DynamicState| state)?; - let mut store = Store::new( - &engine, - DynamicState::new(), - ); + let mut store = Store::new(&engine, DynamicState::new()); let bindings = Dynamic::instantiate(&mut store, &component, &linker)?; - let read_map = store.data_mut().resources.push(ReadMapData { data: HashMap::new() })?; - let readers = store.data_mut().resources.push(CsvReadersData { readers: self.dynamic_node.input_file_paths.clone() })?; + let read_map = store.data_mut().resources.push(ReadMapData { + data: HashMap::new(), + })?; + let readers = store.data_mut().resources.push(CsvReadersData { + readers: self.dynamic_node.input_file_paths.clone(), + })?; let writer = CsvWriterData::new(self.dynamic_node.output_file.clone())?; let writer = store.data_mut().resources.push(writer)?; bindings.call_evaluate(&mut store, read_map, readers, writer)?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/dynamic/read_map.rs b/src/graph/dynamic/read_map.rs index 28e4582..b220f54 100644 --- a/src/graph/dynamic/read_map.rs +++ b/src/graph/dynamic/read_map.rs @@ -7,12 +7,20 @@ pub struct ReadMapData { } impl HostReadMap for DynamicState { - fn get(&mut self, self_: wasmtime::component::Resource, key: String) -> Option { - self.resources.get(&self_).ok().map(|data| data.data.get(&key).cloned()).flatten() + fn get( + &mut self, + self_: wasmtime::component::Resource, + key: String, + ) -> Option { + self.resources + .get(&self_) + .ok() + .map(|data| data.data.get(&key).cloned()) + .flatten() } fn drop(&mut self, rep: wasmtime::component::Resource) -> wasmtime::Result<()> { self.resources.delete(rep)?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/pull_from_db.rs b/src/graph/pull_from_db.rs index 867ef8c..00c83fe 100644 --- a/src/graph/pull_from_db.rs +++ b/src/graph/pull_from_db.rs @@ -12,17 +12,22 @@ use tokio_util::compat::TokioAsyncWriteCompatExt; /** * Pull data from a db using a db query into a csv file that can be used by another node */ -async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) -> anyhow::Result<()> { +async fn pull_from_db( + executor: &mut impl QueryExecutor, + node: &PullFromDBNode, +) -> anyhow::Result<()> { let mut output_file = csv::Writer::from_path(node.output_data_source.path.clone())?; let mut first_row = true; - executor.get_rows(&node.query, &node.parameters, &mut move |row| { - if first_row { - output_file.write_header(&row)?; - first_row = false; - } - output_file.write_record(row.values())?; - Ok(()) - }).await?; + executor + .get_rows(&node.query, &node.parameters, &mut move |row| { + if first_row { + output_file.write_header(&row)?; + first_row = false; + } + output_file.write_record(row.values())?; + Ok(()) + }) + .await?; Ok(()) } @@ -81,14 +86,20 @@ mod tests { async fn test_sql_server() -> anyhow::Result<()> { let container = GenericImage::new("mcr.microsoft.com/mssql/server", "2022-latest") .with_exposed_port(1433.tcp()) - .with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned())) + .with_wait_for(WaitFor::message_on_stdout( + "Recovery is complete.".to_owned(), + )) .with_env_var("ACCEPT_EULA", "Y") .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") .start() .await?; let host = container.get_host().await?; let port = container.get_host_port_ipv4(1433).await?; - let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123", host, port).to_owned(); + let connection_string = format!( + "jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123", + host, port + ) + .to_owned(); let runner = PullFromDBNodeRunner { pull_from_db_node: PullFromDBNode { @@ -100,7 +111,7 @@ mod tests { path: PathBuf::from("test_pull.csv"), source_type: CSV, }, - } + }, }; runner.run().await?; let mut result_contents = String::new(); @@ -109,8 +120,7 @@ mod tests { "Test 1 ", - result_contents - , + result_contents, "Should pull the correct data from sql" ); Ok(()) diff --git a/src/graph/reduction.rs b/src/graph/reduction.rs index e1ce04f..d978691 100644 --- a/src/graph/reduction.rs +++ b/src/graph/reduction.rs @@ -9,13 +9,27 @@ use polars::prelude::{col, lit, CsvWriter, Expr, LazyCsvReader, LazyFileListRead use schemars::JsonSchema; use serde::{Deserialize, Serialize}; - -fn reduce(grouping_nodes: &Vec, operations: &Vec, input: &DataSource, output: &DataSource) -> anyhow::Result<()> { +fn reduce( + grouping_nodes: &Vec, + operations: &Vec, + input: &DataSource, + output: &DataSource, +) -> anyhow::Result<()> { let df = LazyCsvReader::new(&input.path).finish()?; let mut df = df - .group_by(grouping_nodes.iter().map(|column| col(column)).collect_vec()) - .agg(&operations.iter().map(|operation| operation.to_aggregate_function()).collect_vec()) + .group_by( + grouping_nodes + .iter() + .map(|column| col(column)) + .collect_vec(), + ) + .agg( + &operations + .iter() + .map(|operation| operation.to_aggregate_function()) + .collect_vec(), + ) .collect()?; let mut file = File::create(&output.path)?; CsvWriter::new(&mut file).finish(&mut df)?; @@ -58,7 +72,16 @@ impl ReductionOperation { ReductionOperationType::Min => col(&self.column_name).min(), ReductionOperationType::Average => col(&self.column_name).mean(), ReductionOperationType::Count => col(&self.column_name).count(), - ReductionOperationType::Concat(concat_properties) => lit(concat_properties.prefix.clone()).append(col(&self.column_name).list().join(lit(concat_properties.separator.clone()), true), false).append(lit(concat_properties.suffix.clone()), false), + ReductionOperationType::Concat(concat_properties) => { + lit(concat_properties.prefix.clone()) + .append( + col(&self.column_name) + .list() + .join(lit(concat_properties.separator.clone()), true), + false, + ) + .append(lit(concat_properties.suffix.clone()), false) + } } } } @@ -77,7 +100,12 @@ pub struct ReductionNodeRunner { #[async_trait] impl RunnableNode for ReductionNodeRunner { async fn run(&self) -> anyhow::Result<()> { - reduce(&self.reduction_node.grouping_columns, &self.reduction_node.operations, &self.reduction_node.input_file, &self.reduction_node.output_file)?; + reduce( + &self.reduction_node.grouping_columns, + &self.reduction_node.operations, + &self.reduction_node.input_file, + &self.reduction_node.output_file, + )?; Ok(()) } -} \ No newline at end of file +} diff --git a/src/graph/sql.rs b/src/graph/sql.rs index 5fc31d6..2c6cfc5 100644 --- a/src/graph/sql.rs +++ b/src/graph/sql.rs @@ -1,13 +1,12 @@ use futures::TryStreamExt; use futures_io::{AsyncRead, AsyncWrite}; -use sqlx::{Any, Column, Pool, Row}; +use sqlx::{Any, Column, Executor, Pool, Row}; use std::borrow::Borrow; use std::collections::BTreeMap; use tiberius::{Client, Query}; - pub trait QueryExecutor { - // Retrieve data from a database + /// Retrieve data from a database async fn get_rows( &mut self, query: &str, @@ -16,8 +15,11 @@ pub trait QueryExecutor { row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()>; - // Run a query that returns no results (e.g. bulk insert, insert) + /// Run a query that returns no results (e.g. bulk insert, insert) async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result; + + /// Execute an unprepared query. Avoid where possible as sql injection is possible if not used carefully + async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result; } impl QueryExecutor for Client { @@ -38,8 +40,10 @@ impl QueryExecutor for Client { let mut returned_row = BTreeMap::new(); // TODO: Check how empty columns are handled by tiberius for column in row.columns().into_iter() { - returned_row.insert(column.name().to_owned(), row.get(column.name()).unwrap_or_else(|| "") - .to_owned()); + returned_row.insert( + column.name().to_owned(), + row.get(column.name()).unwrap_or_else(|| "").to_owned(), + ); } row_consumer(returned_row)?; } @@ -57,6 +61,11 @@ impl QueryExecutor for Client { } Ok(result.rows_affected()[0]) } + + async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result { + let result = self.execute(query, &[]).await?; + Ok(result.rows_affected()[0]) + } } impl QueryExecutor for Pool { @@ -90,4 +99,9 @@ impl QueryExecutor for Pool { let result = query.execute(self.borrow()).await?; Ok(result.rows_affected()) } + + async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result { + let result = self.execute(query).await?; + Ok(result.rows_affected()) + } } diff --git a/src/graph/sql_rule.rs b/src/graph/sql_rule.rs index 023f402..34f4bdb 100644 --- a/src/graph/sql_rule.rs +++ b/src/graph/sql_rule.rs @@ -75,7 +75,7 @@ mod tests { data_source: DataSource { source_type: SourceType::CSV, path: PathBuf::from("./testing/test.csv"), - } + }, }], &output_path, &"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(), diff --git a/src/graph/upload_to_db.rs b/src/graph/upload_to_db.rs index 89f0095..27a19c2 100644 --- a/src/graph/upload_to_db.rs +++ b/src/graph/upload_to_db.rs @@ -4,7 +4,8 @@ use itertools::Itertools; use log::{log, Level}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use sqlx::AnyPool; +use sqlx::any::install_default_drivers; +use sqlx::{AnyConnection, AnyPool}; use std::{collections::HashMap, fmt::format}; use tiberius::{Config, EncryptionLevel}; use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -22,25 +23,35 @@ pub async fn upload_file_bulk( let mut rows_affected = None; if upload_node.column_mappings.is_none() { let insert_from_file_query = match upload_node.db_type { - DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)), - DBType::Mysql => Some(format!( - "LOAD DATA INFILE ? INTO {}", - upload_node.table_name, + DBType::Postgres => Some(format!( + r#"COPY "{}" FROM '{}' DELIMITERS ',' CSV HEADER QUOTE '"';"#, + upload_node.table_name, upload_node.file_path + )), + // TODO: Revisit this when sqlx lets this work, currently mysql won't allow this + // to execute since sqlx forces it to be inside of a prepare, which + // isn't allowed + DBType::Mysql => Some(format!( + "LOAD DATA INFILE '{}' INTO TABLE `{}` + FIELDS TERMINATED BY ',' ENCLOSED BY '\"' + LINES TERMINATED BY '\n' + IGNORE 1 LINES;", + upload_node.file_path, upload_node.table_name, + )), + DBType::Mssql => Some(format!( + "BULK INSERT [{}] FROM '{}' WITH ( FORMAT = 'CSV', FIRSTROW = 2 );", + upload_node.table_name, upload_node.file_path )), - DBType::Mssql => Some(format!("DECLARE @insertStatement VARCHAR(MAX) = 'BULK INSERT [{}] FROM ''' + @P1 + ''' WITH ( FORMAT = ''CSV'', FIRSTROW = 2 )'; EXEC(@insertStatement);", upload_node.table_name)), _ => None, }; if let Some(insert_from_file_query) = insert_from_file_query { - let result = executor - .execute_query( - &insert_from_file_query, - &[upload_node.file_path.clone()], - ) - .await; + let result = executor.execute_unchecked(&insert_from_file_query).await; if let Ok(result) = result { rows_affected = Some(result); } else { - log!(Level::Debug, "Failed to bulk insert, trying sql insert instead"); + log!( + Level::Debug, + "Failed to bulk insert, trying sql insert instead" + ); } } } @@ -73,7 +84,15 @@ pub async fn upload_file_bulk( for result in file_reader.records() { let result = result?; insert_query = insert_query - + format!("VALUES ({})", result.iter().enumerate().map(|(index, _)| db_type.get_param_name(index)).join(",")).as_str(); + + format!( + "VALUES ({})", + result + .iter() + .enumerate() + .map(|(index, _)| db_type.get_param_name(index)) + .join(",") + ) + .as_str(); let mut values = result.iter().map(|value| value.to_owned()).collect_vec(); params.append(&mut values); num_params += csv_columns.len(); @@ -149,6 +168,7 @@ impl RunnableNode for UploadNodeRunner { let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type).await?; } else { + install_default_drivers(); let mut pool = AnyPool::connect(&upload_node.connection_string).await?; upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type).await?; } @@ -158,11 +178,11 @@ impl RunnableNode for UploadNodeRunner { #[cfg(test)] mod tests { - use std::path::{self, PathBuf}; - use crate::graph::node::RunnableNode; use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner}; - use testcontainers::core::{IntoContainerPort, Mount, WaitFor}; + use sqlx::{AnyPool, Row}; + use std::path::PathBuf; + use testcontainers::core::{IntoContainerPort, WaitFor}; use testcontainers::runners::AsyncRunner; use testcontainers::{GenericImage, ImageExt}; use tiberius::{Config, EncryptionLevel}; @@ -170,14 +190,22 @@ mod tests { #[tokio::test] pub async fn check_bulk_upload_mssql() -> anyhow::Result<()> { - let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest") - .with_exposed_port(1433.tcp()) - .with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned())) - .with_env_var("ACCEPT_EULA", "Y") - .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") - .with_copy_to("/upload_to_db/test.csv", PathBuf::from("./testing/input/upload_to_db/test.csv")) - .start() - .await?; + let container = GenericImage::new( + "gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", + "latest", + ) + .with_exposed_port(1433.tcp()) + .with_wait_for(WaitFor::message_on_stdout( + "Recovery is complete.".to_owned(), + )) + .with_env_var("ACCEPT_EULA", "Y") + .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") + .with_copy_to( + "/upload_to_db/test.csv", + PathBuf::from("./testing/input/upload_to_db/test.csv"), + ) + .start() + .await?; let host = container.get_host().await?; let port = container.get_host_port_ipv4(1433).await?; let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned(); @@ -191,7 +219,7 @@ mod tests { post_script: None, db_type: DBType::Mssql, connection_string, - } + }, }; upload_node.run().await?; let mut config = Config::from_jdbc_string(&upload_node.upload_node.connection_string)?; @@ -209,4 +237,102 @@ mod tests { Ok(()) } + + #[tokio::test] + pub async fn check_bulk_upload_postgres() -> anyhow::Result<()> { + let container = GenericImage::new( + "gitea.michaelpivato.dev/vato007/ingey-test-db-postgres", + "latest", + ) + .with_exposed_port(5432.tcp()) + .with_wait_for(WaitFor::message_on_stderr( + "database system is ready to accept connections", + )) + .with_env_var("POSTGRES_PASSWORD", "TestOnlyContainer123") + .with_copy_to( + "/upload_to_db/test.csv", + PathBuf::from("./testing/input/upload_to_db/test.csv"), + ) + .start() + .await?; + let host = container.get_host().await?; + let port = container.get_host_port_ipv4(5432).await?; + let connection_string = format!( + "postgres://postgres:TestOnlyContainer123@{}:{}/testingeydatabase", + host, port + ) + .to_owned(); + let file = "/upload_to_db/test.csv".to_owned(); + let table_name = "My Test Table".to_string(); + let upload_node = UploadNodeRunner { + upload_node: UploadNode { + file_path: file.to_owned(), + table_name: table_name.clone(), + column_mappings: None, + post_script: None, + db_type: DBType::Postgres, + connection_string, + }, + }; + upload_node.run().await?; + let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?; + let result = sqlx::query(r#"SELECT * FROM "My Test Table""#) + .fetch_one(&pool) + .await?; + + let column1: i32 = result.try_get("column1")?; + let column2: &str = result.try_get("column2")?; + assert_eq!(1, column1); + assert_eq!("Hello", column2); + Ok(()) + } + + #[tokio::test] + pub async fn check_bulk_upload_mysql() -> anyhow::Result<()> { + let container = GenericImage::new( + "gitea.michaelpivato.dev/vato007/ingey-test-db-mysql", + "latest", + ) + .with_exposed_port(3306.tcp()) + .with_wait_for(WaitFor::message_on_stderr("ready for connections.")) + .with_env_var("MYSQL_ROOT_PASSWORD", "TestOnlyContainer123") + .with_copy_to( + "/upload_to_db/test.csv", + PathBuf::from("./testing/input/upload_to_db/test.csv"), + ) + // https://dev.mysql.com/doc/refman/8.4/en/server-system-variables.html#sysvar_secure_file_priv + .with_cmd(&["--secure-file-priv=/upload_to_db".to_owned()]) + .start() + .await?; + let host = container.get_host().await?; + let port = container.get_host_port_ipv4(3306).await?; + let connection_string = format!( + "mysql://root:TestOnlyContainer123@{}:{}/TestIngeyDatabase", + host, port + ) + .to_owned(); + let file = "/upload_to_db/test.csv".to_owned(); + let table_name = "My Test Table".to_string(); + let upload_node = UploadNodeRunner { + upload_node: UploadNode { + file_path: file.to_owned(), + table_name: table_name.clone(), + column_mappings: None, + post_script: None, + db_type: DBType::Mysql, + connection_string, + }, + }; + upload_node.run().await?; + let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?; + let result = sqlx::query(r#"SELECT * FROM `My Test Table`"#) + .fetch_one(&pool) + .await?; + + let column1: i32 = result.try_get("column1")?; + let column2: &str = result.try_get("column2")?; + assert_eq!(1, column1); + assert_eq!("Hello", column2); + Ok(()) + } } diff --git a/src/overhead_allocation.rs b/src/overhead_allocation.rs index 50e9951..2816c29 100644 --- a/src/overhead_allocation.rs +++ b/src/overhead_allocation.rs @@ -723,10 +723,10 @@ fn solve_reciprocal_no_from( .map(|cost| TotalDepartmentCost { department: cost.department, value: cost.value, // + if new_cost == 0_f64 || diff == 0_f64 { - // 0_f64 - // } else { - // cost.value / new_cost * diff - // }, + // 0_f64 + // } else { + // cost.value / new_cost * diff + // }, }) .filter(|cost| cost.value != 0_f64) .collect(), @@ -880,7 +880,7 @@ mod tests { false, 0.00001, ) - .unwrap(); + .unwrap(); assert_eq!(expected_final_allocations.len(), result.len()); let expected_account = &expected_final_allocations[0]; let final_account = &result[0]; @@ -890,10 +890,9 @@ mod tests { for final_department in &final_account.summed_department_costs { if final_department.department == expected_department_a.department { assert_eq!(*expected_department_a, *final_department); - }else if final_department.department == expected_department_b.department { + } else if final_department.department == expected_department_b.department { assert_eq!(*expected_department_b, *final_department); - } - else { + } else { panic!("Unknown department found!"); } } diff --git a/src/products/create_products.rs b/src/products/create_products.rs index acb40b2..9fa9070 100644 --- a/src/products/create_products.rs +++ b/src/products/create_products.rs @@ -134,9 +134,7 @@ pub fn build_polars( // TODO: What I really want to do is not use source type, instead I want to be referring to a file, which we translate from the sourcetype // to an actual filename. I don't want to be limited by a concept of 'sourcetype' at all, instead the definition should treat everything // the same, and just translate the imported csv format to the necessary files and columns in files that are expected to be input. - Component::Field(_, column) => { - built_expression = built_expression + col(column) - } + Component::Field(_, column) => built_expression = built_expression + col(column), } } diff --git a/src/products/csv.rs b/src/products/csv.rs index 87e0622..0f0adbd 100644 --- a/src/products/csv.rs +++ b/src/products/csv.rs @@ -381,7 +381,10 @@ mod tests { #[test] fn test_read_definitions() { let definitions = read_definitions( - &mut csv::Reader::from_path("testing/input/create_products/service_builder_definitions.csv").unwrap(), + &mut csv::Reader::from_path( + "testing/input/create_products/service_builder_definitions.csv", + ) + .unwrap(), ); if let Err(error) = &definitions { println!("{}", error)