diff --git a/src/graph/upload_to_db.rs b/src/graph/upload_to_db.rs index fa042a7..c7a5ba0 100644 --- a/src/graph/upload_to_db.rs +++ b/src/graph/upload_to_db.rs @@ -1,10 +1,11 @@ use anyhow::bail; use async_trait::async_trait; use itertools::Itertools; +use log::{log, Level}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use sqlx::AnyPool; -use std::collections::HashMap; +use std::{collections::HashMap, fmt::format}; use tiberius::{Config, EncryptionLevel}; use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -17,27 +18,8 @@ pub async fn upload_file_bulk( upload_node: &UploadNode, bind_limit: usize, db_type: &DBType, - allow_table_creation: bool, ) -> anyhow::Result { let mut rows_affected = None; - let mut num_matching_tables: usize = 0; - executor.get_rows(&format!("SELECT COUNT(*) Count -FROM information_schema.tables -WHERE table_schema = {} - AND table_name = ? -LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut move |row| { - num_matching_tables = row.get("Count").map(|count| count.parse().unwrap_or(0)).unwrap_or(0); - Ok(()) - }).await?; - if num_matching_tables == 0 { - if allow_table_creation { - // TODO: Create the table with the columns in the file to be uploaded - - } else { - bail!("Table creation not allowed and table does not yet exist, aborting"); - } - } - if upload_node.column_mappings.is_none() { let insert_from_file_query = match upload_node.db_type { DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)), @@ -45,7 +27,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov "LOAD DATA INFILE ? INTO {}", upload_node.table_name, )), - DBType::Mssql => Some(format!("BULK INSERT [{}] FROM ? WITH ( FORMAT = 'CSV');", upload_node.table_name)), + DBType::Mssql => Some(format!("DECLARE @insertStatement VARCHAR(MAX) = 'BULK INSERT [{}] FROM ''' + @P1 + ''' WITH ( FORMAT = ''CSV'', FIRSTROW = 2 )'; EXEC(@insertStatement);", upload_node.table_name)), _ => None, }; if let Some(insert_from_file_query) = insert_from_file_query { @@ -57,13 +39,14 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov .await; if let Ok(result) = result { rows_affected = Some(result); + } else { + log!(Level::Debug, "Failed to bulk insert, trying sql insert instead"); } } } if rows_affected == None { let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?; - let csv_columns = file_reader .headers()? .iter() @@ -90,7 +73,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov for result in file_reader.records() { let result = result?; insert_query = insert_query - + format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str(); + + format!("VALUES ({})", result.iter().enumerate().map(|(index, _)| db_type.get_param_name(index)).join(",")).as_str(); let mut values = result.iter().map(|value| value.to_owned()).collect_vec(); params.append(&mut values); num_params += csv_columns.len(); @@ -128,13 +111,13 @@ pub enum DBType { } impl DBType { - pub fn get_schema_fn(&self) -> String { + pub fn get_param_name(&self, index: usize) -> String { match self { - DBType::Postgres => "CURRENT_SCHEMA()", - DBType::Mysql => "DATABASE()", - DBType::Mssql => "DB_NAME()", - DBType::Sqlite => "'sqlite'", - }.to_owned() + DBType::Postgres => format!("${}", index), + DBType::Mysql => "?".to_owned(), + DBType::Mssql => format!("@P{}", index), + DBType::Sqlite => "?".to_owned(), + } } } @@ -147,7 +130,6 @@ pub struct UploadNode { post_script: Option, db_type: DBType, connection_string: String, - allow_table_creation: bool, } pub struct UploadNodeRunner { @@ -165,10 +147,10 @@ impl RunnableNode for UploadNodeRunner { let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; tcp.set_nodelay(true)?; let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; - upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?; + upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type).await?; } else { let mut pool = AnyPool::connect(&upload_node.connection_string).await?; - upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?; + upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type).await?; } Ok(()) } @@ -176,27 +158,30 @@ impl RunnableNode for UploadNodeRunner { #[cfg(test)] mod tests { + use std::path; + use crate::graph::node::RunnableNode; use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner}; - use testcontainers::core::{IntoContainerPort, WaitFor}; + use testcontainers::core::{IntoContainerPort, Mount, WaitFor}; use testcontainers::runners::AsyncRunner; use testcontainers::{GenericImage, ImageExt}; use tiberius::{Config, EncryptionLevel}; use tokio_util::compat::TokioAsyncWriteCompatExt; #[tokio::test] - pub async fn check_basic_upload() -> anyhow::Result<()> { + pub async fn check_bulk_upload_mssql() -> anyhow::Result<()> { let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest") .with_exposed_port(1433.tcp()) .with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned())) .with_env_var("ACCEPT_EULA", "Y") .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") + .with_mount(Mount::bind_mount(path::absolute("./testing/input/upload_to_db")?.to_string_lossy().to_string(), "/upload_to_db")) .start() .await?; let host = container.get_host().await?; let port = container.get_host_port_ipv4(1433).await?; let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned(); - let file = "testing/input/upload_to_db/test.csv"; + let file = "/upload_to_db/test.csv".to_owned(); let table_name = "My Test Table".to_string(); let upload_node = UploadNodeRunner { upload_node: UploadNode { @@ -206,7 +191,6 @@ mod tests { post_script: None, db_type: DBType::Mssql, connection_string, - allow_table_creation: false, } }; upload_node.run().await?; @@ -215,6 +199,14 @@ mod tests { let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; tcp.set_nodelay(true)?; let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + let result = client.query("SELECT * FROM [My Test Table]", &[]).await?; + let result = result.into_first_result().await?; + assert_eq!(1, result.len()); + + let first_row = result.get(0).unwrap(); + assert_eq!(Some(1), first_row.get("column1")); + assert_eq!(Some("Hello"), first_row.get("column2")); + Ok(()) } } diff --git a/testing/input/upload_to_db/test.csv b/testing/input/upload_to_db/test.csv new file mode 100755 index 0000000..cd46682 --- /dev/null +++ b/testing/input/upload_to_db/test.csv @@ -0,0 +1,2 @@ +column1,column2 +1,Hello \ No newline at end of file