This commit is contained in:
@@ -1,10 +1,11 @@
|
|||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
use log::{log, Level};
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::AnyPool;
|
use sqlx::AnyPool;
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, fmt::format};
|
||||||
use tiberius::{Config, EncryptionLevel};
|
use tiberius::{Config, EncryptionLevel};
|
||||||
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||||
|
|
||||||
@@ -17,27 +18,8 @@ pub async fn upload_file_bulk(
|
|||||||
upload_node: &UploadNode,
|
upload_node: &UploadNode,
|
||||||
bind_limit: usize,
|
bind_limit: usize,
|
||||||
db_type: &DBType,
|
db_type: &DBType,
|
||||||
allow_table_creation: bool,
|
|
||||||
) -> anyhow::Result<u64> {
|
) -> anyhow::Result<u64> {
|
||||||
let mut rows_affected = None;
|
let mut rows_affected = None;
|
||||||
let mut num_matching_tables: usize = 0;
|
|
||||||
executor.get_rows(&format!("SELECT COUNT(*) Count
|
|
||||||
FROM information_schema.tables
|
|
||||||
WHERE table_schema = {}
|
|
||||||
AND table_name = ?
|
|
||||||
LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut move |row| {
|
|
||||||
num_matching_tables = row.get("Count").map(|count| count.parse().unwrap_or(0)).unwrap_or(0);
|
|
||||||
Ok(())
|
|
||||||
}).await?;
|
|
||||||
if num_matching_tables == 0 {
|
|
||||||
if allow_table_creation {
|
|
||||||
// TODO: Create the table with the columns in the file to be uploaded
|
|
||||||
|
|
||||||
} else {
|
|
||||||
bail!("Table creation not allowed and table does not yet exist, aborting");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if upload_node.column_mappings.is_none() {
|
if upload_node.column_mappings.is_none() {
|
||||||
let insert_from_file_query = match upload_node.db_type {
|
let insert_from_file_query = match upload_node.db_type {
|
||||||
DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)),
|
DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)),
|
||||||
@@ -45,7 +27,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
|
|||||||
"LOAD DATA INFILE ? INTO {}",
|
"LOAD DATA INFILE ? INTO {}",
|
||||||
upload_node.table_name,
|
upload_node.table_name,
|
||||||
)),
|
)),
|
||||||
DBType::Mssql => Some(format!("BULK INSERT [{}] FROM ? WITH ( FORMAT = 'CSV');", upload_node.table_name)),
|
DBType::Mssql => Some(format!("DECLARE @insertStatement VARCHAR(MAX) = 'BULK INSERT [{}] FROM ''' + @P1 + ''' WITH ( FORMAT = ''CSV'', FIRSTROW = 2 )'; EXEC(@insertStatement);", upload_node.table_name)),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
if let Some(insert_from_file_query) = insert_from_file_query {
|
if let Some(insert_from_file_query) = insert_from_file_query {
|
||||||
@@ -57,13 +39,14 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
|
|||||||
.await;
|
.await;
|
||||||
if let Ok(result) = result {
|
if let Ok(result) = result {
|
||||||
rows_affected = Some(result);
|
rows_affected = Some(result);
|
||||||
|
} else {
|
||||||
|
log!(Level::Debug, "Failed to bulk insert, trying sql insert instead");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if rows_affected == None {
|
if rows_affected == None {
|
||||||
let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?;
|
let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?;
|
||||||
|
|
||||||
let csv_columns = file_reader
|
let csv_columns = file_reader
|
||||||
.headers()?
|
.headers()?
|
||||||
.iter()
|
.iter()
|
||||||
@@ -90,7 +73,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
|
|||||||
for result in file_reader.records() {
|
for result in file_reader.records() {
|
||||||
let result = result?;
|
let result = result?;
|
||||||
insert_query = insert_query
|
insert_query = insert_query
|
||||||
+ format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str();
|
+ format!("VALUES ({})", result.iter().enumerate().map(|(index, _)| db_type.get_param_name(index)).join(",")).as_str();
|
||||||
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
|
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
|
||||||
params.append(&mut values);
|
params.append(&mut values);
|
||||||
num_params += csv_columns.len();
|
num_params += csv_columns.len();
|
||||||
@@ -128,13 +111,13 @@ pub enum DBType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl DBType {
|
impl DBType {
|
||||||
pub fn get_schema_fn(&self) -> String {
|
pub fn get_param_name(&self, index: usize) -> String {
|
||||||
match self {
|
match self {
|
||||||
DBType::Postgres => "CURRENT_SCHEMA()",
|
DBType::Postgres => format!("${}", index),
|
||||||
DBType::Mysql => "DATABASE()",
|
DBType::Mysql => "?".to_owned(),
|
||||||
DBType::Mssql => "DB_NAME()",
|
DBType::Mssql => format!("@P{}", index),
|
||||||
DBType::Sqlite => "'sqlite'",
|
DBType::Sqlite => "?".to_owned(),
|
||||||
}.to_owned()
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,7 +130,6 @@ pub struct UploadNode {
|
|||||||
post_script: Option<String>,
|
post_script: Option<String>,
|
||||||
db_type: DBType,
|
db_type: DBType,
|
||||||
connection_string: String,
|
connection_string: String,
|
||||||
allow_table_creation: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct UploadNodeRunner {
|
pub struct UploadNodeRunner {
|
||||||
@@ -165,10 +147,10 @@ impl RunnableNode for UploadNodeRunner {
|
|||||||
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
||||||
tcp.set_nodelay(true)?;
|
tcp.set_nodelay(true)?;
|
||||||
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
||||||
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?;
|
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
|
||||||
} else {
|
} else {
|
||||||
let mut pool = AnyPool::connect(&upload_node.connection_string).await?;
|
let mut pool = AnyPool::connect(&upload_node.connection_string).await?;
|
||||||
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?;
|
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -176,27 +158,30 @@ impl RunnableNode for UploadNodeRunner {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use std::path;
|
||||||
|
|
||||||
use crate::graph::node::RunnableNode;
|
use crate::graph::node::RunnableNode;
|
||||||
use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner};
|
use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner};
|
||||||
use testcontainers::core::{IntoContainerPort, WaitFor};
|
use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
|
||||||
use testcontainers::runners::AsyncRunner;
|
use testcontainers::runners::AsyncRunner;
|
||||||
use testcontainers::{GenericImage, ImageExt};
|
use testcontainers::{GenericImage, ImageExt};
|
||||||
use tiberius::{Config, EncryptionLevel};
|
use tiberius::{Config, EncryptionLevel};
|
||||||
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
pub async fn check_basic_upload() -> anyhow::Result<()> {
|
pub async fn check_bulk_upload_mssql() -> anyhow::Result<()> {
|
||||||
let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest")
|
let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest")
|
||||||
.with_exposed_port(1433.tcp())
|
.with_exposed_port(1433.tcp())
|
||||||
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
|
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
|
||||||
.with_env_var("ACCEPT_EULA", "Y")
|
.with_env_var("ACCEPT_EULA", "Y")
|
||||||
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
|
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
|
||||||
|
.with_mount(Mount::bind_mount(path::absolute("./testing/input/upload_to_db")?.to_string_lossy().to_string(), "/upload_to_db"))
|
||||||
.start()
|
.start()
|
||||||
.await?;
|
.await?;
|
||||||
let host = container.get_host().await?;
|
let host = container.get_host().await?;
|
||||||
let port = container.get_host_port_ipv4(1433).await?;
|
let port = container.get_host_port_ipv4(1433).await?;
|
||||||
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
|
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
|
||||||
let file = "testing/input/upload_to_db/test.csv";
|
let file = "/upload_to_db/test.csv".to_owned();
|
||||||
let table_name = "My Test Table".to_string();
|
let table_name = "My Test Table".to_string();
|
||||||
let upload_node = UploadNodeRunner {
|
let upload_node = UploadNodeRunner {
|
||||||
upload_node: UploadNode {
|
upload_node: UploadNode {
|
||||||
@@ -206,7 +191,6 @@ mod tests {
|
|||||||
post_script: None,
|
post_script: None,
|
||||||
db_type: DBType::Mssql,
|
db_type: DBType::Mssql,
|
||||||
connection_string,
|
connection_string,
|
||||||
allow_table_creation: false,
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
upload_node.run().await?;
|
upload_node.run().await?;
|
||||||
@@ -215,6 +199,14 @@ mod tests {
|
|||||||
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
||||||
tcp.set_nodelay(true)?;
|
tcp.set_nodelay(true)?;
|
||||||
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
||||||
|
let result = client.query("SELECT * FROM [My Test Table]", &[]).await?;
|
||||||
|
let result = result.into_first_result().await?;
|
||||||
|
assert_eq!(1, result.len());
|
||||||
|
|
||||||
|
let first_row = result.get(0).unwrap();
|
||||||
|
assert_eq!(Some(1), first_row.get("column1"));
|
||||||
|
assert_eq!(Some("Hello"), first_row.get("column2"));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
2
testing/input/upload_to_db/test.csv
Executable file
2
testing/input/upload_to_db/test.csv
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
column1,column2
|
||||||
|
1,Hello
|
||||||
|
Reference in New Issue
Block a user