diff --git a/src/graph/dynamic/csv_reader.rs b/src/graph/dynamic/csv_reader.rs index 9389454..32606a3 100644 --- a/src/graph/dynamic/csv_reader.rs +++ b/src/graph/dynamic/csv_reader.rs @@ -76,7 +76,7 @@ impl HostCsvReader for DynamicState { HashMap::new() } } - Err(err) => { + Err(_) => { HashMap::new() } }; diff --git a/src/graph/dynamic/mod.rs b/src/graph/dynamic/mod.rs index 06f0f79..a193b44 100644 --- a/src/graph/dynamic/mod.rs +++ b/src/graph/dynamic/mod.rs @@ -9,7 +9,6 @@ use wasmtime::{Config, Engine, Store}; mod dynamic_state; use crate::graph::dynamic::csv_readers::CsvReadersData; use crate::graph::dynamic::csv_writer::CsvWriterData; -use crate::graph::dynamic::dynamic_state::ReadMap; use crate::graph::dynamic::read_map::ReadMapData; use dynamic_state::{Dynamic, DynamicState}; diff --git a/src/graph/node.rs b/src/graph/node.rs index ca9efde..2449aa5 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -6,6 +6,7 @@ pub trait RunnableNode { // TODO: Status - // TODO: Is it possible to make this async? + // TODO: Runtime attributes passed in here would be nice, so a task can change depending + // on the attributes async fn run(&self) -> anyhow::Result<()>; } diff --git a/src/graph/sql.rs b/src/graph/sql.rs index d17c7ad..5fc31d6 100644 --- a/src/graph/sql.rs +++ b/src/graph/sql.rs @@ -11,20 +11,20 @@ pub trait QueryExecutor { async fn get_rows( &mut self, query: &str, - params: &Vec, + params: &[String], // TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator? row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()>; // Run a query that returns no results (e.g. bulk insert, insert) - async fn execute_query(&mut self, query: &str, params: &Vec) -> anyhow::Result; + async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result; } impl QueryExecutor for Client { async fn get_rows( &mut self, query: &str, - params: &Vec, + params: &[String], row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()> { let mut query = Query::new(query); @@ -46,7 +46,7 @@ impl QueryExecutor for Client { Ok(()) } - async fn execute_query(&mut self, query: &str, params: &Vec) -> anyhow::Result { + async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result { let mut query = Query::new(query); for param in params { query.bind(param); @@ -63,7 +63,7 @@ impl QueryExecutor for Pool { async fn get_rows( &mut self, query: &str, - params: &Vec, + params: &[String], row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()> { let mut query = sqlx::query(query); @@ -82,7 +82,7 @@ impl QueryExecutor for Pool { Ok(()) } - async fn execute_query(&mut self, query: &str, params: &Vec) -> anyhow::Result { + async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result { let mut query = sqlx::query(query); for param in params { query = query.bind(param); diff --git a/src/graph/upload_to_db.rs b/src/graph/upload_to_db.rs index dcb2e04..fa042a7 100644 --- a/src/graph/upload_to_db.rs +++ b/src/graph/upload_to_db.rs @@ -1,12 +1,11 @@ -use std::collections::HashMap; - use anyhow::bail; use async_trait::async_trait; use itertools::Itertools; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use sqlx::AnyPool; -use tiberius::Config; +use std::collections::HashMap; +use tiberius::{Config, EncryptionLevel}; use tokio_util::compat::TokioAsyncWriteCompatExt; use super::{node::RunnableNode, sql::QueryExecutor}; @@ -17,8 +16,28 @@ pub async fn upload_file_bulk( executor: &mut impl QueryExecutor, upload_node: &UploadNode, bind_limit: usize, + db_type: &DBType, + allow_table_creation: bool, ) -> anyhow::Result { let mut rows_affected = None; + let mut num_matching_tables: usize = 0; + executor.get_rows(&format!("SELECT COUNT(*) Count +FROM information_schema.tables +WHERE table_schema = {} + AND table_name = ? +LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut move |row| { + num_matching_tables = row.get("Count").map(|count| count.parse().unwrap_or(0)).unwrap_or(0); + Ok(()) + }).await?; + if num_matching_tables == 0 { + if allow_table_creation { + // TODO: Create the table with the columns in the file to be uploaded + + } else { + bail!("Table creation not allowed and table does not yet exist, aborting"); + } + } + if upload_node.column_mappings.is_none() { let insert_from_file_query = match upload_node.db_type { DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)), @@ -26,17 +45,19 @@ pub async fn upload_file_bulk( "LOAD DATA INFILE ? INTO {}", upload_node.table_name, )), - DBType::Mssql => Some(format!("BULK INSERT {} FROM ?", upload_node.table_name)), + DBType::Mssql => Some(format!("BULK INSERT [{}] FROM ? WITH ( FORMAT = 'CSV');", upload_node.table_name)), _ => None, }; if let Some(insert_from_file_query) = insert_from_file_query { let result = executor .execute_query( &insert_from_file_query, - &vec![upload_node.file_path.clone()], + &[upload_node.file_path.clone()], ) - .await?; - rows_affected = Some(result); + .await; + if let Ok(result) = result { + rows_affected = Some(result); + } } } @@ -57,7 +78,7 @@ pub async fn upload_file_bulk( csv_columns.clone() }; let query_template = format!( - "INSERT INTO {}({}) \n", + "INSERT INTO [{}]({}) \n", upload_node.table_name, table_columns.join(",") ); @@ -87,7 +108,7 @@ pub async fn upload_file_bulk( } if let Some(post_script) = &upload_node.post_script { - executor.execute_query(post_script, &vec![]).await?; + executor.execute_query(post_script, &[]).await?; } match rows_affected { @@ -106,6 +127,17 @@ pub enum DBType { Sqlite, } +impl DBType { + pub fn get_schema_fn(&self) -> String { + match self { + DBType::Postgres => "CURRENT_SCHEMA()", + DBType::Mysql => "DATABASE()", + DBType::Mssql => "DB_NAME()", + DBType::Sqlite => "'sqlite'", + }.to_owned() + } +} + #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct UploadNode { file_path: String, @@ -115,6 +147,7 @@ pub struct UploadNode { post_script: Option, db_type: DBType, connection_string: String, + allow_table_creation: bool, } pub struct UploadNodeRunner { @@ -126,14 +159,16 @@ impl RunnableNode for UploadNodeRunner { async fn run(&self) -> anyhow::Result<()> { let upload_node = self.upload_node.clone(); if upload_node.db_type == DBType::Mssql { - let config = Config::from_jdbc_string(&upload_node.connection_string)?; + let mut config = Config::from_jdbc_string(&upload_node.connection_string)?; + // TODO: Restore encryption for remote hosts, doesn't work on localhost without encryption. + config.encryption(EncryptionLevel::NotSupported); let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; tcp.set_nodelay(true)?; let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; - upload_file_bulk(&mut client, &upload_node, BIND_LIMIT).await?; + upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?; } else { let mut pool = AnyPool::connect(&upload_node.connection_string).await?; - upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT).await?; + upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?; } Ok(()) } @@ -141,23 +176,45 @@ impl RunnableNode for UploadNodeRunner { #[cfg(test)] mod tests { + use crate::graph::node::RunnableNode; use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner}; + use testcontainers::core::{IntoContainerPort, WaitFor}; + use testcontainers::runners::AsyncRunner; + use testcontainers::{GenericImage, ImageExt}; + use tiberius::{Config, EncryptionLevel}; + use tokio_util::compat::TokioAsyncWriteCompatExt; - #[test] - pub fn check_basic_upload() { - let upload_ode = UploadNodeRunner { + #[tokio::test] + pub async fn check_basic_upload() -> anyhow::Result<()> { + let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest") + .with_exposed_port(1433.tcp()) + .with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned())) + .with_env_var("ACCEPT_EULA", "Y") + .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") + .start() + .await?; + let host = container.get_host().await?; + let port = container.get_host_port_ipv4(1433).await?; + let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned(); + let file = "testing/input/upload_to_db/test.csv"; + let table_name = "My Test Table".to_string(); + let upload_node = UploadNodeRunner { upload_node: UploadNode { - file_path: "".to_owned(), - table_name: "".to_string(), + file_path: file.to_owned(), + table_name: table_name.clone(), column_mappings: None, post_script: None, - db_type: DBType::Mysql, - connection_string: "".to_string(), + db_type: DBType::Mssql, + connection_string, + allow_table_creation: false, } }; - + upload_node.run().await?; + let mut config = Config::from_jdbc_string(&upload_node.upload_node.connection_string)?; + config.encryption(EncryptionLevel::NotSupported); + let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + Ok(()) } - - #[test] - pub fn check_batch_upload() {} }