Add basic pull from db support

This commit is contained in:
2024-12-26 16:19:38 +10:30
parent 375e1f9638
commit 139d6fb7fd
9 changed files with 802 additions and 130 deletions

677
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -18,10 +18,10 @@ clap = { version = "4", features = ["derive"] }
anyhow = "1"
itertools = "0.13.0"
chrono = { version = "0.4", features = ["default", "serde"] }
chrono = { version = "0.4.39", features = ["default", "serde"] }
rayon = "1.6.0"
tokio = { version = "1.39", features = ["full"] }
rayon = "1.10.0"
tokio = { version = "1.42.0", features = ["full"] }
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any"] }
rmp-serde = "1.1"
tempfile = "3.7"
@@ -31,12 +31,13 @@ serde_json = "1.0.122"
num_cpus = "1.16.0"
schemars = { version = "0.8.21", features = ["chrono"] }
log = "0.4.22"
env_logger = "0.11.5"
env_logger = "0.11.6"
tiberius = { version = "0.12.3", features = ["chrono", "tokio"] }
futures-io = "0.3.30"
futures = "0.3.30"
tokio-util = { version = "0.7.11", features = ["compat"] }
async-trait = "0.1.81"
futures-io = "0.3.31"
futures = "0.3.31"
tokio-util = { version = "0.7.13", features = ["compat"] }
async-trait = "0.1.83"
testcontainers = "0.23.1"
# More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target
[lib]

View File

