This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
use anyhow::bail;
|
||||
use async_trait::async_trait;
|
||||
use itertools::Itertools;
|
||||
use log::{log, Level};
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::AnyPool;
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, fmt::format};
|
||||
use tiberius::{Config, EncryptionLevel};
|
||||
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||
|
||||
@@ -17,27 +18,8 @@ pub async fn upload_file_bulk(
|
||||
upload_node: &UploadNode,
|
||||
bind_limit: usize,
|
||||
db_type: &DBType,
|
||||
allow_table_creation: bool,
|
||||
) -> anyhow::Result<u64> {
|
||||
let mut rows_affected = None;
|
||||
let mut num_matching_tables: usize = 0;
|
||||
executor.get_rows(&format!("SELECT COUNT(*) Count
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = {}
|
||||
AND table_name = ?
|
||||
LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut move |row| {
|
||||
num_matching_tables = row.get("Count").map(|count| count.parse().unwrap_or(0)).unwrap_or(0);
|
||||
Ok(())
|
||||
}).await?;
|
||||
if num_matching_tables == 0 {
|
||||
if allow_table_creation {
|
||||
// TODO: Create the table with the columns in the file to be uploaded
|
||||
|
||||
} else {
|
||||
bail!("Table creation not allowed and table does not yet exist, aborting");
|
||||
}
|
||||
}
|
||||
|
||||
if upload_node.column_mappings.is_none() {
|
||||
let insert_from_file_query = match upload_node.db_type {
|
||||
DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)),
|
||||
@@ -45,7 +27,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
|
||||
"LOAD DATA INFILE ? INTO {}",
|
||||
upload_node.table_name,
|
||||
)),
|
||||
DBType::Mssql => Some(format!("BULK INSERT [{}] FROM ? WITH ( FORMAT = 'CSV');", upload_node.table_name)),
|
||||
DBType::Mssql => Some(format!("DECLARE @insertStatement VARCHAR(MAX) = 'BULK INSERT [{}] FROM ''' + @P1 + ''' WITH ( FORMAT = ''CSV'', FIRSTROW = 2 )'; EXEC(@insertStatement);", upload_node.table_name)),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(insert_from_file_query) = insert_from_file_query {
|
||||
@@ -57,13 +39,14 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
|
||||
.await;
|
||||
if let Ok(result) = result {
|
||||
rows_affected = Some(result);
|
||||
} else {
|
||||
log!(Level::Debug, "Failed to bulk insert, trying sql insert instead");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rows_affected == None {
|
||||
let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?;
|
||||
|
||||
let csv_columns = file_reader
|
||||
.headers()?
|
||||
.iter()
|
||||
@@ -90,7 +73,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
|
||||
for result in file_reader.records() {
|
||||
let result = result?;
|
||||
insert_query = insert_query
|
||||
+ format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str();
|
||||
+ format!("VALUES ({})", result.iter().enumerate().map(|(index, _)| db_type.get_param_name(index)).join(",")).as_str();
|
||||
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
|
||||
params.append(&mut values);
|
||||
num_params += csv_columns.len();
|
||||
@@ -128,13 +111,13 @@ pub enum DBType {
|
||||
}
|
||||
|
||||
impl DBType {
|
||||
pub fn get_schema_fn(&self) -> String {
|
||||
pub fn get_param_name(&self, index: usize) -> String {
|
||||
match self {
|
||||
DBType::Postgres => "CURRENT_SCHEMA()",
|
||||
DBType::Mysql => "DATABASE()",
|
||||
DBType::Mssql => "DB_NAME()",
|
||||
DBType::Sqlite => "'sqlite'",
|
||||
}.to_owned()
|
||||
DBType::Postgres => format!("${}", index),
|
||||
DBType::Mysql => "?".to_owned(),
|
||||
DBType::Mssql => format!("@P{}", index),
|
||||
DBType::Sqlite => "?".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -147,7 +130,6 @@ pub struct UploadNode {
|
||||
post_script: Option<String>,
|
||||
db_type: DBType,
|
||||
connection_string: String,
|
||||
allow_table_creation: bool,
|
||||
}
|
||||
|
||||
pub struct UploadNodeRunner {
|
||||
@@ -165,10 +147,10 @@ impl RunnableNode for UploadNodeRunner {
|
||||
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
||||
tcp.set_nodelay(true)?;
|
||||
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
||||
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?;
|
||||
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
|
||||
} else {
|
||||
let mut pool = AnyPool::connect(&upload_node.connection_string).await?;
|
||||
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?;
|
||||
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -176,27 +158,30 @@ impl RunnableNode for UploadNodeRunner {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path;
|
||||
|
||||
use crate::graph::node::RunnableNode;
|
||||
use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner};
|
||||
use testcontainers::core::{IntoContainerPort, WaitFor};
|
||||
use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
|
||||
use testcontainers::runners::AsyncRunner;
|
||||
use testcontainers::{GenericImage, ImageExt};
|
||||
use tiberius::{Config, EncryptionLevel};
|
||||
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||
|
||||
#[tokio::test]
|
||||
pub async fn check_basic_upload() -> anyhow::Result<()> {
|
||||
pub async fn check_bulk_upload_mssql() -> anyhow::Result<()> {
|
||||
let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest")
|
||||
.with_exposed_port(1433.tcp())
|
||||
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
|
||||
.with_env_var("ACCEPT_EULA", "Y")
|
||||
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
|
||||
.with_mount(Mount::bind_mount(path::absolute("./testing/input/upload_to_db")?.to_string_lossy().to_string(), "/upload_to_db"))
|
||||
.start()
|
||||
.await?;
|
||||
let host = container.get_host().await?;
|
||||
let port = container.get_host_port_ipv4(1433).await?;
|
||||
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
|
||||
let file = "testing/input/upload_to_db/test.csv";
|
||||
let file = "/upload_to_db/test.csv".to_owned();
|
||||
let table_name = "My Test Table".to_string();
|
||||
let upload_node = UploadNodeRunner {
|
||||
upload_node: UploadNode {
|
||||
@@ -206,7 +191,6 @@ mod tests {
|
||||
post_script: None,
|
||||
db_type: DBType::Mssql,
|
||||
connection_string,
|
||||
allow_table_creation: false,
|
||||
}
|
||||
};
|
||||
upload_node.run().await?;
|
||||
@@ -215,6 +199,14 @@ mod tests {
|
||||
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
||||
tcp.set_nodelay(true)?;
|
||||
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
||||
let result = client.query("SELECT * FROM [My Test Table]", &[]).await?;
|
||||
let result = result.into_first_result().await?;
|
||||
assert_eq!(1, result.len());
|
||||
|
||||
let first_row = result.get(0).unwrap();
|
||||
assert_eq!(Some(1), first_row.get("column1"));
|
||||
assert_eq!(Some("Hello"), first_row.get("column2"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
2
testing/input/upload_to_db/test.csv
Executable file
2
testing/input/upload_to_db/test.csv
Executable file
@@ -0,0 +1,2 @@
|
||||
column1,column2
|
||||
1,Hello
|
||||
|
Reference in New Issue
Block a user