Get bulk upload to db working
Some checks failed
test / test (push) Failing after 14m35s

This commit is contained in:
2025-02-04 17:28:40 +10:30
parent c02a8cd5ab
commit 284375eb3f
2 changed files with 30 additions and 36 deletions

View File

@@ -1,10 +1,11 @@
use anyhow::bail; use anyhow::bail;
use async_trait::async_trait; use async_trait::async_trait;
use itertools::Itertools; use itertools::Itertools;
use log::{log, Level};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::AnyPool; use sqlx::AnyPool;
use std::collections::HashMap; use std::{collections::HashMap, fmt::format};
use tiberius::{Config, EncryptionLevel}; use tiberius::{Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt; use tokio_util::compat::TokioAsyncWriteCompatExt;
@@ -17,27 +18,8 @@ pub async fn upload_file_bulk(
upload_node: &UploadNode, upload_node: &UploadNode,
bind_limit: usize, bind_limit: usize,
db_type: &DBType, db_type: &DBType,
allow_table_creation: bool,
) -> anyhow::Result<u64> { ) -> anyhow::Result<u64> {
let mut rows_affected = None; let mut rows_affected = None;
let mut num_matching_tables: usize = 0;
executor.get_rows(&format!("SELECT COUNT(*) Count
FROM information_schema.tables
WHERE table_schema = {}
AND table_name = ?
LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut move |row| {
num_matching_tables = row.get("Count").map(|count| count.parse().unwrap_or(0)).unwrap_or(0);
Ok(())
}).await?;
if num_matching_tables == 0 {
if allow_table_creation {
// TODO: Create the table with the columns in the file to be uploaded
} else {
bail!("Table creation not allowed and table does not yet exist, aborting");
}
}
if upload_node.column_mappings.is_none() { if upload_node.column_mappings.is_none() {
let insert_from_file_query = match upload_node.db_type { let insert_from_file_query = match upload_node.db_type {
DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)), DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)),
@@ -45,7 +27,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
"LOAD DATA INFILE ? INTO {}", "LOAD DATA INFILE ? INTO {}",
upload_node.table_name, upload_node.table_name,
)), )),
DBType::Mssql => Some(format!("BULK INSERT [{}] FROM ? WITH ( FORMAT = 'CSV');", upload_node.table_name)), DBType::Mssql => Some(format!("DECLARE @insertStatement VARCHAR(MAX) = 'BULK INSERT [{}] FROM ''' + @P1 + ''' WITH ( FORMAT = ''CSV'', FIRSTROW = 2 )'; EXEC(@insertStatement);", upload_node.table_name)),
_ => None, _ => None,
}; };
if let Some(insert_from_file_query) = insert_from_file_query { if let Some(insert_from_file_query) = insert_from_file_query {
@@ -57,13 +39,14 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
.await; .await;
if let Ok(result) = result { if let Ok(result) = result {
rows_affected = Some(result); rows_affected = Some(result);
} else {
log!(Level::Debug, "Failed to bulk insert, trying sql insert instead");
} }
} }
} }
if rows_affected == None { if rows_affected == None {
let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?; let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?;
let csv_columns = file_reader let csv_columns = file_reader
.headers()? .headers()?
.iter() .iter()
@@ -90,7 +73,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
for result in file_reader.records() { for result in file_reader.records() {
let result = result?; let result = result?;
insert_query = insert_query insert_query = insert_query
+ format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str(); + format!("VALUES ({})", result.iter().enumerate().map(|(index, _)| db_type.get_param_name(index)).join(",")).as_str();
let mut values = result.iter().map(|value| value.to_owned()).collect_vec(); let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
params.append(&mut values); params.append(&mut values);
num_params += csv_columns.len(); num_params += csv_columns.len();
@@ -128,13 +111,13 @@ pub enum DBType {
} }
impl DBType { impl DBType {
pub fn get_schema_fn(&self) -> String { pub fn get_param_name(&self, index: usize) -> String {
match self { match self {
DBType::Postgres => "CURRENT_SCHEMA()", DBType::Postgres => format!("${}", index),
DBType::Mysql => "DATABASE()", DBType::Mysql => "?".to_owned(),
DBType::Mssql => "DB_NAME()", DBType::Mssql => format!("@P{}", index),
DBType::Sqlite => "'sqlite'", DBType::Sqlite => "?".to_owned(),
}.to_owned() }
} }
} }
@@ -147,7 +130,6 @@ pub struct UploadNode {
post_script: Option<String>, post_script: Option<String>,
db_type: DBType, db_type: DBType,
connection_string: String, connection_string: String,
allow_table_creation: bool,
} }
pub struct UploadNodeRunner { pub struct UploadNodeRunner {
@@ -165,10 +147,10 @@ impl RunnableNode for UploadNodeRunner {
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?; tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?; upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
} else { } else {
let mut pool = AnyPool::connect(&upload_node.connection_string).await?; let mut pool = AnyPool::connect(&upload_node.connection_string).await?;
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?; upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
} }
Ok(()) Ok(())
} }
@@ -176,27 +158,30 @@ impl RunnableNode for UploadNodeRunner {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::path;
use crate::graph::node::RunnableNode; use crate::graph::node::RunnableNode;
use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner}; use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner};
use testcontainers::core::{IntoContainerPort, WaitFor}; use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
use testcontainers::runners::AsyncRunner; use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt}; use testcontainers::{GenericImage, ImageExt};
use tiberius::{Config, EncryptionLevel}; use tiberius::{Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt; use tokio_util::compat::TokioAsyncWriteCompatExt;
#[tokio::test] #[tokio::test]
pub async fn check_basic_upload() -> anyhow::Result<()> { pub async fn check_bulk_upload_mssql() -> anyhow::Result<()> {
let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest") let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest")
.with_exposed_port(1433.tcp()) .with_exposed_port(1433.tcp())
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned())) .with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
.with_env_var("ACCEPT_EULA", "Y") .with_env_var("ACCEPT_EULA", "Y")
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123") .with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
.with_mount(Mount::bind_mount(path::absolute("./testing/input/upload_to_db")?.to_string_lossy().to_string(), "/upload_to_db"))
.start() .start()
.await?; .await?;
let host = container.get_host().await?; let host = container.get_host().await?;
let port = container.get_host_port_ipv4(1433).await?; let port = container.get_host_port_ipv4(1433).await?;
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned(); let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
let file = "testing/input/upload_to_db/test.csv"; let file = "/upload_to_db/test.csv".to_owned();
let table_name = "My Test Table".to_string(); let table_name = "My Test Table".to_string();
let upload_node = UploadNodeRunner { let upload_node = UploadNodeRunner {
upload_node: UploadNode { upload_node: UploadNode {
@@ -206,7 +191,6 @@ mod tests {
post_script: None, post_script: None,
db_type: DBType::Mssql, db_type: DBType::Mssql,
connection_string, connection_string,
allow_table_creation: false,
} }
}; };
upload_node.run().await?; upload_node.run().await?;
@@ -215,6 +199,14 @@ mod tests {
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?; tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
let result = client.query("SELECT * FROM [My Test Table]", &[]).await?;
let result = result.into_first_result().await?;
assert_eq!(1, result.len());
let first_row = result.get(0).unwrap();
assert_eq!(Some(1), first_row.get("column1"));
assert_eq!(Some("Hello"), first_row.get("column2"));
Ok(()) Ok(())
} }
} }

View File

@@ -0,0 +1,2 @@
column1,column2
1,Hello
1 column1 column2
2 1 Hello