@@ -4,7 +4,7 @@ use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::io::{RecordDeserializer, RecordSerializer};
use crate::io::{DataSource, RecordDeserializer, RecordSerializer};
use super::node::RunnableNode;
@@ -165,8 +165,8 @@ pub struct DeriveRule {
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct DeriveNode {
pub rules: Vec<DeriveRule>,
pub input_file_path: String,
pub output_file_path: String,
pub input_data_source: DataSource,
pub output_data_source: DataSource,
pub copy_all_columns: bool,
}
@@ -302,8 +302,8 @@ pub struct DeriveNodeRunner {
#[async_trait]
impl RunnableNode for DeriveNodeRunner {
async fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.derive_node.input_file_path)?;
let mut writer = csv::Writer::from_path(&self.derive_node.output_file_path)?;
let mut reader = csv::Reader::from_path(&self.derive_node.input_data_source.path)?;
let mut writer = csv::Writer::from_path(&self.derive_node.output_data_source.path)?;
let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self
.derive_node
.rules

View File

@@ -4,7 +4,7 @@ use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use crate::io::{RecordDeserializer, RecordSerializer};
use crate::io::{DataSource, RecordDeserializer, RecordSerializer};
use super::derive::{DataValidators, DeriveFilter};
@@ -42,8 +42,8 @@ pub fn filter_file(
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct FilterNode {
pub filters: Vec<DeriveFilter>,
pub input_file_path: String,
pub output_file_path: String,
pub input_data_source: DataSource,
pub output_data_source: DataSource,
}
pub struct FilterNodeRunner {
@@ -53,8 +53,8 @@ pub struct FilterNodeRunner {
#[async_trait]
impl RunnableNode for FilterNodeRunner {
async fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
let mut reader = csv::Reader::from_path(&self.filter_node.input_data_source.path)?;
let mut writer = csv::Writer::from_path(&self.filter_node.output_data_source.path)?;
let rules = derive::to_filter_rules(&self.filter_node.filters)?;
filter_file(&rules, &mut reader, &mut writer)
}
@@ -62,7 +62,6 @@ impl RunnableNode for FilterNodeRunner {
#[cfg(test)]
mod tests {
use super::derive::{Comparator, FilterRule};
use super::filter_file;

View File

@@ -308,9 +308,10 @@ impl RunnableGraph {
#[cfg(test)]
mod tests {
use chrono::Local;
use super::{NodeConfiguration, RunnableGraph};
use crate::io::{DataSource, SourceType};
use chrono::Local;
use std::path::PathBuf;
#[tokio::test]
async fn test_basic() -> anyhow::Result<()> {
@@ -324,8 +325,8 @@ mod tests {
name: "Hello".to_owned(),
configuration: NodeConfiguration::FilterNode(super::FilterNode {
filters: vec![],
input_file_path: "".to_owned(),
output_file_path: "".to_owned(),
input_data_source: DataSource { path: PathBuf::from(""), source_type: SourceType::CSV },
output_data_source: DataSource { path: PathBuf::from(""), source_type: SourceType::CSV },
}),
output_files: vec![],
dynamic_configuration: None,

View File

@@ -1,21 +1,36 @@
use super::sql::QueryExecutor;
use crate::graph::node::RunnableNode;
use crate::graph::upload_to_db::{upload_file_bulk, DBType};
use crate::io::{DataSource, RecordSerializer};
use async_trait::async_trait;
use polars::prelude::CsvWriter;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sqlx::AnyPool;
use tiberius::Config;
use std::collections::BTreeMap;
use tiberius::{AuthMethod, Config, EncryptionLevel};
use tokio_util::compat::TokioAsyncWriteCompatExt;
/**
* Pull data from a db using a db query into a csv file that can be used by another node
*/
async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) {}
async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) -> anyhow::Result<()> {
let mut output_file = csv::Writer::from_path(node.output_data_source.path.clone())?;
let mut first_row = true;
executor.get_rows(&node.query, &node.parameters, &mut move |row| {
if first_row {
output_file.write_header(&row)?;
first_row = false;
}
output_file.write_record(row.values())?;
Ok(())
}).await?;
Ok(())
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct PullFromDBNode {
file_path: String,
output_data_source: DataSource,
query: String,
parameters: Vec<String>,
db_type: DBType,
@@ -33,17 +48,75 @@ impl RunnableNode for PullFromDBNodeRunner {
// TODO: Clean up grabbing of connection/executor so don't need to repeat this between upload/download
match node.db_type {
DBType::Mssql => {
let config = Config::from_jdbc_string(&node.connection_string)?;
let mut config = Config::from_jdbc_string(&node.connection_string)?;
// TODO: Restore encryption for remote hosts, doesn't work on localhost without encryption.
config.encryption(EncryptionLevel::NotSupported);
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
tcp.set_nodelay(true)?;
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
pull_from_db(&mut client, &node).await;
pull_from_db(&mut client, &node).await?;
}
_ => {
let mut pool = AnyPool::connect(&node.connection_string).await?;
pull_from_db(&mut pool, &node).await;
pull_from_db(&mut pool, &node).await?;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::graph::node::RunnableNode;
use crate::graph::pull_from_db::{PullFromDBNode, PullFromDBNodeRunner};
use crate::graph::upload_to_db::DBType::Mssql;
use crate::io::DataSource;
use crate::io::SourceType::CSV;
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use testcontainers::core::{IntoContainerPort, WaitFor};
use testcontainers::runners::AsyncRunner;
use testcontainers::{GenericImage, ImageExt};
#[tokio::test]
async fn test_sql_server() -> anyhow::Result<()> {
let container = GenericImage::new("mcr.microsoft.com/mssql/server", "2022-latest")
.with_exposed_port(1433.tcp())
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
.with_env_var("ACCEPT_EULA", "Y")
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
.start()
.await?;
let host = container.get_host().await?;
let port = container.get_host_port_ipv4(1433).await?;
let port = 1433;
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123", host, port).to_owned();
let connection_string = "jdbc:sqlserver://localhost:1433;username=sa;password=TestOnlyContainer123;Encrypt=False".to_owned();
let runner = PullFromDBNodeRunner {
pull_from_db_node: PullFromDBNode {
db_type: Mssql,
query: "SELECT '1' Test".to_owned(),
parameters: vec![],
connection_string,
output_data_source: DataSource {
path: PathBuf::from("test_pull.csv"),
source_type: CSV,
},
}
};
runner.run().await?;
let mut result_contents = String::new();
let result_length = File::open("test_pull.csv")?.read_to_string(&mut result_contents)?;
assert_eq!(
"Test
1
",
result_contents
,
"Should pull the correct data from sql"
);
Ok(())
}
}

View File

@@ -1,21 +1,21 @@
use std::borrow::Borrow;
use futures::TryStreamExt;
use futures_io::{AsyncRead, AsyncWrite};
use itertools::Itertools;
use sqlx::{Any, AnyPool, Column, Pool, Row};
use sqlx::{Any, Column, Pool, Row};
use std::borrow::Borrow;
use std::collections::BTreeMap;
use tiberius::{Client, Query};
// TODO: This doesn't seem to work. Suggestion by compiler is to instead create an enum and implement
// the trait on the enum (basically use a match in the implementation depending on which enum we have)
pub trait QueryExecutor {
// TODO: Params binding for filtering the same query?
// Retrieve data from a database
async fn get_rows(
&mut self,
query: &str,
params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>>;
// TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator?
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()>;
// Run a query that returns no results (e.g. bulk insert, insert)
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64>;
@@ -26,32 +26,25 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
&mut self,
query: &str,
params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>> {
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let mut query = Query::new(query);
for param in params {
query.bind(param);
}
let query_result = query.query(self).await?;
let results = query_result.into_first_result().await?;
let results = results
.into_iter()
.map(|row| {
row.columns()
.into_iter()
.map(|column| {
(
column.name().to_owned(),
match row.get(column.name()) {
Some(value) => value,
None => "",
let mut query_stream = query_result.into_row_stream();
while let Some(row) = query_stream.try_next().await? {
let mut returned_row = BTreeMap::new();
// TODO: Check how empty columns are handled by tiberius
for column in row.columns().into_iter() {
returned_row.insert(column.name().to_owned(), row.get(column.name()).unwrap_or_else(|| "")
.to_owned());
}
.to_owned(),
)
})
.collect_vec()
})
.collect();
Ok(results)
row_consumer(returned_row)?;
}
Ok(())
}
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
@@ -72,22 +65,22 @@ impl QueryExecutor for Pool<Any> {
&mut self,
query: &str,
params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>> {
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
) -> anyhow::Result<()> {
let mut query = sqlx::query(query);
for param in params {
query = query.bind(param);
}
let mut rows = query.fetch(self.borrow());
let mut results = vec![];
while let Some(row) = rows.try_next().await? {
results.push(
row.columns()
.into_iter()
.map(|column| (column.name().to_owned(), row.get(column.name())))
.collect(),
);
let mut returned_row = BTreeMap::new();
for column in row.columns().into_iter() {
returned_row.insert(column.name().to_owned(), row.get(column.name()));
}
Ok(results)
row_consumer(returned_row)?;
}
Ok(())
}
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {

View File

@@ -1,6 +1,10 @@
use std::fs::File;
use super::node::RunnableNode;
use crate::io::{DataSource, SourceType};
use async_trait::async_trait;
use polars::io::SerReader;
use polars::prelude::{IntoLazy, LazyFrame, ParquetReader, ScanArgsParquet};
use polars::{
io::SerWriter,
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
@@ -9,21 +13,24 @@ use polars_sql::SQLContext;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::node::RunnableNode;
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct CSVFile {
name: String,
path: String,
pub struct SqlFile {
pub name: String,
pub data_source: DataSource,
}
/**
* Run SQL over files using polars, export results to output file
*/
fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow::Result<()> {
fn run_sql(files: &Vec<SqlFile>, output_path: &String, query: &String) -> anyhow::Result<()> {
let mut ctx = SQLContext::new();
for file in files {
let df = LazyCsvReader::new(&file.path).finish()?;
let df = match file.data_source.source_type {
SourceType::CSV => LazyCsvReader::new(&file.data_source.path).finish()?,
SourceType::PARQUET => {
LazyFrame::scan_parquet(&file.data_source.path, ScanArgsParquet::default())?
}
};
ctx.register(&file.name, df);
}
let result = ctx.execute(&query)?;
@@ -34,7 +41,7 @@ fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct SQLNode {
pub files: Vec<CSVFile>,
pub files: Vec<SqlFile>,
pub output_file: String,
pub query: String,
}
@@ -55,17 +62,21 @@ impl RunnableNode for SQLNodeRunner {
}
#[cfg(test)]
mod tests {
use super::{run_sql, SqlFile};
use crate::io::{DataSource, SourceType};
use std::path::PathBuf;
use std::{fs::File, io::Read};
use super::{run_sql, CSVFile};
#[test]
fn basic_query_works() -> anyhow::Result<()> {
let output_path = "./testing/output/output.csv".to_owned();
run_sql(
&vec![CSVFile {
&vec![SqlFile {
name: "Account".to_owned(),
path: "./testing/test.csv".to_owned(),
data_source: DataSource {
source_type: SourceType::CSV,
path: PathBuf::from("./testing/test.csv"),
}
}],
&output_path,
&"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(),
@@ -76,7 +87,7 @@ mod tests {
assert_eq!(
output,
"Code,Description,Type,CostOutput,PercentFixed
A195950,A195950 Staff Related Other,E,GS,100.00
A195950,A195950 Staff Related Other,E,GS,100.0
"
);
Ok(())

View File

@@ -1,11 +1,24 @@
use anyhow::bail;
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
use schemars::JsonSchema;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::path::PathBuf;
use std::{
collections::BTreeMap,
io::{Read, Seek, Write},
};
use anyhow::bail;
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub enum SourceType {
CSV,
PARQUET,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct DataSource {
pub path: PathBuf,
pub source_type: SourceType,
}
pub trait RecordSerializer {
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;