Get bulk upload to db working
Some checks failed
test / test (push) Failing after 14m35s

This commit is contained in:
2025-02-04 17:28:40 +10:30
parent c02a8cd5ab
commit 284375eb3f
2 changed files with 30 additions and 36 deletions

View File

@@ -1,10 +1,11 @@
use anyhow::bail;
use async_trait::async_trait;
use itertools::Itertools;
use log::{log, Level};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sqlx::AnyPool;
use std::collections::HashMap;
use std::{collections::HashMap, fmt::format};
use tiberius::{Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt;
@@ -17,27 +18,8 @@ pub async fn upload_file_bulk(
upload_node: &UploadNode,
bind_limit: usize,
db_type: &DBType,
allow_table_creation: bool,
) -> anyhow::Result<u64> {
let mut rows_affected = None;
let mut num_matching_tables: usize = 0;
executor.get_rows(&format!("SELECT COUNT(*) Count
FROM information_schema.tables
WHERE table_schema = {}
AND table_name = ?
LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut move |row| {
num_matching_tables = row.get("Count").map(|count| count.parse().unwrap_or(0)).unwrap_or(0);
Ok(())
}).await?;
if num_matching_tables == 0 {
if allow_table_creation {
// TODO: Create the table with the columns in the file to be uploaded
} else {
bail!("Table creation not allowed and table does not yet exist, aborting");
}
}
if upload_node.column_mappings.is_none() {
let insert_from_file_query = match upload_node.db_type {
DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)),
@@ -45,7 +27,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
"LOAD DATA INFILE ? INTO {}",
upload_node.table_name,
)),
DBType::Mssql => Some(format!("BULK INSERT [{}] FROM ? WITH ( FORMAT = 'CSV');", upload_node.table_name)),
DBType::Mssql => Some(format!("DECLARE @insertStatement VARCHAR(MAX) = 'BULK INSERT [{}] FROM ''' + @P1 + ''' WITH ( FORMAT = ''CSV'', FIRSTROW = 2 )'; EXEC(@insertStatement);", upload_node.table_name)),
_ => None,
};
if let Some(insert_from_file_query) = insert_from_file_query {
@@ -57,13 +39,14 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
.await;
if let Ok(result) = result {
rows_affected = Some(result);
} else {
log!(Level::Debug, "Failed to bulk insert, trying sql insert instead");
}
}
}
if rows_affected == None {
let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?;
let csv_columns = file_reader
.headers()?
.iter()
@@ -90,7 +73,7 @@ LIMIT 1;", db_type.get_schema_fn()), &[upload_node.table_name.clone()], &mut mov
for result in file_reader.records() {
let result = result?;
insert_query = insert_query
+ format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str();
+ format!("VALUES ({})", result.iter().enumerate().map(|(index, _)| db_type.get_param_name(index)).join(",")).as_str();
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
params.append(&mut values);
num_params += csv_columns.len();
@@ -128,13 +111,13 @@ pub enum DBType {
}
impl DBType {
pub fn get_schema_fn(&self) -> String {
pub fn get_param_name(&self, index: usize) -> String {
match self {
DBType::Postgres => "CURRENT_SCHEMA()",
DBType::Mysql => "DATABASE()",
DBType::Mssql => "DB_NAME()",
DBType::Sqlite => "'sqlite'",
}.to_owned()
DBType::Postgres => format!("${}", index),
DBType::Mysql => "?".to_owned(),
DBType::Mssql => format!("@P{}", index),
DBType::Sqlite => "?".to_owned(),
}
}
}
@@ -147,7 +130,6 @@ pub struct UploadNode {
post_script: Option<String>,
db_type: DBType,
connection_string: String,
allow_table_creation: bool,
}
pub struct UploadNodeRunner {
@@ -165,10 +147,10 @@ impl RunnableNode for UploadNodeRunner {
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?;
upload_file_bulk(&mut client, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
} else {
let mut pool = AnyPool::connect(&upload_node.connection_string).await?;
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type, upload_node.allow_table_creation).await?;
upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT, &upload_node.db_type).await?;
}
Ok(())
}
@@ -176,27 +158,30 @@ impl RunnableNode for UploadNodeRunner {
#[cfg(test)]
mod tests {
use std::path;
use crate::graph::node::RunnableNode;
use crate::graph::upload_to_db::{DBType, UploadNode, UploadNodeRunner};
use testcontainers::core::{IntoContainerPort, WaitFor};
use testcontainers::core::{IntoContainerPort, Mount, WaitFor};
use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt};
use tiberius::{Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt;
#[tokio::test]
pub async fn check_basic_upload() -> anyhow::Result<()> {
pub async fn check_bulk_upload_mssql() -> anyhow::Result<()> {
let container = GenericImage::new("gitea.michaelpivato.dev/vato007/ingey-test-db-mssql", "latest")
.with_exposed_port(1433.tcp())
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
.with_env_var("ACCEPT_EULA", "Y")
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
.with_mount(Mount::bind_mount(path::absolute("./testing/input/upload_to_db")?.to_string_lossy().to_string(), "/upload_to_db"))
.start()
.await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(1433).await?;
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123;database=TestIngeyDatabase", host, port).to_owned();
let file = "testing/input/upload_to_db/test.csv";
let file = "/upload_to_db/test.csv".to_owned();
let table_name = "My Test Table".to_string();
let upload_node = UploadNodeRunner {
upload_node: UploadNode {
@@ -206,7 +191,6 @@ mod tests {
post_script: None,
db_type: DBType::Mssql,
connection_string,
allow_table_creation: false,
}
};
upload_node.run().await?;
@@ -215,6 +199,14 @@ mod tests {
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
let result = client.query("SELECT * FROM [My Test Table]", &[]).await?;
let result = result.into_first_result().await?;
assert_eq!(1, result.len());
let first_row = result.get(0).unwrap();
assert_eq!(Some(1), first_row.get("column1"));
assert_eq!(Some("Hello"), first_row.get("column2"));
Ok(())
}
}

View File

@@ -0,0 +1,2 @@
column1,column2
1,Hello
1 column1 column2
2 1 Hello