Add basic pull from db support

This commit is contained in:
2024-12-26 16:19:38 +10:30
parent 375e1f9638
commit 139d6fb7fd
9 changed files with 802 additions and 130 deletions

677
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -18,10 +18,10 @@ clap = { version = "4", features = ["derive"] }
anyhow = "1" anyhow = "1"
itertools = "0.13.0" itertools = "0.13.0"
chrono = { version = "0.4", features = ["default", "serde"] } chrono = { version = "0.4.39", features = ["default", "serde"] }
rayon = "1.6.0" rayon = "1.10.0"
tokio = { version = "1.39", features = ["full"] } tokio = { version = "1.42.0", features = ["full"] }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any"] } sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any"] }
rmp-serde = "1.1" rmp-serde = "1.1"
tempfile = "3.7" tempfile = "3.7"
@@ -31,12 +31,13 @@ serde_json = "1.0.122"
num_cpus = "1.16.0" num_cpus = "1.16.0"
schemars = { version = "0.8.21", features = ["chrono"] } schemars = { version = "0.8.21", features = ["chrono"] }
log = "0.4.22" log = "0.4.22"
env_logger = "0.11.5" env_logger = "0.11.6"
tiberius = { version = "0.12.3", features = ["chrono", "tokio"] } tiberius = { version = "0.12.3", features = ["chrono", "tokio"] }
futures-io = "0.3.30" futures-io = "0.3.31"
futures = "0.3.30" futures = "0.3.31"
tokio-util = { version = "0.7.11", features = ["compat"] } tokio-util = { version = "0.7.13", features = ["compat"] }
async-trait = "0.1.81" async-trait = "0.1.83"
testcontainers = "0.23.1"
# More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target # More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target
[lib] [lib]

View File

