diff --git a/src/graph/pull_from_db.rs b/src/graph/pull_from_db.rs index 00c83fe..5e00df7 100644 --- a/src/graph/pull_from_db.rs +++ b/src/graph/pull_from_db.rs @@ -5,6 +5,7 @@ use crate::io::{DataSource, RecordSerializer}; use async_trait::async_trait; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use sqlx::any::install_default_drivers; use sqlx::AnyPool; use tiberius::{Config, EncryptionLevel}; use tokio_util::compat::TokioAsyncWriteCompatExt; @@ -60,6 +61,7 @@ impl RunnableNode for PullFromDBNodeRunner { pull_from_db(&mut client, &node).await?; } _ => { + install_default_drivers(); let mut pool = AnyPool::connect(&node.connection_string).await?; pull_from_db(&mut pool, &node).await?; } diff --git a/src/graph/sql.rs b/src/graph/sql.rs index 2c6cfc5..e0cecaf 100644 --- a/src/graph/sql.rs +++ b/src/graph/sql.rs @@ -9,27 +9,31 @@ pub trait QueryExecutor { /// Retrieve data from a database async fn get_rows( &mut self, - query: &str, + query: impl AsRef, params: &[String], // TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator? row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()>; /// Run a query that returns no results (e.g. bulk insert, insert) - async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result; + async fn execute_query( + &mut self, + query: impl AsRef, + params: &[String], + ) -> anyhow::Result; /// Execute an unprepared query. Avoid where possible as sql injection is possible if not used carefully - async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result; + async fn execute_unchecked(&mut self, query: impl AsRef) -> anyhow::Result; } impl QueryExecutor for Client { async fn get_rows( &mut self, - query: &str, + query: impl AsRef, params: &[String], row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()> { - let mut query = Query::new(query); + let mut query = Query::new(query.as_ref()); for param in params { query.bind(param); } @@ -50,8 +54,12 @@ impl QueryExecutor for Client { Ok(()) } - async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result { - let mut query = Query::new(query); + async fn execute_query( + &mut self, + query: impl AsRef, + params: &[String], + ) -> anyhow::Result { + let mut query = Query::new(query.as_ref()); for param in params { query.bind(param); } @@ -62,8 +70,8 @@ impl QueryExecutor for Client { Ok(result.rows_affected()[0]) } - async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result { - let result = self.execute(query, &[]).await?; + async fn execute_unchecked(&mut self, query: impl AsRef) -> anyhow::Result { + let result = self.execute(query.as_ref(), &[]).await?; Ok(result.rows_affected()[0]) } } @@ -71,11 +79,11 @@ impl QueryExecutor for Client { impl QueryExecutor for Pool { async fn get_rows( &mut self, - query: &str, + query: impl AsRef, params: &[String], row_consumer: &mut impl FnMut(BTreeMap) -> anyhow::Result<()>, ) -> anyhow::Result<()> { - let mut query = sqlx::query(query); + let mut query = sqlx::query(query.as_ref()); for param in params { query = query.bind(param); } @@ -91,8 +99,12 @@ impl QueryExecutor for Pool { Ok(()) } - async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result { - let mut query = sqlx::query(query); + async fn execute_query( + &mut self, + query: impl AsRef, + params: &[String], + ) -> anyhow::Result { + let mut query = sqlx::query(query.as_ref()); for param in params { query = query.bind(param); } @@ -100,8 +112,8 @@ impl QueryExecutor for Pool { Ok(result.rows_affected()) } - async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result { - let result = self.execute(query).await?; + async fn execute_unchecked(&mut self, query: impl AsRef) -> anyhow::Result { + let result = self.execute(query.as_ref()).await?; Ok(result.rows_affected()) } } diff --git a/src/graph/upload_to_db.rs b/src/graph/upload_to_db.rs index 27a19c2..e6e5ebe 100644 --- a/src/graph/upload_to_db.rs +++ b/src/graph/upload_to_db.rs @@ -27,9 +27,6 @@ pub async fn upload_file_bulk( r#"COPY "{}" FROM '{}' DELIMITERS ',' CSV HEADER QUOTE '"';"#, upload_node.table_name, upload_node.file_path )), - // TODO: Revisit this when sqlx lets this work, currently mysql won't allow this - // to execute since sqlx forces it to be inside of a prepare, which - // isn't allowed DBType::Mysql => Some(format!( "LOAD DATA INFILE '{}' INTO TABLE `{}` FIELDS TERMINATED BY ',' ENCLOSED BY '\"' @@ -72,9 +69,12 @@ pub async fn upload_file_bulk( csv_columns.clone() }; let query_template = format!( - "INSERT INTO [{}]({}) \n", - upload_node.table_name, - table_columns.join(",") + "INSERT INTO {} ({}) \n", + db_type.quote_name(&upload_node.table_name), + table_columns + .iter() + .map(|column| db_type.quote_name(column)) + .join(",") ); let mut params = vec![]; let mut insert_query = "".to_owned(); @@ -88,13 +88,13 @@ pub async fn upload_file_bulk( "VALUES ({})", result .iter() - .enumerate() - .map(|(index, _)| db_type.get_param_name(index)) + // TODO: This is done because postgres won't do an implicit cast on text in the insert, + // if this is resolved we can go back to parameterised queries instead + // as mysql/sql server work fine with implicit cast on parameters + .map(|value| format!("'{}'", value.replace("'", "''"))) .join(",") ) .as_str(); - let mut values = result.iter().map(|value| value.to_owned()).collect_vec(); - params.append(&mut values); num_params += csv_columns.len(); if num_params == bind_limit { running_row_total += executor.execute_query(&query_template, ¶ms).await?; @@ -104,15 +104,18 @@ pub async fn upload_file_bulk( } } if !insert_query.is_empty() { - running_row_total += executor.execute_query(&query_template, ¶ms).await?; + running_row_total += executor + .execute_query(format!("{}{}", query_template, insert_query), ¶ms) + .await?; } rows_affected = Some(running_row_total); } - if let Some(post_script) = &upload_node.post_script { - executor.execute_query(post_script, &[]).await?; + if rows_affected.is_some() { + if let Some(post_script) = &upload_node.post_script { + executor.execute_query(post_script, &[]).await?; + } } - match rows_affected { Some(rows_affected) => Ok(rows_affected), None => bail!( @@ -130,12 +133,13 @@ pub enum DBType { } impl DBType { - pub fn get_param_name(&self, index: usize) -> String { + pub fn quote_name(&self, quoted: impl AsRef) -> String { + let quoted = quoted.as_ref(); match self { - DBType::Postgres => format!("${}", index), - DBType::Mysql => "?".to_owned(), - DBType::Mssql => format!("@P{}", index), - DBType::Sqlite => "?".to_owned(), + DBType::Postgres => format!(r#""{}""#, quoted), + DBType::Mysql => format!("`{}`", quoted), + DBType::Mssql => format!("[{}]", quoted), + DBType::Sqlite => format!("[{}]", quoted), } } } @@ -335,4 +339,138 @@ mod tests { assert_eq!("Hello", column2); Ok(()) } + + #[tokio::test] + pub async fn check_insert_mssql() -> anyhow::Result<()> { + let container = GenericImage::new( + "gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", + "latest", + ) + .with_exposed_port(1433.tcp()) + .with_wait_for(WaitFor::message_on_stdout( + "Recovery is complete.".to_owned(), + )) + .with_env_var("ACCEPT_EULA", "Y") + .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") + .start() + .await?; + let host = container.get_host().await?; + let port = container.get_host_port_ipv4(1433).await?; + let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned(); + let file = "./testing/input/upload_to_db/test.csv".to_owned(); + let table_name = "My Test Table".to_string(); + let upload_node = UploadNodeRunner { + upload_node: UploadNode { + file_path: file.to_owned(), + table_name: table_name.clone(), + column_mappings: None, + post_script: None, + db_type: DBType::Mssql, + connection_string, + }, + }; + upload_node.run().await?; + let mut config = Config::from_jdbc_string(&upload_node.upload_node.connection_string)?; + config.encryption(EncryptionLevel::NotSupported); + let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + let result = client.query("SELECT * FROM [My Test Table]", &[]).await?; + let result = result.into_first_result().await?; + assert_eq!(1, result.len()); + + let first_row = result.get(0).unwrap(); + assert_eq!(Some(1), first_row.get("column1")); + assert_eq!(Some("Hello"), first_row.get("column2")); + + Ok(()) + } + + #[tokio::test] + pub async fn check_insert_postgres() -> anyhow::Result<()> { + let container = GenericImage::new( + "gitea.michaelpivato.dev/vato007/ingey-test-db-postgres", + "latest", + ) + .with_exposed_port(5432.tcp()) + .with_wait_for(WaitFor::message_on_stderr( + "database system is ready to accept connections", + )) + .with_env_var("POSTGRES_PASSWORD", "TestOnlyContainer123") + .start() + .await?; + let host = container.get_host().await?; + let port = container.get_host_port_ipv4(5432).await?; + let connection_string = format!( + "postgres://postgres:TestOnlyContainer123@{}:{}/testingeydatabase", + host, port + ) + .to_owned(); + let file = "./testing/input/upload_to_db/test.csv".to_owned(); + let table_name = "My Test Table".to_string(); + let upload_node = UploadNodeRunner { + upload_node: UploadNode { + file_path: file.to_owned(), + table_name: table_name.clone(), + column_mappings: None, + post_script: None, + db_type: DBType::Postgres, + connection_string, + }, + }; + upload_node.run().await?; + let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?; + let result = sqlx::query(r#"SELECT * FROM "My Test Table""#) + .fetch_one(&pool) + .await?; + + let column1: i32 = result.try_get("column1")?; + let column2: &str = result.try_get("column2")?; + assert_eq!(1, column1); + assert_eq!("Hello", column2); + Ok(()) + } + + #[tokio::test] + pub async fn check_insert_mysql() -> anyhow::Result<()> { + let container = GenericImage::new( + "gitea.michaelpivato.dev/vato007/ingey-test-db-mysql", + "latest", + ) + .with_exposed_port(3306.tcp()) + .with_wait_for(WaitFor::message_on_stderr("ready for connections.")) + .with_env_var("MYSQL_ROOT_PASSWORD", "TestOnlyContainer123") + .start() + .await?; + let host = container.get_host().await?; + let port = container.get_host_port_ipv4(3306).await?; + let connection_string = format!( + "mysql://root:TestOnlyContainer123@{}:{}/TestIngeyDatabase", + host, port + ) + .to_owned(); + let file = "./testing/input/upload_to_db/test.csv".to_owned(); + let table_name = "My Test Table".to_string(); + let upload_node = UploadNodeRunner { + upload_node: UploadNode { + file_path: file.to_owned(), + table_name: table_name.clone(), + column_mappings: None, + post_script: None, + db_type: DBType::Mysql, + connection_string, + }, + }; + upload_node.run().await?; + let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?; + let result = sqlx::query(r#"SELECT * FROM `My Test Table`"#) + .fetch_one(&pool) + .await?; + + let column1: i32 = result.try_get("column1")?; + let column2: &str = result.try_get("column2")?; + assert_eq!(1, column1); + assert_eq!("Hello", column2); + Ok(()) + } }