diff --git a/src/graph/pull_from_db.rs b/src/graph/pull_from_db.rs index a0210f4..c0ee49c 100644 --- a/src/graph/pull_from_db.rs +++ b/src/graph/pull_from_db.rs @@ -1,12 +1,49 @@ +use super::sql::QueryExecutor; +use crate::graph::node::RunnableNode; +use crate::graph::upload_to_db::{upload_file_bulk, DBType}; +use async_trait::async_trait; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; - -use super::sql::QueryExecutor; +use sqlx::AnyPool; +use tiberius::Config; +use tokio_util::compat::TokioAsyncWriteCompatExt; /** * Pull data from a db using a db query into a csv file that can be used by another node */ -fn pull_from_db(executor: &mut impl QueryExecutor, node: PullFromDBNode) {} +async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) {} #[derive(Serialize, Deserialize, Clone, JsonSchema)] -pub struct PullFromDBNode {} +pub struct PullFromDBNode { + file_path: String, + query: String, + parameters: Vec, + db_type: DBType, + connection_string: String, +} + +pub struct PullFromDBNodeRunner { + pub pull_from_db_node: PullFromDBNode, +} + +#[async_trait] +impl RunnableNode for PullFromDBNodeRunner { + async fn run(&self) -> anyhow::Result<()> { + let node = self.pull_from_db_node.clone(); + // TODO: Clean up grabbing of connection/executor so don't need to repeat this between upload/download + match node.db_type { + DBType::Mssql => { + let config = Config::from_jdbc_string(&node.connection_string)?; + let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + pull_from_db(&mut client, &node).await; + } + _ => { + let mut pool = AnyPool::connect(&node.connection_string).await?; + pull_from_db(&mut pool, &node).await; + } + } + Ok(()) + } +} diff --git a/src/graph/upload_to_db.rs b/src/graph/upload_to_db.rs index 9ee09fa..9bf95f9 100644 --- a/src/graph/upload_to_db.rs +++ b/src/graph/upload_to_db.rs @@ -43,20 +43,24 @@ pub async fn upload_file_bulk( if rows_affected == None { let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?; - let csv_columns = file_reader.headers()?.iter().map(|header| header.to_owned()).collect_vec(); + let csv_columns = file_reader + .headers()? + .iter() + .map(|header| header.to_owned()) + .collect_vec(); let table_columns = if let Some(column_mappings) = &upload_node.column_mappings { csv_columns .iter() - .map(|column| { - column_mappings - .get(column).unwrap_or(column) - .clone() - }) + .map(|column| column_mappings.get(column).unwrap_or(column).clone()) .collect_vec() } else { csv_columns.clone() }; - let query_template = format!("INSERT INTO {}({}) \n", upload_node.table_name, table_columns.join(",")); + let query_template = format!( + "INSERT INTO {}({}) \n", + upload_node.table_name, + table_columns.join(",") + ); let mut params = vec![]; let mut insert_query = "".to_owned(); let mut num_params = 0; @@ -64,7 +68,8 @@ pub async fn upload_file_bulk( for result in file_reader.records() { let result = result?; - insert_query = insert_query + format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str(); + insert_query = insert_query + + format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str(); let mut values = result.iter().map(|value| value.to_owned()).collect_vec(); params.append(&mut values); num_params += csv_columns.len(); @@ -120,24 +125,16 @@ pub struct UploadNodeRunner { impl RunnableNode for UploadNodeRunner { async fn run(&self) -> anyhow::Result<()> { let upload_node = self.upload_node.clone(); - if upload_node.db_type == DBType::Mssql { - let config = Config::from_jdbc_string(&upload_node.connection_string); - if let Ok(config) = config { - let tcp = tokio::net::TcpStream::connect(config.get_addr()).await; - if let Ok(tcp) = tcp { - tcp.set_nodelay(true)?; - let client = tiberius::Client::connect(config, tcp.compat_write()).await; - if let Ok(mut client) = client { - upload_file_bulk(&mut client, &upload_node, BIND_LIMIT).await?; - } - } - } - }else { - let pool = AnyPool::connect(&upload_node.connection_string).await; - if let Ok(mut pool) = pool { - upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT).await?; - } - } + if upload_node.db_type == DBType::Mssql { + let config = Config::from_jdbc_string(&upload_node.connection_string)?; + let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; + tcp.set_nodelay(true)?; + let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; + upload_file_bulk(&mut client, &upload_node, BIND_LIMIT).await?; + } else { + let mut pool = AnyPool::connect(&upload_node.connection_string).await?; + upload_file_bulk(&mut pool, &upload_node, BIND_LIMIT).await?; + } Ok(()) } }