@@ -4,7 +4,7 @@ use async_trait::async_trait;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::io::{RecordDeserializer, RecordSerializer}; use crate::io::{DataSource, RecordDeserializer, RecordSerializer};
use super::node::RunnableNode; use super::node::RunnableNode;
@@ -165,8 +165,8 @@ pub struct DeriveRule {
#[derive(Serialize, Deserialize, Clone, JsonSchema)] #[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct DeriveNode { pub struct DeriveNode {
pub rules: Vec<DeriveRule>, pub rules: Vec<DeriveRule>,
pub input_file_path: String, pub input_data_source: DataSource,
pub output_file_path: String, pub output_data_source: DataSource,
pub copy_all_columns: bool, pub copy_all_columns: bool,
} }
@@ -302,8 +302,8 @@ pub struct DeriveNodeRunner {
#[async_trait] #[async_trait]
impl RunnableNode for DeriveNodeRunner { impl RunnableNode for DeriveNodeRunner {
async fn run(&self) -> anyhow::Result<()> { async fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.derive_node.input_file_path)?; let mut reader = csv::Reader::from_path(&self.derive_node.input_data_source.path)?;
let mut writer = csv::Writer::from_path(&self.derive_node.output_file_path)?; let mut writer = csv::Writer::from_path(&self.derive_node.output_data_source.path)?;
let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self
.derive_node .derive_node
.rules .rules

View File

@@ -4,7 +4,7 @@ use async_trait::async_trait;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::io::{RecordDeserializer, RecordSerializer}; use crate::io::{DataSource, RecordDeserializer, RecordSerializer};
use super::derive::{DataValidators, DeriveFilter}; use super::derive::{DataValidators, DeriveFilter};
@@ -42,8 +42,8 @@ pub fn filter_file(
#[derive(Serialize, Deserialize, Clone, JsonSchema)] #[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct FilterNode { pub struct FilterNode {
pub filters: Vec<DeriveFilter>, pub filters: Vec<DeriveFilter>,
pub input_file_path: String, pub input_data_source: DataSource,
pub output_file_path: String, pub output_data_source: DataSource,
} }
pub struct FilterNodeRunner { pub struct FilterNodeRunner {
@@ -53,8 +53,8 @@ pub struct FilterNodeRunner {
#[async_trait] #[async_trait]
impl RunnableNode for FilterNodeRunner { impl RunnableNode for FilterNodeRunner {
async fn run(&self) -> anyhow::Result<()> { async fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?; let mut reader = csv::Reader::from_path(&self.filter_node.input_data_source.path)?;
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?; let mut writer = csv::Writer::from_path(&self.filter_node.output_data_source.path)?;
let rules = derive::to_filter_rules(&self.filter_node.filters)?; let rules = derive::to_filter_rules(&self.filter_node.filters)?;
filter_file(&rules, &mut reader, &mut writer) filter_file(&rules, &mut reader, &mut writer)
} }
@@ -62,7 +62,6 @@ impl RunnableNode for FilterNodeRunner {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::derive::{Comparator, FilterRule}; use super::derive::{Comparator, FilterRule};
use super::filter_file; use super::filter_file;

View File

@@ -308,9 +308,10 @@ impl RunnableGraph {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use chrono::Local;
use super::{NodeConfiguration, RunnableGraph}; use super::{NodeConfiguration, RunnableGraph};
use crate::io::{DataSource, SourceType};
use chrono::Local;
use std::path::PathBuf;
#[tokio::test] #[tokio::test]
async fn test_basic() -> anyhow::Result<()> { async fn test_basic() -> anyhow::Result<()> {
@@ -324,8 +325,8 @@ mod tests {
name: "Hello".to_owned(), name: "Hello".to_owned(),
configuration: NodeConfiguration::FilterNode(super::FilterNode { configuration: NodeConfiguration::FilterNode(super::FilterNode {
filters: vec![], filters: vec![],
input_file_path: "".to_owned(), input_data_source: DataSource { path: PathBuf::from(""), source_type: SourceType::CSV },
output_file_path: "".to_owned(), output_data_source: DataSource { path: PathBuf::from(""), source_type: SourceType::CSV },
}), }),
output_files: vec![], output_files: vec![],
dynamic_configuration: None, dynamic_configuration: None,

View File

@@ -1,21 +1,36 @@
use super::sql::QueryExecutor; use super::sql::QueryExecutor;
use crate::graph::node::RunnableNode; use crate::graph::node::RunnableNode;
use crate::graph::upload_to_db::{upload_file_bulk, DBType}; use crate::graph::upload_to_db::{upload_file_bulk, DBType};
use crate::io::{DataSource, RecordSerializer};
use async_trait::async_trait; use async_trait::async_trait;
use polars::prelude::CsvWriter;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use sqlx::AnyPool; use sqlx::AnyPool;
use tiberius::Config; use std::collections::BTreeMap;
use tiberius::{AuthMethod, Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt; use tokio_util::compat::TokioAsyncWriteCompatExt;
/** /**
* Pull data from a db using a db query into a csv file that can be used by another node * Pull data from a db using a db query into a csv file that can be used by another node
*/ */
async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) {} async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) -> anyhow::Result<()> {
let mut output_file = csv::Writer::from_path(node.output_data_source.path.clone())?;
let mut first_row = true;
executor.get_rows(&node.query, &node.parameters, &mut move |row| {
if first_row {
output_file.write_header(&row)?;
first_row = false;
}
output_file.write_record(row.values())?;
Ok(())
}).await?;
Ok(())
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)] #[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct PullFromDBNode { pub struct PullFromDBNode {
file_path: String, output_data_source: DataSource,
query: String, query: String,
parameters: Vec<String>, parameters: Vec<String>,
db_type: DBType, db_type: DBType,
@@ -33,17 +48,75 @@ impl RunnableNode for PullFromDBNodeRunner {
// TODO: Clean up grabbing of connection/executor so don't need to repeat this between upload/download // TODO: Clean up grabbing of connection/executor so don't need to repeat this between upload/download
match node.db_type { match node.db_type {
DBType::Mssql => { DBType::Mssql => {
let config = Config::from_jdbc_string(&node.connection_string)?; let mut config = Config::from_jdbc_string(&node.connection_string)?;
// TODO: Restore encryption for remote hosts, doesn't work on localhost without encryption.
config.encryption(EncryptionLevel::NotSupported);
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?; let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?; tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?; let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
pull_from_db(&mut client, &node).await; pull_from_db(&mut client, &node).await?;
} }
_ => { _ => {
let mut pool = AnyPool::connect(&node.connection_string).await?; let mut pool = AnyPool::connect(&node.connection_string).await?;
pull_from_db(&mut pool, &node).await; pull_from_db(&mut pool, &node).await?;
} }
} }
Ok(()) Ok(())
} }
} }
#[cfg(test)]
mod tests {
use crate::graph::node::RunnableNode;
use crate::graph::pull_from_db::{PullFromDBNode, PullFromDBNodeRunner};
use crate::graph::upload_to_db::DBType::Mssql;
use crate::io::DataSource;
use crate::io::SourceType::CSV;
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use testcontainers::core::{IntoContainerPort, WaitFor};
use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt};
#[tokio::test]
async fn test_sql_server() -> anyhow::Result<()> {
let container = GenericImage::new("mcr.microsoft.com/mssql/server", "2022-latest")
.with_exposed_port(1433.tcp())
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
.with_env_var("ACCEPT_EULA", "Y")
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
.start()
.await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(1433).await?;
let port = 1433;
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123", host, port).to_owned();
let connection_string = "jdbc:sqlserver://localhost:1433;username=sa;password=TestOnlyContainer123;Encrypt=False".to_owned();
let runner = PullFromDBNodeRunner {
pull_from_db_node: PullFromDBNode {
db_type: Mssql,
query: "SELECT '1' Test".to_owned(),
parameters: vec![],
connection_string,
output_data_source: DataSource {
path: PathBuf::from("test_pull.csv"),
source_type: CSV,
},
}
};
runner.run().await?;
let mut result_contents = String::new();
let result_length = File::open("test_pull.csv")?.read_to_string(&mut result_contents)?;
assert_eq!(
"Test
1
",
result_contents
,
"Should pull the correct data from sql"
);
Ok(())
}
}

View File

@@ -1,21 +1,21 @@
use std::borrow::Borrow;
use futures::TryStreamExt; use futures::TryStreamExt;
use futures_io::{AsyncRead, AsyncWrite}; use futures_io::{AsyncRead, AsyncWrite};
use itertools::Itertools; use itertools::Itertools;
use sqlx::{Any, AnyPool, Column, Pool, Row}; use sqlx::{Any, Column, Pool, Row};
use std::borrow::Borrow;
use std::collections::BTreeMap;
use tiberius::{Client, Query}; use tiberius::{Client, Query};
// TODO: This doesn't seem to work. Suggestion by compiler is to instead create an enum and implement
// the trait on the enum (basically use a match in the implementation depending on which enum we have)
pub trait QueryExecutor { pub trait QueryExecutor {
// TODO: Params binding for filtering the same query?
// Retrieve data from a database // Retrieve data from a database
async fn get_rows( async fn get_rows(
&mut self, &mut self,
query: &str, query: &str,
params: &Vec<String>, params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>>; // TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator?
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()>;
// Run a query that returns no results (e.g. bulk insert, insert) // Run a query that returns no results (e.g. bulk insert, insert)
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64>; async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64>;
@@ -26,32 +26,25 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
&mut self, &mut self,
query: &str, query: &str,
params: &Vec<String>, params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>> { row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let mut query = Query::new(query); let mut query = Query::new(query);
for param in params { for param in params {
query.bind(param); query.bind(param);
} }
let query_result = query.query(self).await?; let query_result = query.query(self).await?;
let results = query_result.into_first_result().await?; let mut query_stream = query_result.into_row_stream();
let results = results
.into_iter() while let Some(row) = query_stream.try_next().await? {
.map(|row| { let mut returned_row = BTreeMap::new();
row.columns() // TODO: Check how empty columns are handled by tiberius
.into_iter() for column in row.columns().into_iter() {
.map(|column| { returned_row.insert(column.name().to_owned(), row.get(column.name()).unwrap_or_else(|| "")
( .to_owned());
column.name().to_owned(),
match row.get(column.name()) {
Some(value) => value,
None => "",
} }
.to_owned(), row_consumer(returned_row)?;
) }
}) Ok(())
.collect_vec()
})
.collect();
Ok(results)
} }
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> { async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
@@ -72,22 +65,22 @@ impl QueryExecutor for Pool<Any> {
&mut self, &mut self,
query: &str, query: &str,
params: &Vec<String>, params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>> { row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let mut query = sqlx::query(query); let mut query = sqlx::query(query);
for param in params { for param in params {
query = query.bind(param); query = query.bind(param);
} }
let mut rows = query.fetch(self.borrow()); let mut rows = query.fetch(self.borrow());
let mut results = vec![];
while let Some(row) = rows.try_next().await? { while let Some(row) = rows.try_next().await? {
results.push( let mut returned_row = BTreeMap::new();
row.columns() for column in row.columns().into_iter() {
.into_iter() returned_row.insert(column.name().to_owned(), row.get(column.name()));
.map(|column| (column.name().to_owned(), row.get(column.name())))
.collect(),
);
} }
Ok(results) row_consumer(returned_row)?;
}
Ok(())
} }
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> { async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {

View File

@@ -1,6 +1,10 @@
use std::fs::File; use std::fs::File;
use super::node::RunnableNode;
use crate::io::{DataSource, SourceType};
use async_trait::async_trait; use async_trait::async_trait;
use polars::io::SerReader;
use polars::prelude::{IntoLazy, LazyFrame, ParquetReader, ScanArgsParquet};
use polars::{ use polars::{
io::SerWriter, io::SerWriter,
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader}, prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
@@ -9,21 +13,24 @@ use polars_sql::SQLContext;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::node::RunnableNode;
#[derive(Serialize, Deserialize, Clone, JsonSchema)] #[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct CSVFile { pub struct SqlFile {
name: String, pub name: String,
path: String, pub data_source: DataSource,
} }
/** /**
* Run SQL over files using polars, export results to output file * Run SQL over files using polars, export results to output file
*/ */
fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow::Result<()> { fn run_sql(files: &Vec<SqlFile>, output_path: &String, query: &String) -> anyhow::Result<()> {
let mut ctx = SQLContext::new(); let mut ctx = SQLContext::new();
for file in files { for file in files {
let df = LazyCsvReader::new(&file.path).finish()?; let df = match file.data_source.source_type {
SourceType::CSV => LazyCsvReader::new(&file.data_source.path).finish()?,
SourceType::PARQUET => {
LazyFrame::scan_parquet(&file.data_source.path, ScanArgsParquet::default())?
}
};
ctx.register(&file.name, df); ctx.register(&file.name, df);
} }
let result = ctx.execute(&query)?; let result = ctx.execute(&query)?;
@@ -34,7 +41,7 @@ fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow
#[derive(Serialize, Deserialize, Clone, JsonSchema)] #[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct SQLNode { pub struct SQLNode {
pub files: Vec<CSVFile>, pub files: Vec<SqlFile>,
pub output_file: String, pub output_file: String,
pub query: String, pub query: String,
} }
@@ -55,17 +62,21 @@ impl RunnableNode for SQLNodeRunner {
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{run_sql, SqlFile};
use crate::io::{DataSource, SourceType};
use std::path::PathBuf;
use std::{fs::File, io::Read}; use std::{fs::File, io::Read};
use super::{run_sql, CSVFile};
#[test] #[test]
fn basic_query_works() -> anyhow::Result<()> { fn basic_query_works() -> anyhow::Result<()> {
let output_path = "./testing/output/output.csv".to_owned(); let output_path = "./testing/output/output.csv".to_owned();
run_sql( run_sql(
&vec![CSVFile { &vec![SqlFile {
name: "Account".to_owned(), name: "Account".to_owned(),
path: "./testing/test.csv".to_owned(), data_source: DataSource {
source_type: SourceType::CSV,
path: PathBuf::from("./testing/test.csv"),
}
}], }],
&output_path, &output_path,
&"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(), &"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(),
@@ -76,7 +87,7 @@ mod tests {
assert_eq!( assert_eq!(
output, output,
"Code,Description,Type,CostOutput,PercentFixed "Code,Description,Type,CostOutput,PercentFixed
A195950,A195950 Staff Related Other,E,GS,100.00 A195950,A195950 Staff Related Other,E,GS,100.0
" "
); );
Ok(()) Ok(())

View File

@@ -1,11 +1,24 @@
use anyhow::bail;
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::path::PathBuf;
use std::{ use std::{
collections::BTreeMap, collections::BTreeMap,
io::{Read, Seek, Write}, io::{Read, Seek, Write},
}; };
use anyhow::bail; #[derive(Serialize, Deserialize, Clone, JsonSchema)]
use rmp_serde::{decode::ReadReader, Deserializer, Serializer}; pub enum SourceType {
use serde::{de::DeserializeOwned, Deserialize, Serialize}; CSV,
PARQUET,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct DataSource {
pub path: PathBuf,
pub source_type: SourceType,
}
pub trait RecordSerializer { pub trait RecordSerializer {
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>; fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;