Add basic pull from db support
This commit is contained in:
677
Cargo.lock
generated
677
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
17
Cargo.toml
17
Cargo.toml
@@ -18,10 +18,10 @@ clap = { version = "4", features = ["derive"] }
|
|||||||
anyhow = "1"
|
anyhow = "1"
|
||||||
|
|
||||||
itertools = "0.13.0"
|
itertools = "0.13.0"
|
||||||
chrono = { version = "0.4", features = ["default", "serde"] }
|
chrono = { version = "0.4.39", features = ["default", "serde"] }
|
||||||
|
|
||||||
rayon = "1.6.0"
|
rayon = "1.10.0"
|
||||||
tokio = { version = "1.39", features = ["full"] }
|
tokio = { version = "1.42.0", features = ["full"] }
|
||||||
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any"] }
|
sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any"] }
|
||||||
rmp-serde = "1.1"
|
rmp-serde = "1.1"
|
||||||
tempfile = "3.7"
|
tempfile = "3.7"
|
||||||
@@ -31,12 +31,13 @@ serde_json = "1.0.122"
|
|||||||
num_cpus = "1.16.0"
|
num_cpus = "1.16.0"
|
||||||
schemars = { version = "0.8.21", features = ["chrono"] }
|
schemars = { version = "0.8.21", features = ["chrono"] }
|
||||||
log = "0.4.22"
|
log = "0.4.22"
|
||||||
env_logger = "0.11.5"
|
env_logger = "0.11.6"
|
||||||
tiberius = { version = "0.12.3", features = ["chrono", "tokio"] }
|
tiberius = { version = "0.12.3", features = ["chrono", "tokio"] }
|
||||||
futures-io = "0.3.30"
|
futures-io = "0.3.31"
|
||||||
futures = "0.3.30"
|
futures = "0.3.31"
|
||||||
tokio-util = { version = "0.7.11", features = ["compat"] }
|
tokio-util = { version = "0.7.13", features = ["compat"] }
|
||||||
async-trait = "0.1.81"
|
async-trait = "0.1.83"
|
||||||
|
testcontainers = "0.23.1"
|
||||||
|
|
||||||
# More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target
|
# More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target
|
||||||
[lib]
|
[lib]
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use async_trait::async_trait;
|
|||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::io::{RecordDeserializer, RecordSerializer};
|
use crate::io::{DataSource, RecordDeserializer, RecordSerializer};
|
||||||
|
|
||||||
use super::node::RunnableNode;
|
use super::node::RunnableNode;
|
||||||
|
|
||||||
@@ -165,8 +165,8 @@ pub struct DeriveRule {
|
|||||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
pub struct DeriveNode {
|
pub struct DeriveNode {
|
||||||
pub rules: Vec<DeriveRule>,
|
pub rules: Vec<DeriveRule>,
|
||||||
pub input_file_path: String,
|
pub input_data_source: DataSource,
|
||||||
pub output_file_path: String,
|
pub output_data_source: DataSource,
|
||||||
pub copy_all_columns: bool,
|
pub copy_all_columns: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,8 +302,8 @@ pub struct DeriveNodeRunner {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableNode for DeriveNodeRunner {
|
impl RunnableNode for DeriveNodeRunner {
|
||||||
async fn run(&self) -> anyhow::Result<()> {
|
async fn run(&self) -> anyhow::Result<()> {
|
||||||
let mut reader = csv::Reader::from_path(&self.derive_node.input_file_path)?;
|
let mut reader = csv::Reader::from_path(&self.derive_node.input_data_source.path)?;
|
||||||
let mut writer = csv::Writer::from_path(&self.derive_node.output_file_path)?;
|
let mut writer = csv::Writer::from_path(&self.derive_node.output_data_source.path)?;
|
||||||
let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self
|
let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self
|
||||||
.derive_node
|
.derive_node
|
||||||
.rules
|
.rules
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ use async_trait::async_trait;
|
|||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::io::{RecordDeserializer, RecordSerializer};
|
use crate::io::{DataSource, RecordDeserializer, RecordSerializer};
|
||||||
|
|
||||||
use super::derive::{DataValidators, DeriveFilter};
|
use super::derive::{DataValidators, DeriveFilter};
|
||||||
|
|
||||||
@@ -42,8 +42,8 @@ pub fn filter_file(
|
|||||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
pub struct FilterNode {
|
pub struct FilterNode {
|
||||||
pub filters: Vec<DeriveFilter>,
|
pub filters: Vec<DeriveFilter>,
|
||||||
pub input_file_path: String,
|
pub input_data_source: DataSource,
|
||||||
pub output_file_path: String,
|
pub output_data_source: DataSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FilterNodeRunner {
|
pub struct FilterNodeRunner {
|
||||||
@@ -53,8 +53,8 @@ pub struct FilterNodeRunner {
|
|||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl RunnableNode for FilterNodeRunner {
|
impl RunnableNode for FilterNodeRunner {
|
||||||
async fn run(&self) -> anyhow::Result<()> {
|
async fn run(&self) -> anyhow::Result<()> {
|
||||||
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
|
let mut reader = csv::Reader::from_path(&self.filter_node.input_data_source.path)?;
|
||||||
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
|
let mut writer = csv::Writer::from_path(&self.filter_node.output_data_source.path)?;
|
||||||
let rules = derive::to_filter_rules(&self.filter_node.filters)?;
|
let rules = derive::to_filter_rules(&self.filter_node.filters)?;
|
||||||
filter_file(&rules, &mut reader, &mut writer)
|
filter_file(&rules, &mut reader, &mut writer)
|
||||||
}
|
}
|
||||||
@@ -62,7 +62,6 @@ impl RunnableNode for FilterNodeRunner {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|
||||||
use super::derive::{Comparator, FilterRule};
|
use super::derive::{Comparator, FilterRule};
|
||||||
|
|
||||||
use super::filter_file;
|
use super::filter_file;
|
||||||
|
|||||||
@@ -308,9 +308,10 @@ impl RunnableGraph {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use chrono::Local;
|
|
||||||
|
|
||||||
use super::{NodeConfiguration, RunnableGraph};
|
use super::{NodeConfiguration, RunnableGraph};
|
||||||
|
use crate::io::{DataSource, SourceType};
|
||||||
|
use chrono::Local;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_basic() -> anyhow::Result<()> {
|
async fn test_basic() -> anyhow::Result<()> {
|
||||||
@@ -324,8 +325,8 @@ mod tests {
|
|||||||
name: "Hello".to_owned(),
|
name: "Hello".to_owned(),
|
||||||
configuration: NodeConfiguration::FilterNode(super::FilterNode {
|
configuration: NodeConfiguration::FilterNode(super::FilterNode {
|
||||||
filters: vec![],
|
filters: vec![],
|
||||||
input_file_path: "".to_owned(),
|
input_data_source: DataSource { path: PathBuf::from(""), source_type: SourceType::CSV },
|
||||||
output_file_path: "".to_owned(),
|
output_data_source: DataSource { path: PathBuf::from(""), source_type: SourceType::CSV },
|
||||||
}),
|
}),
|
||||||
output_files: vec![],
|
output_files: vec![],
|
||||||
dynamic_configuration: None,
|
dynamic_configuration: None,
|
||||||
|
|||||||
@@ -1,21 +1,36 @@
|
|||||||
use super::sql::QueryExecutor;
|
use super::sql::QueryExecutor;
|
||||||
use crate::graph::node::RunnableNode;
|
use crate::graph::node::RunnableNode;
|
||||||
use crate::graph::upload_to_db::{upload_file_bulk, DBType};
|
use crate::graph::upload_to_db::{upload_file_bulk, DBType};
|
||||||
|
use crate::io::{DataSource, RecordSerializer};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use polars::prelude::CsvWriter;
|
||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use sqlx::AnyPool;
|
use sqlx::AnyPool;
|
||||||
use tiberius::Config;
|
use std::collections::BTreeMap;
|
||||||
|
use tiberius::{AuthMethod, Config, EncryptionLevel};
|
||||||
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
use tokio_util::compat::TokioAsyncWriteCompatExt;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Pull data from a db using a db query into a csv file that can be used by another node
|
* Pull data from a db using a db query into a csv file that can be used by another node
|
||||||
*/
|
*/
|
||||||
async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) {}
|
async fn pull_from_db(executor: &mut impl QueryExecutor, node: &PullFromDBNode) -> anyhow::Result<()> {
|
||||||
|
let mut output_file = csv::Writer::from_path(node.output_data_source.path.clone())?;
|
||||||
|
let mut first_row = true;
|
||||||
|
executor.get_rows(&node.query, &node.parameters, &mut move |row| {
|
||||||
|
if first_row {
|
||||||
|
output_file.write_header(&row)?;
|
||||||
|
first_row = false;
|
||||||
|
}
|
||||||
|
output_file.write_record(row.values())?;
|
||||||
|
Ok(())
|
||||||
|
}).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
pub struct PullFromDBNode {
|
pub struct PullFromDBNode {
|
||||||
file_path: String,
|
output_data_source: DataSource,
|
||||||
query: String,
|
query: String,
|
||||||
parameters: Vec<String>,
|
parameters: Vec<String>,
|
||||||
db_type: DBType,
|
db_type: DBType,
|
||||||
@@ -33,17 +48,75 @@ impl RunnableNode for PullFromDBNodeRunner {
|
|||||||
// TODO: Clean up grabbing of connection/executor so don't need to repeat this between upload/download
|
// TODO: Clean up grabbing of connection/executor so don't need to repeat this between upload/download
|
||||||
match node.db_type {
|
match node.db_type {
|
||||||
DBType::Mssql => {
|
DBType::Mssql => {
|
||||||
let config = Config::from_jdbc_string(&node.connection_string)?;
|
let mut config = Config::from_jdbc_string(&node.connection_string)?;
|
||||||
|
// TODO: Restore encryption for remote hosts, doesn't work on localhost without encryption.
|
||||||
|
config.encryption(EncryptionLevel::NotSupported);
|
||||||
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
|
||||||
tcp.set_nodelay(true)?;
|
tcp.set_nodelay(true)?;
|
||||||
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
let mut client = tiberius::Client::connect(config, tcp.compat_write()).await?;
|
||||||
pull_from_db(&mut client, &node).await;
|
pull_from_db(&mut client, &node).await?;
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let mut pool = AnyPool::connect(&node.connection_string).await?;
|
let mut pool = AnyPool::connect(&node.connection_string).await?;
|
||||||
pull_from_db(&mut pool, &node).await;
|
pull_from_db(&mut pool, &node).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::graph::node::RunnableNode;
|
||||||
|
use crate::graph::pull_from_db::{PullFromDBNode, PullFromDBNodeRunner};
|
||||||
|
use crate::graph::upload_to_db::DBType::Mssql;
|
||||||
|
use crate::io::DataSource;
|
||||||
|
use crate::io::SourceType::CSV;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use testcontainers::core::{IntoContainerPort, WaitFor};
|
||||||
|
use testcontainers::runners::AsyncRunner;
|
||||||
|
use testcontainers::{GenericImage, ImageExt};
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_sql_server() -> anyhow::Result<()> {
|
||||||
|
let container = GenericImage::new("mcr.microsoft.com/mssql/server", "2022-latest")
|
||||||
|
.with_exposed_port(1433.tcp())
|
||||||
|
.with_wait_for(WaitFor::message_on_stdout("Recovery is complete.".to_owned()))
|
||||||
|
.with_env_var("ACCEPT_EULA", "Y")
|
||||||
|
.with_env_var("MSSQL_SA_PASSWORD", "TestOnlyContainer123")
|
||||||
|
.start()
|
||||||
|
.await?;
|
||||||
|
let host = container.get_host().await?;
|
||||||
|
let port = container.get_host_port_ipv4(1433).await?;
|
||||||
|
let port = 1433;
|
||||||
|
let connection_string = format!("jdbc:sqlserver://{}:{};username=sa;password=TestOnlyContainer123", host, port).to_owned();
|
||||||
|
let connection_string = "jdbc:sqlserver://localhost:1433;username=sa;password=TestOnlyContainer123;Encrypt=False".to_owned();
|
||||||
|
|
||||||
|
let runner = PullFromDBNodeRunner {
|
||||||
|
pull_from_db_node: PullFromDBNode {
|
||||||
|
db_type: Mssql,
|
||||||
|
query: "SELECT '1' Test".to_owned(),
|
||||||
|
parameters: vec![],
|
||||||
|
connection_string,
|
||||||
|
output_data_source: DataSource {
|
||||||
|
path: PathBuf::from("test_pull.csv"),
|
||||||
|
source_type: CSV,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
};
|
||||||
|
runner.run().await?;
|
||||||
|
let mut result_contents = String::new();
|
||||||
|
let result_length = File::open("test_pull.csv")?.read_to_string(&mut result_contents)?;
|
||||||
|
assert_eq!(
|
||||||
|
"Test
|
||||||
|
1
|
||||||
|
",
|
||||||
|
result_contents
|
||||||
|
,
|
||||||
|
"Should pull the correct data from sql"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,21 +1,21 @@
|
|||||||
use std::borrow::Borrow;
|
|
||||||
|
|
||||||
use futures::TryStreamExt;
|
use futures::TryStreamExt;
|
||||||
use futures_io::{AsyncRead, AsyncWrite};
|
use futures_io::{AsyncRead, AsyncWrite};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
use sqlx::{Any, AnyPool, Column, Pool, Row};
|
use sqlx::{Any, Column, Pool, Row};
|
||||||
|
use std::borrow::Borrow;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
use tiberius::{Client, Query};
|
use tiberius::{Client, Query};
|
||||||
|
|
||||||
// TODO: This doesn't seem to work. Suggestion by compiler is to instead create an enum and implement
|
|
||||||
// the trait on the enum (basically use a match in the implementation depending on which enum we have)
|
|
||||||
pub trait QueryExecutor {
|
pub trait QueryExecutor {
|
||||||
// TODO: Params binding for filtering the same query?
|
|
||||||
// Retrieve data from a database
|
// Retrieve data from a database
|
||||||
async fn get_rows(
|
async fn get_rows(
|
||||||
&mut self,
|
&mut self,
|
||||||
query: &str,
|
query: &str,
|
||||||
params: &Vec<String>,
|
params: &Vec<String>,
|
||||||
) -> anyhow::Result<Vec<Vec<(String, String)>>>;
|
// TODO: This is looking pretty ugly, simpler way to handle it? Maybe with an iterator?
|
||||||
|
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
|
||||||
|
) -> anyhow::Result<()>;
|
||||||
|
|
||||||
// Run a query that returns no results (e.g. bulk insert, insert)
|
// Run a query that returns no results (e.g. bulk insert, insert)
|
||||||
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64>;
|
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64>;
|
||||||
@@ -26,32 +26,25 @@ impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
query: &str,
|
query: &str,
|
||||||
params: &Vec<String>,
|
params: &Vec<String>,
|
||||||
) -> anyhow::Result<Vec<Vec<(String, String)>>> {
|
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let mut query = Query::new(query);
|
let mut query = Query::new(query);
|
||||||
for param in params {
|
for param in params {
|
||||||
query.bind(param);
|
query.bind(param);
|
||||||
}
|
}
|
||||||
let query_result = query.query(self).await?;
|
let query_result = query.query(self).await?;
|
||||||
let results = query_result.into_first_result().await?;
|
let mut query_stream = query_result.into_row_stream();
|
||||||
let results = results
|
|
||||||
.into_iter()
|
while let Some(row) = query_stream.try_next().await? {
|
||||||
.map(|row| {
|
let mut returned_row = BTreeMap::new();
|
||||||
row.columns()
|
// TODO: Check how empty columns are handled by tiberius
|
||||||
.into_iter()
|
for column in row.columns().into_iter() {
|
||||||
.map(|column| {
|
returned_row.insert(column.name().to_owned(), row.get(column.name()).unwrap_or_else(|| "")
|
||||||
(
|
.to_owned());
|
||||||
column.name().to_owned(),
|
|
||||||
match row.get(column.name()) {
|
|
||||||
Some(value) => value,
|
|
||||||
None => "",
|
|
||||||
}
|
}
|
||||||
.to_owned(),
|
row_consumer(returned_row)?;
|
||||||
)
|
}
|
||||||
})
|
Ok(())
|
||||||
.collect_vec()
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Ok(results)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
|
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
|
||||||
@@ -72,22 +65,22 @@ impl QueryExecutor for Pool<Any> {
|
|||||||
&mut self,
|
&mut self,
|
||||||
query: &str,
|
query: &str,
|
||||||
params: &Vec<String>,
|
params: &Vec<String>,
|
||||||
) -> anyhow::Result<Vec<Vec<(String, String)>>> {
|
row_consumer: &mut impl FnMut(BTreeMap<String, String>) -> anyhow::Result<()>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
let mut query = sqlx::query(query);
|
let mut query = sqlx::query(query);
|
||||||
for param in params {
|
for param in params {
|
||||||
query = query.bind(param);
|
query = query.bind(param);
|
||||||
}
|
}
|
||||||
let mut rows = query.fetch(self.borrow());
|
let mut rows = query.fetch(self.borrow());
|
||||||
let mut results = vec![];
|
|
||||||
while let Some(row) = rows.try_next().await? {
|
while let Some(row) = rows.try_next().await? {
|
||||||
results.push(
|
let mut returned_row = BTreeMap::new();
|
||||||
row.columns()
|
for column in row.columns().into_iter() {
|
||||||
.into_iter()
|
returned_row.insert(column.name().to_owned(), row.get(column.name()));
|
||||||
.map(|column| (column.name().to_owned(), row.get(column.name())))
|
|
||||||
.collect(),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Ok(results)
|
row_consumer(returned_row)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
|
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
|
|
||||||
|
use super::node::RunnableNode;
|
||||||
|
use crate::io::{DataSource, SourceType};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use polars::io::SerReader;
|
||||||
|
use polars::prelude::{IntoLazy, LazyFrame, ParquetReader, ScanArgsParquet};
|
||||||
use polars::{
|
use polars::{
|
||||||
io::SerWriter,
|
io::SerWriter,
|
||||||
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
|
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
|
||||||
@@ -9,21 +13,24 @@ use polars_sql::SQLContext;
|
|||||||
use schemars::JsonSchema;
|
use schemars::JsonSchema;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::node::RunnableNode;
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
pub struct CSVFile {
|
pub struct SqlFile {
|
||||||
name: String,
|
pub name: String,
|
||||||
path: String,
|
pub data_source: DataSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Run SQL over files using polars, export results to output file
|
* Run SQL over files using polars, export results to output file
|
||||||
*/
|
*/
|
||||||
fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow::Result<()> {
|
fn run_sql(files: &Vec<SqlFile>, output_path: &String, query: &String) -> anyhow::Result<()> {
|
||||||
let mut ctx = SQLContext::new();
|
let mut ctx = SQLContext::new();
|
||||||
for file in files {
|
for file in files {
|
||||||
let df = LazyCsvReader::new(&file.path).finish()?;
|
let df = match file.data_source.source_type {
|
||||||
|
SourceType::CSV => LazyCsvReader::new(&file.data_source.path).finish()?,
|
||||||
|
SourceType::PARQUET => {
|
||||||
|
LazyFrame::scan_parquet(&file.data_source.path, ScanArgsParquet::default())?
|
||||||
|
}
|
||||||
|
};
|
||||||
ctx.register(&file.name, df);
|
ctx.register(&file.name, df);
|
||||||
}
|
}
|
||||||
let result = ctx.execute(&query)?;
|
let result = ctx.execute(&query)?;
|
||||||
@@ -34,7 +41,7 @@ fn run_sql(files: &Vec<CSVFile>, output_path: &String, query: &String) -> anyhow
|
|||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
pub struct SQLNode {
|
pub struct SQLNode {
|
||||||
pub files: Vec<CSVFile>,
|
pub files: Vec<SqlFile>,
|
||||||
pub output_file: String,
|
pub output_file: String,
|
||||||
pub query: String,
|
pub query: String,
|
||||||
}
|
}
|
||||||
@@ -55,17 +62,21 @@ impl RunnableNode for SQLNodeRunner {
|
|||||||
}
|
}
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use super::{run_sql, SqlFile};
|
||||||
|
use crate::io::{DataSource, SourceType};
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::{fs::File, io::Read};
|
use std::{fs::File, io::Read};
|
||||||
|
|
||||||
use super::{run_sql, CSVFile};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn basic_query_works() -> anyhow::Result<()> {
|
fn basic_query_works() -> anyhow::Result<()> {
|
||||||
let output_path = "./testing/output/output.csv".to_owned();
|
let output_path = "./testing/output/output.csv".to_owned();
|
||||||
run_sql(
|
run_sql(
|
||||||
&vec![CSVFile {
|
&vec![SqlFile {
|
||||||
name: "Account".to_owned(),
|
name: "Account".to_owned(),
|
||||||
path: "./testing/test.csv".to_owned(),
|
data_source: DataSource {
|
||||||
|
source_type: SourceType::CSV,
|
||||||
|
path: PathBuf::from("./testing/test.csv"),
|
||||||
|
}
|
||||||
}],
|
}],
|
||||||
&output_path,
|
&output_path,
|
||||||
&"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(),
|
&"SELECT * FROM Account WHERE Code = 'A195950'".to_owned(),
|
||||||
@@ -76,7 +87,7 @@ mod tests {
|
|||||||
assert_eq!(
|
assert_eq!(
|
||||||
output,
|
output,
|
||||||
"Code,Description,Type,CostOutput,PercentFixed
|
"Code,Description,Type,CostOutput,PercentFixed
|
||||||
A195950,A195950 Staff Related Other,E,GS,100.00
|
A195950,A195950 Staff Related Other,E,GS,100.0
|
||||||
"
|
"
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
19
src/io.rs
19
src/io.rs
@@ -1,11 +1,24 @@
|
|||||||
|
use anyhow::bail;
|
||||||
|
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
|
||||||
|
use schemars::JsonSchema;
|
||||||
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
use std::path::PathBuf;
|
||||||
use std::{
|
use std::{
|
||||||
collections::BTreeMap,
|
collections::BTreeMap,
|
||||||
io::{Read, Seek, Write},
|
io::{Read, Seek, Write},
|
||||||
};
|
};
|
||||||
|
|
||||||
use anyhow::bail;
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
|
pub enum SourceType {
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
CSV,
|
||||||
|
PARQUET,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
|
||||||
|
pub struct DataSource {
|
||||||
|
pub path: PathBuf,
|
||||||
|
pub source_type: SourceType,
|
||||||
|
}
|
||||||
|
|
||||||
pub trait RecordSerializer {
|
pub trait RecordSerializer {
|
||||||
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;
|
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;
|
||||||
|
|||||||
Reference in New Issue
Block a user