This commit is contained in:
@@ -5,6 +5,7 @@ use crate::io::{DataSource, RecordSerializer};
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::any::install_default_drivers;
|
||||
use sqlx::AnyPool;
|
||||
use tiberius::{Config, EncryptionLevel};
|
||||
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||
@@ -60,6 +61,7 @@ impl RunnableNode for PullFromDBNodeRunner {
|
||||
pull_from_db(&mut client, &node).await?;
|
||||
}
|
||||
_ => {
|
||||
install_default_drivers();
|
||||
let mut pool = AnyPool::connect(&node.connection_string).await?;
|
||||
pull_from_db(&mut pool, &node).await?;
|
||||
}
|
||||
|
||||
@@ -9,27 +9,31 @@ pub trait QueryExecutor {
|
||||
/// Retrieve data from a database
|
||||
async fn get_rows(
|
||||
&mut self,
|
||||
query: &str,
|
||||
query: impl AsRef<str>,
|
||||
params: &[String],
|
||||
// TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator?
|
||||
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Run a query that returns no results (e.g. bulk insert, insert)
|
||||
async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result<u64>;
|
||||
async fn execute_query(
|
||||
&mut self,
|
||||
query: impl AsRef<str>,
|
||||
params: &[String],
|
||||
) -> anyhow::Result<u64>;
|
||||
|
||||
/// Execute an unprepared query. Avoid where possible as sql injection is possible if not used carefully
|
||||
async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result<u64>;
|
||||
async fn execute_unchecked(&mut self, query: impl AsRef<str>) -> anyhow::Result<u64>;
|
||||
}
|
||||
|
||||
impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
|
||||
async fn get_rows(
|
||||
&mut self,
|
||||
query: &str,
|
||||
query: impl AsRef<str>,
|
||||
params: &[String],
|
||||
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut query = Query::new(query);
|
||||
let mut query = Query::new(query.as_ref());
|
||||
for param in params {
|
||||
query.bind(param);
|
||||
}
|
||||
@@ -50,8 +54,12 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result<u64> {
|
||||
let mut query = Query::new(query);
|
||||
async fn execute_query(
|
||||
&mut self,
|
||||
query: impl AsRef<str>,
|
||||
params: &[String],
|
||||
) -> anyhow::Result<u64> {
|
||||
let mut query = Query::new(query.as_ref());
|
||||
for param in params {
|
||||
query.bind(param);
|
||||
}
|
||||
@@ -62,8 +70,8 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
|
||||
Ok(result.rows_affected()[0])
|
||||
}
|
||||
|
||||
async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result<u64> {
|
||||
let result = self.execute(query, &[]).await?;
|
||||
async fn execute_unchecked(&mut self, query: impl AsRef<str>) -> anyhow::Result<u64> {
|
||||
let result = self.execute(query.as_ref(), &[]).await?;
|
||||
Ok(result.rows_affected()[0])
|
||||
}
|
||||
}
|
||||
@@ -71,11 +79,11 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
|
||||
impl QueryExecutor for Pool<Any> {
|
||||
async fn get_rows(
|
||||
&mut self,
|
||||
query: &str,
|
||||
query: impl AsRef<str>,
|
||||
params: &[String],
|
||||
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut query = sqlx::query(query);
|
||||
let mut query = sqlx::query(query.as_ref());
|
||||
for param in params {
|
||||
query = query.bind(param);
|
||||
}
|
||||
@@ -91,8 +99,12 @@ impl QueryExecutor for Pool<Any> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn execute_query(&mut self, query: &str, params: &[String]) -> anyhow::Result<u64> {
|
||||
let mut query = sqlx::query(query);
|
||||
async fn execute_query(
|
||||
&mut self,
|
||||
query: impl AsRef<str>,
|
||||
params: &[String],
|
||||
) -> anyhow::Result<u64> {
|
||||
let mut query = sqlx::query(query.as_ref());
|
||||
for param in params {
|
||||
query = query.bind(param);
|
||||
}
|
||||
@@ -100,8 +112,8 @@ impl QueryExecutor for Pool<Any> {
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
async fn execute_unchecked(&mut self, query: &str) -> anyhow::Result<u64> {
|
||||
let result = self.execute(query).await?;
|
||||
async fn execute_unchecked(&mut self, query: impl AsRef<str>) -> anyhow::Result<u64> {
|
||||
let result = self.execute(query.as_ref()).await?;
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,9 +27,6 @@ pub async fn upload_file_bulk(
|
||||
r#"COPY "{}" FROM '{}' DELIMITERS ',' CSV HEADER QUOTE '"';"#,
|
||||
upload_node.table_name, upload_node.file_path
|
||||
)),
|
||||
// TODO: Revisit this when sqlx lets this work, currently mysql won't allow this
|
||||
// to execute since sqlx forces it to be inside of a prepare, which
|
||||
// isn't allowed
|
||||
DBType::Mysql => Some(format!(
|
||||
"LOAD DATA INFILE '{}' INTO TABLE `{}`
|
||||
FIELDS TERMINATED BY ',' ENCLOSED BY '\"'
|
||||
@@ -72,9 +69,12 @@ pub async fn upload_file_bulk(
|
||||
csv_columns.clone()
|
||||
};
|
||||
let query_template = format!(
|
||||
"INSERT INTO [{}]({}) \n",
|
||||
upload_node.table_name,
|
||||
table_columns.join(",")
|
||||
"INSERT INTO {} ({}) \n",
|
||||
db_type.quote_name(&upload_node.table_name),
|
||||
table_columns
|
||||
.iter()
|
||||
.map(|column| db_type.quote_name(column))
|
||||
.join(",")
|
||||
);
|
||||
let mut params = vec![];
|
||||
let mut insert_query = "".to_owned();
|
||||
@@ -88,13 +88,13 @@ pub async fn upload_file_bulk(
|
||||
"VALUES ({})",
|
||||
result
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, _)| db_type.get_param_name(index))
|
||||
// TODO: This is done because postgres won't do an implicit cast on text in the insert,
|
||||
// if this is resolved we can go back to parameterised queries instead
|
||||
// as mysql/sql server work fine with implicit cast on parameters
|
||||
.map(|value| format!("'{}'", value.replace("'", "''")))
|
||||
.join(",")
|
||||
)
|
||||
.as_str();
|
||||
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
|
||||
params.append(&mut values);
|
||||
num_params += csv_columns.len();
|
||||
if num_params == bind_limit {
|
||||
running_row_total += executor.execute_query(&query_template, ¶ms).await?;
|
||||
@@ -104,15 +104,18 @@ pub async fn upload_file_bulk(
|
||||
}
|
||||
}
|
||||
if !insert_query.is_empty() {
|
||||
running_row_total += executor.execute_query(&query_template, ¶ms).await?;
|
||||
running_row_total += executor
|
||||
.execute_query(format!("{}{}", query_template, insert_query), ¶ms)
|
||||
.await?;
|
||||
}
|
||||
rows_affected = Some(running_row_total);
|
||||
}
|
||||
|
||||
if rows_affected.is_some() {
|
||||
if let Some(post_script) = &upload_node.post_script {
|
||||
executor.execute_query(post_script, &[]).await?;
|
||||
}
|
||||
|
||||
}
|
||||
match rows_affected {
|
||||
Some(rows_affected) => Ok(rows_affected),
|
||||
None => bail!(
|
||||
@@ -130,12 +133,13 @@ pub enum DBType {
|
||||
}
|
||||
|
||||
impl DBType {
|
||||
pub fn get_param_name(&self, index: usize) -> String {
|
||||
pub fn quote_name(&self, quoted: impl AsRef<str>) -> String {
|
||||
let quoted = quoted.as_ref();
|
||||
match self {
|
||||
DBType::Postgres => format!("${}", index),
|
||||
DBType::Mysql => "?".to_owned(),
|
||||
DBType::Mssql => format!("@P{}", index),
|
||||
DBType::Sqlite => "?".to_owned(),
|
||||
DBType::Postgres => format!(r#""{}""#, quoted),
|
||||
DBType::Mysql => format!("`{}`", quoted),
|
||||
DBType::Mssql => format!("[{}]", quoted),
|
||||
DBType::Sqlite => format!("[{}]", quoted),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -335,4 +339,138 @@ mod tests {
|
||||
assert_eq!("Hello", column2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn check_insert_mssql() -> anyhow::Result<()> {
|
||||
let container = GenericImage::new(
|
||||
"gitea.michaelpivato.dev/vato007/ingey-test-db-mssql",
|
||||
"latest",
|
||||
)
|
||||
.with_exposed_port(1433.tcp())
|
||||
.with_wait_for(WaitFor::message_on_stdout(
|
||||
"Recovery is complete.".to_owned(),
|
||||
))
|
||||
.with_env_var("ACCEPT_EULA", "Y")
|
||||
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
|
||||
.start()
|
||||
.await?;
|
||||
let host = container.get_host().await?;
|
||||
let port = container.get_host_port_ipv4(1433).await?;
|
||||
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
|
||||
let file = "./testing/input/upload_to_db/test.csv".to_owned();
|
||||
let table_name = "My Test Table".to_string();
|
||||
let upload_node = UploadNodeRunner {
|
||||
upload_node: UploadNode {
|
||||
file_path: file.to_owned(),
|
||||
table_name: table_name.clone(),
|
||||
column_mappings: None,
|
||||
post_script: None,
|
||||
db_type: DBType::Mssql,
|
||||
connection_string,
|
||||
},
|
||||
};
|
||||
upload_node.run().await?;
|
||||
let mut config = Config::from_jdbc_string(&upload_node.upload_node.connection_string)?;
|
||||
config.encryption(EncryptionLevel::NotSupported);
|
||||
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
||||
tcp.set_nodelay(true)?;
|
||||
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
||||
let result = client.query("SELECT * FROM [My Test Table]", &[]).await?;
|
||||
let result = result.into_first_result().await?;
|
||||
assert_eq!(1, result.len());
|
||||
|
||||
let first_row = result.get(0).unwrap();
|
||||
assert_eq!(Some(1), first_row.get("column1"));
|
||||
assert_eq!(Some("Hello"), first_row.get("column2"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn check_insert_postgres() -> anyhow::Result<()> {
|
||||
let container = GenericImage::new(
|
||||
"gitea.michaelpivato.dev/vato007/ingey-test-db-postgres",
|
||||
"latest",
|
||||
)
|
||||
.with_exposed_port(5432.tcp())
|
||||
.with_wait_for(WaitFor::message_on_stderr(
|
||||
"database system is ready to accept connections",
|
||||
))
|
||||
.with_env_var("POSTGRES_PASSWORD", "TestOnlyContainer123")
|
||||
.start()
|
||||
.await?;
|
||||
let host = container.get_host().await?;
|
||||
let port = container.get_host_port_ipv4(5432).await?;
|
||||
let connection_string = format!(
|
||||
"postgres://postgres:TestOnlyContainer123@{}:{}/testingeydatabase",
|
||||
host, port
|
||||
)
|
||||
.to_owned();
|
||||
let file = "./testing/input/upload_to_db/test.csv".to_owned();
|
||||
let table_name = "My Test Table".to_string();
|
||||
let upload_node = UploadNodeRunner {
|
||||
upload_node: UploadNode {
|
||||
file_path: file.to_owned(),
|
||||
table_name: table_name.clone(),
|
||||
column_mappings: None,
|
||||
post_script: None,
|
||||
db_type: DBType::Postgres,
|
||||
connection_string,
|
||||
},
|
||||
};
|
||||
upload_node.run().await?;
|
||||
let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?;
|
||||
let result = sqlx::query(r#"SELECT * FROM "My Test Table""#)
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
|
||||
let column1: i32 = result.try_get("column1")?;
|
||||
let column2: &str = result.try_get("column2")?;
|
||||
assert_eq!(1, column1);
|
||||
assert_eq!("Hello", column2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn check_insert_mysql() -> anyhow::Result<()> {
|
||||
let container = GenericImage::new(
|
||||
"gitea.michaelpivato.dev/vato007/ingey-test-db-mysql",
|
||||
"latest",
|
||||
)
|
||||
.with_exposed_port(3306.tcp())
|
||||
.with_wait_for(WaitFor::message_on_stderr("ready for connections."))
|
||||
.with_env_var("MYSQL_ROOT_PASSWORD", "TestOnlyContainer123")
|
||||
.start()
|
||||
.await?;
|
||||
let host = container.get_host().await?;
|
||||
let port = container.get_host_port_ipv4(3306).await?;
|
||||
let connection_string = format!(
|
||||
"mysql://root:TestOnlyContainer123@{}:{}/TestIngeyDatabase",
|
||||
host, port
|
||||
)
|
||||
.to_owned();
|
||||
let file = "./testing/input/upload_to_db/test.csv".to_owned();
|
||||
let table_name = "My Test Table".to_string();
|
||||
let upload_node = UploadNodeRunner {
|
||||
upload_node: UploadNode {
|
||||
file_path: file.to_owned(),
|
||||
table_name: table_name.clone(),
|
||||
column_mappings: None,
|
||||
post_script: None,
|
||||
db_type: DBType::Mysql,
|
||||
connection_string,
|
||||
},
|
||||
};
|
||||
upload_node.run().await?;
|
||||
let pool = AnyPool::connect(&upload_node.upload_node.connection_string).await?;
|
||||
let result = sqlx::query(r#"SELECT * FROM `My Test Table`"#)
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
|
||||
let column1: i32 = result.try_get("column1")?;
|
||||
let column2: &str = result.try_get("column2")?;
|
||||
assert_eq!(1, column1);
|
||||
assert_eq!("Hello", column2);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user