Fix non-bulk upload to db
Some checks failed
test / test (push) Failing after 17m46s

This commit is contained in:
2025-02-15 10:39:19 +10:30
parent d9f69ff298
commit aac37d6bf4
3 changed files with 186 additions and 34 deletions

View File

@@ -5,6 +5,7 @@ use crate::io::{DataSource, RecordSerializer};
use async_trait::async_trait; use async_trait::async_trait;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::any::install_default_drivers;
use sqlx::AnyPool; use sqlx::AnyPool;
use tiberius::{Config, EncryptionLevel}; use tiberius::{Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt; use tokio_util::compat::TokioAsyncWriteCompatExt;
@@ -60,6 +61,7 @@ impl RunnableNode for PullFromDBNodeRunner {
pull_from_db(&mut client, &node).await?; pull_from_db(&mut client, &node).await?;
} }
_ => { _ => {
install_default_drivers();
let mut pool = AnyPool::connect(&node.connection_string).await?; let mut pool = AnyPool::connect(&node.connection_string).await?;
pull_from_db(&mut pool, &node).await?; pull_from_db(&mut pool, &node).await?;
} }

View File

@@ -9,27 +9,31 @@ pub trait QueryExecutor {
/// Retrieve data from a database /// Retrieve data from a database
async fn get_rows( async fn get_rows(
&mut self, &mut self,
query: &str, query: impl AsRef<str>,
params: &[String], params: &[String],
// TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator? // TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator?
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>, row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()>; ) -> anyhow::Result<()>;
/// Run a query that returns no results (e.g. bulk insert, insert) /// Run a query that returns no results (e.g. bulk insert, insert)
async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result<u64>; async fn execute_query(
&mut self,
query: impl AsRef<str>,
params: &[String],
) -> anyhow::Result<u64>;
/// Execute an unprepared query. Avoid where possible as sql injection is possible if not used carefully /// Execute an unprepared query. Avoid where possible as sql injection is possible if not used carefully
async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result<u64>; async fn execute_unchecked(&mut self, query: impl AsRef<str>) -> anyhow::Result<u64>;
} }
impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> { impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
async fn get_rows( async fn get_rows(
&mut self, &mut self,
query: &str, query: impl AsRef<str>,
params: &[String], params: &[String],
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>, row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut query = Query::new(query); let mut query = Query::new(query.as_ref());
for param in params { for param in params {
query.bind(param); query.bind(param);
} }
@@ -50,8 +54,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
Ok(()) Ok(())
} }
async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result<u64> { async fn execute_query(
let mut query = Query::new(query); &mut self,
query: impl AsRef<str>,
params: &[String],
) -> anyhow::Result<u64> {
let mut query = Query::new(query.as_ref());
for param in params { for param in params {
query.bind(param); query.bind(param);
} }
@@ -62,8 +70,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
Ok(result.rows_affected()[0]) Ok(result.rows_affected()[0])
} }
async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result<u64> { async fn execute_unchecked(&mut self, query: impl AsRef<str>) -> anyhow::Result<u64> {
let result = self.execute(query, &[]).await?; let result = self.execute(query.as_ref(), &[]).await?;
Ok(result.rows_affected()[0]) Ok(result.rows_affected()[0])
} }
} }
@@ -71,11 +79,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
impl QueryExecutor for Pool<Any> { impl QueryExecutor for Pool<Any> {
async fn get_rows( async fn get_rows(
&mut self, &mut self,
query: &str, query: impl AsRef<str>,
params: &[String], params: &[String],
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>, row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let mut query = sqlx::query(query); let mut query = sqlx::query(query.as_ref());
for param in params { for param in params {
query = query.bind(param); query = query.bind(param);
} }
@@ -91,8 +99,12 @@ impl QueryExecutor for Pool<Any> {
Ok(()) Ok(())
} }
async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result<u64> { async fn execute_query(
let mut query = sqlx::query(query); &mut self,
query: impl AsRef<str>,
params: &[String],
) -> anyhow::Result<u64> {
let mut query = sqlx::query(query.as_ref());
for param in params { for param in params {
query = query.bind(param); query = query.bind(param);
} }
@@ -100,8 +112,8 @@ impl QueryExecutor for Pool<Any> {
Ok(result.rows_affected()) Ok(result.rows_affected())
} }
async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result<u64> { async fn execute_unchecked(&mut self, query: impl AsRef<str>) -> anyhow::Result<u64> {
let result = self.execute(query).await?; let result = self.execute(query.as_ref()).await?;
Ok(result.rows_affected()) Ok(result.rows_affected())
} }
} }

View File

@@ -27,9 +27,6 @@ pub async fn upload_file_bulk(
r#"COPY "{}" FROM '{}' DELIMITERS ',' CSV HEADER QUOTE '"';"#, r#"COPY "{}" FROM '{}' DELIMITERS ',' CSV HEADER QUOTE '"';"#,
upload_node.table_name, upload_node.file_path upload_node.table_name, upload_node.file_path
)), )),
// TODO: Revisit this when sqlx lets this work, currently mysql won't allow this
// to execute since sqlx forces it to be inside of a prepare, which
// isn't allowed
DBType::Mysql => Some(format!( DBType::Mysql => Some(format!(
"LOAD DATA INFILE '{}' INTO TABLE `{}` "LOAD DATA INFILE '{}' INTO TABLE `{}`
FIELDS TERMINATED BY ',' ENCLOSED BY '\"' FIELDS TERMINATED BY ',' ENCLOSED BY '\"'
@@ -72,9 +69,12 @@ pub async fn upload_file_bulk(
csv_columns.clone() csv_columns.clone()
}; };
let query_template = format!( let query_template = format!(
"INSERT INTO [{}]({}) \n", "INSERT INTO {} ({}) \n",
upload_node.table_name, db_type.quote_name(&upload_node.table_name),
table_columns.join(",") table_columns
.iter()
.map(|column| db_type.quote_name(column))
.join(",")
); );
let mut params = vec![]; let mut params = vec![];
let mut insert_query = "".to_owned(); let mut insert_query = "".to_owned();
@@ -88,13 +88,13 @@ pub async fn upload_file_bulk(
"VALUES ({})", "VALUES ({})",
result result
.iter() .iter()
.enumerate() // TODO: This is done because postgres won't do an implicit cast on text in the insert,
.map(|(index, _)| db_type.get_param_name(index)) // if this is resolved we can go back to parameterised queries instead
// as mysql/sql server work fine with implicit cast on parameters
.map(|value| format!("'{}'", value.replace("'", "''")))
.join(",") .join(",")
) )
.as_str(); .as_str();
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
params.append(&mut values);
num_params += csv_columns.len(); num_params += csv_columns.len();
if num_params == bind_limit { if num_params == bind_limit {
running_row_total += executor.execute_query(&query_template, &params).await?; running_row_total += executor.execute_query(&query_template, &params).await?;
@@ -104,15 +104,18 @@ pub async fn upload_file_bulk(
} }
} }
if !insert_query.is_empty() { if !insert_query.is_empty() {
running_row_total += executor.execute_query(&query_template, &params).await?; running_row_total += executor
.execute_query(format!("{}{}", query_template, insert_query), &params)
.await?;
} }
rows_affected = Some(running_row_total); rows_affected = Some(running_row_total);
} }
if rows_affected.is_some() {
if let Some(post_script) = &upload_node.post_script { if let Some(post_script) = &upload_node.post_script {
executor.execute_query(post_script, &[]).await?; executor.execute_query(post_script, &[]).await?;
} }
}
match rows_affected { match rows_affected {
Some(rows_affected) => Ok(rows_affected), Some(rows_affected) => Ok(rows_affected),
None => bail!( None => bail!(
@@ -130,12 +133,13 @@ pub enum DBType {
} }
impl DBType { impl DBType {
pub fn get_param_name(&self, index: usize) -> String { pub fn quote_name(&self, quoted: impl AsRef<str>) -> String {
let quoted = quoted.as_ref();
match self { match self {
DBType::Postgres => format!("${}", index), DBType::Postgres => format!(r#""{}""#, quoted),
DBType::Mysql => "?".to_owned(), DBType::Mysql => format!("`{}`", quoted),
DBType::Mssql => format!("@P{}", index), DBType::Mssql => format!("[{}]", quoted),
DBType::Sqlite => "?".to_owned(), DBType::Sqlite => format!("[{}]", quoted),
} }
} }
} }
@@ -335,4 +339,138 @@ mod tests {
assert_eq!("Hello", column2); assert_eq!("Hello", column2);
Ok(()) Ok(())
} }
#[tokio::test]
pub async fn check_insert_mssql() -> anyhow::Result<()> {
let container = GenericImage::new(
"gitea.michaelpivato.dev/vato007/ingey-test-db-mssql",
"latest",
)
.with_exposed_port(1433.tcp())
.with_wait_for(WaitFor::message_on_stdout(
"Recovery is complete.".to_owned(),
))
.with_env_var("ACCEPT_EULA", "Y")
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
.start()
.await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(1433).await?;
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
let file = "./testing/input/upload_to_db/test.csv".to_owned();
let table_name = "My Test Table".to_string();
let upload_node = UploadNodeRunner {
upload_node: UploadNode {
file_path: file.to_owned(),
table_name: table_name.clone(),
column_mappings: None,
post_script: None,
db_type: DBType::Mssql,
connection_string,
},
};
upload_node.run().await?;
let mut config = Config::from_jdbc_string(&upload_node.upload_node.connection_string)?;
config.encryption(EncryptionLevel::NotSupported);
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
let result = client.query("SELECT * FROM [My Test Table]", &[]).await?;
let result = result.into_first_result().await?;
assert_eq!(1, result.len());
let first_row = result.get(0).unwrap();
assert_eq!(Some(1), first_row.get("column1"));
assert_eq!(Some("Hello"), first_row.get("column2"));
Ok(())
}
#[tokio::test]
pub async fn check_insert_postgres() -> anyhow::Result<()> {
let container = GenericImage::new(
"gitea.michaelpivato.dev/vato007/ingey-test-db-postgres",
"latest",
)
.with_exposed_port(5432.tcp())
.with_wait_for(WaitFor::message_on_stderr(
"database system is ready to accept connections",
))
.with_env_var("POSTGRES_PASSWORD", "TestOnlyContainer123")
.start()
.await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(5432).await?;
let connection_string = format!(
"postgres://postgres:TestOnlyContainer123@{}:{}/testingeydatabase",
host, port
)
.to_owned();
let file = "./testing/input/upload_to_db/test.csv".to_owned();
let table_name = "My Test Table".to_string();
let upload_node = UploadNodeRunner {
upload_node: UploadNode {
file_path: file.to_owned(),
table_name: table_name.clone(),
column_mappings: None,
post_script: None,
db_type: DBType::Postgres,
connection_string,
},
};
upload_node.run().await?;
let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?;
let result = sqlx::query(r#"SELECT * FROM "My Test Table""#)
.fetch_one(&pool)
.await?;
let column1: i32 = result.try_get("column1")?;
let column2: &str = result.try_get("column2")?;
assert_eq!(1, column1);
assert_eq!("Hello", column2);
Ok(())
}
#[tokio::test]
pub async fn check_insert_mysql() -> anyhow::Result<()> {
let container = GenericImage::new(
"gitea.michaelpivato.dev/vato007/ingey-test-db-mysql",
"latest",
)
.with_exposed_port(3306.tcp())
.with_wait_for(WaitFor::message_on_stderr("ready for connections."))
.with_env_var("MYSQL_ROOT_PASSWORD", "TestOnlyContainer123")
.start()
.await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(3306).await?;
let connection_string = format!(
"mysql://root:TestOnlyContainer123@{}:{}/TestIngeyDatabase",
host, port
)
.to_owned();
let file = "./testing/input/upload_to_db/test.csv".to_owned();
let table_name = "My Test Table".to_string();
let upload_node = UploadNodeRunner {
upload_node: UploadNode {
file_path: file.to_owned(),
table_name: table_name.clone(),
column_mappings: None,
post_script: None,
db_type: DBType::Mysql,
connection_string,
},
};
upload_node.run().await?;
let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?;
let result = sqlx::query(r#"SELECT * FROM `My Test Table`"#)
.fetch_one(&pool)
.await?;
let column1: i32 = result.try_get("column1")?;
let column2: &str = result.try_get("column2")?;
assert_eq!(1, column1);
assert_eq!("Hello", column2);
Ok(())
}
} }