Start adding db download/upload, migrate graph to async/await

This commit is contained in:
2024-08-29 21:24:23 +09:30
parent 0ee88e3a99
commit 9e225e58cb
12 changed files with 626 additions and 86 deletions

View File

@@ -174,9 +174,11 @@ impl Cli {
let reader = BufReader::new(file);
let graph = serde_json::from_reader(reader)?;
let graph = RunnableGraph::from_graph(graph);
// TODO: Possible to await here?
graph.run_default_tasks(threads, |id, status| {
info!("Node with id {} finished with status {:?}", id, status)
})
});
Ok(())
}
Commands::GenerateSchema { output } => {
let schema = schema_for!(Graph);

View File

@@ -1,5 +1,6 @@
use std::{collections::BTreeMap, str::FromStr};
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -298,8 +299,9 @@ pub struct DeriveNodeRunner {
derive_node: DeriveNode,
}
#[async_trait]
impl RunnableNode for DeriveNodeRunner {
fn run(&self) -> anyhow::Result<()> {
async fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.derive_node.input_file_path)?;
let mut writer = csv::Writer::from_path(&self.derive_node.output_file_path)?;
let rules: anyhow::Result<Vec<RunnableDeriveRule>> = self

View File

@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -49,8 +50,9 @@ pub struct FilterNodeRunner {
pub filter_node: FilterNode,
}
#[async_trait]
impl RunnableNode for FilterNodeRunner {
fn run(&self) -> anyhow::Result<()> {
async fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
let rules = derive::to_filter_rules(&self.filter_node.filters)?;

View File

@@ -2,16 +2,16 @@ use std::{
cmp::{min, Ordering},
collections::{HashMap, HashSet},
sync::{
mpsc::{self, Sender},
Arc,
mpsc, Arc
},
thread,
};
use chrono::Local;
use futures::lock::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use split::{SplitNode, SplitNodeRunner};
use tokio::sync::mpsc::Sender;
use {
derive::DeriveNode,
@@ -24,7 +24,9 @@ use {
mod derive;
mod filter;
mod node;
mod pull_from_db;
mod split;
mod sql;
mod sql_rule;
mod upload_to_db;
@@ -131,7 +133,7 @@ impl Node {
}
}
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode + Send> {
match node.info.configuration {
NodeConfiguration::FileNode => todo!(),
NodeConfiguration::MoveMoneyNode(_) => todo!(),
@@ -169,7 +171,7 @@ impl RunnableGraph {
RunnableGraph { graph }
}
pub fn run_default_tasks<F>(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()>
pub async fn run_default_tasks<F>(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()>
where
F: Fn(i64, NodeStatus),
{
@@ -177,18 +179,18 @@ impl RunnableGraph {
num_threads,
Box::new(|node| get_runnable_node(node)),
status_changed,
)
).await
}
// Make this not mutable, emit node status when required in a function or some other message
pub fn run<'a, F, StatusChanged>(
pub async fn run<'a, F, StatusChanged>(
&self,
num_threads: usize,
get_node_fn: F,
node_status_changed_fn: StatusChanged,
) -> anyhow::Result<()>
where
F: Fn(Node) -> Box<dyn RunnableNode> + Send + Sync + 'static,
F: Fn(Node) -> Box<dyn RunnableNode + Send> + Send + Sync + 'static,
StatusChanged: Fn(i64, NodeStatus),
{
let mut nodes = self.graph.nodes.clone();
@@ -209,7 +211,7 @@ impl RunnableGraph {
if num_threads < 2 {
for node in &nodes {
node_status_changed_fn(node.id, NodeStatus::Running);
match get_node_fn(node.clone()).run() {
match get_node_fn(node.clone()).run().await {
Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed),
Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)),
};
@@ -226,15 +228,18 @@ impl RunnableGraph {
let node_fn = Arc::new(get_node_fn);
for n in 0..num_threads {
let finish_task = finish_task.clone();
let (tx, rx) = mpsc::channel();
// let finish_task = finish_task.clone();
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
senders.push(tx);
let node_fn = node_fn.clone();
let handle = thread::spawn(move || {
for node in rx {
let status = match node_fn(node.clone()).run() {
// TODO: Think this needs to be all reworked to be more inline with async
let handle = tokio::spawn(async move {
for node in rx.recv().await {
let status = match node_fn(node.clone()).run().await {
Ok(_) => NodeStatus::Completed,
Err(err) => NodeStatus::Failed(err),
};
let status = status;
finish_task
.send((n, node, status))
.expect("Failed to notify node status completion");
@@ -258,7 +263,7 @@ impl RunnableGraph {
let node = nodes.remove(i);
node_status_changed_fn(node.id, NodeStatus::Running);
running_nodes.insert(node.id);
senders[i % senders.len()].send(node)?;
senders[i % senders.len()].send(node).await?;
}
}
@@ -280,7 +285,7 @@ impl RunnableGraph {
let node = nodes.remove(i);
for i in 0..num_threads {
if !running_threads.contains(&i) {
senders[i].send(node)?;
senders[i].send(node).await?;
break;
}
}
@@ -296,7 +301,7 @@ impl RunnableGraph {
}
for handle in handles {
handle.join().expect("Failed to join thread");
handle.await.expect("Failed to join thread");
}
println!("Process finished");
Ok(())
@@ -310,8 +315,8 @@ mod tests {
use super::{NodeConfiguration, RunnableGraph};
#[test]
fn test_basic() -> anyhow::Result<()> {
#[tokio::test]
async fn test_basic() -> anyhow::Result<()> {
let graph = RunnableGraph {
graph: super::Graph {
name: "Test".to_owned(),
@@ -332,7 +337,7 @@ mod tests {
}],
},
};
graph.run_default_tasks(2, |_, _| {})?;
graph.run_default_tasks(2, |_, _| {}).await?;
Ok(())
}
}

View File

@@ -1,6 +1,11 @@
use async_trait::async_trait;
#[async_trait]
pub trait RunnableNode {
// TODO: Get inputs/outputs to determine whether we can skip running this task
// TODO: Status
fn run(&self) -> anyhow::Result<()>;
// TODO: Is it possible to make this async?
async fn run(&self) -> anyhow::Result<()>;
}

12
src/graph/pull_from_db.rs Normal file
View File

@@ -0,0 +1,12 @@
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::sql::QueryExecutor;
/**
* Pull data from a db using a db query into a csv file that can be used by another node
*/
fn pull_from_db(executor: &mut impl QueryExecutor, node: PullFromDBNode) {}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct PullFromDBNode {}

View File

@@ -1,5 +1,6 @@
use std::{collections::BTreeMap, fs::File};
use std::collections::BTreeMap;
use async_trait::async_trait;
use chrono::DateTime;
use polars::{
io::SerWriter,
@@ -9,7 +10,7 @@ use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use tempfile::tempfile;
use crate::io::{RecordDeserializer, RecordSerializer};
use crate::io::RecordSerializer;
use super::{
derive::{self, DataValidator, DeriveFilter},
@@ -149,8 +150,9 @@ fn split(
Ok(())
}
#[async_trait]
impl RunnableNode for SplitNodeRunner {
fn run(&self) -> anyhow::Result<()> {
async fn run(&self) -> anyhow::Result<()> {
let mut output = csv::Writer::from_path(&self.split_node.output_file_path)?;
let rules: anyhow::Result<Vec<RunnableSplitRule>> = self
.split_node

101
src/graph/sql.rs Normal file
View File

@@ -0,0 +1,101 @@
use std::borrow::Borrow;
use futures::TryStreamExt;
use futures_io::{AsyncRead, AsyncWrite};
use itertools::Itertools;
use sqlx::{Any, AnyPool, Column, Pool, Row};
use tiberius::{Client, Query};
// TODO: This doesn't seem to work. Suggestion by compiler is to instead create an enum and implement
// the trait on the enum (basically use a match in the implementation depending on which enum we have)
pub trait QueryExecutor {
// TODO: Params binding for filtering the same query?
// Retrieve data from a database
async fn get_rows(
&mut self,
query: &str,
params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>>;
// Run a query that returns no results (e.g. bulk insert, insert)
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64>;
}
impl<S: AsyncRead + AsyncWrite + Unpin + Send> QueryExecutor for Client<S> {
async fn get_rows(
&mut self,
query: &str,
params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>> {
let mut query = Query::new(query);
for param in params {
query.bind(param);
}
let query_result = query.query(self).await?;
let results = query_result.into_first_result().await?;
let results = results
.into_iter()
.map(|row| {
row.columns()
.into_iter()
.map(|column| {
(
column.name().to_owned(),
match row.get(column.name()) {
Some(value) => value,
None => "",
}
.to_owned(),
)
})
.collect_vec()
})
.collect();
Ok(results)
}
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
let mut query = Query::new(query);
for param in params {
query.bind(param);
}
let result = query.execute(self).await?;
if result.rows_affected().len() == 0 {
return Ok(0);
}
Ok(result.rows_affected()[0])
}
}
impl QueryExecutor for Pool<Any> {
async fn get_rows(
&mut self,
query: &str,
params: &Vec<String>,
) -> anyhow::Result<Vec<Vec<(String, String)>>> {
let mut query = sqlx::query(query);
for param in params {
query = query.bind(param);
}
let mut rows = query.fetch(self.borrow());
let mut results = vec![];
while let Some(row) = rows.try_next().await? {
results.push(
row.columns()
.into_iter()
.map(|column| (column.name().to_owned(), row.get(column.name())))
.collect(),
);
}
Ok(results)
}
async fn execute_query(&mut self, query: &str, params: &Vec<String>) -> anyhow::Result<u64> {
let mut query = sqlx::query(query);
for param in params {
query = query.bind(param);
}
let result = query.execute(self.borrow()).await?;
Ok(result.rows_affected())
}
}

View File

@@ -1,5 +1,6 @@
use std::fs::File;
use async_trait::async_trait;
use polars::{
io::SerWriter,
prelude::{CsvWriter, LazyCsvReader, LazyFileListReader},
@@ -42,8 +43,9 @@ pub struct SQLNodeRunner {
pub sql_node: SQLNode,
}
#[async_trait]
impl RunnableNode for SQLNodeRunner {
fn run(&self) -> anyhow::Result<()> {
async fn run(&self) -> anyhow::Result<()> {
run_sql(
&self.sql_node.files,
&self.sql_node.output_file,

View File

@@ -1,65 +1,89 @@
use std::collections::HashMap;
use anyhow::bail;
use async_trait::async_trait;
use futures::executor;
use itertools::Itertools;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use sqlx::{Any, Pool, QueryBuilder};
use sqlx::{AnyPool};
use tiberius::Config;
use tokio::{ task};
use tokio_util::compat::TokioAsyncWriteCompatExt;
use super::node::RunnableNode;
use super::{node::RunnableNode, sql::QueryExecutor};
const BIND_LIMIT: usize = 65535;
// Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any
// type for sqlx 0.6 and earlier due to a query_builder lifetime issue,
// however sqlx >=0.7 currently doesn't support mssql.
// Upload data in a file to a db table, with an optional post-script to run,
// such as to move data from the upload table into other tables
// TODO: Add bulk insert options for non-mssql dbs
// TODO: Add fallback insert when bulk insert fails (e.g. due to
// permission errors)
pub async fn upload_file_bulk(pool: &Pool<Any>, upload_node: &UploadNode) -> anyhow::Result<u64> {
pub async fn upload_file_bulk(
executor: &mut impl QueryExecutor,
upload_node: &UploadNode,
) -> anyhow::Result<u64> {
let mut rows_affected = None;
if upload_node.column_mappings.is_none() {
let insert_from_file_query = match pool.connect_options().database_url.scheme() {
"postgres" => Some(format!("COPY {} FROM $1", upload_node.table_name)),
"mysql" => Some(format!(
let insert_from_file_query = match upload_node.db_type {
DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)),
DBType::Mysql => Some(format!(
"LOAD DATA INFILE ? INTO {}",
upload_node.table_name,
)),
DBType::Mssql => Some(format!("BULK INSERT {} FROM ?", upload_node.table_name)),
_ => None,
};
if let Some(insert_from_file_query) = insert_from_file_query {
let result = sqlx::query(&insert_from_file_query)
.bind(&upload_node.file_path)
.execute(pool)
let result = executor
.execute_query(
&insert_from_file_query,
&vec![upload_node.file_path.clone()],
)
.await?;
rows_affected = Some(result.rows_affected());
rows_affected = Some(result);
}
}
if rows_affected == None {
let rows: Vec<HashMap<String, String>> = vec![];
let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?;
// TODO: Columns to insert... needs some kind of mapping from file column name <-> db column
let mut query_builder =
QueryBuilder::new(format!("INSERT INTO {}({}) ", upload_node.table_name, ""));
// TODO: Iterate over all values in file, not the limit
query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| {
b.push_bind(row.get("s"));
});
let mut query_builder = query_builder;
// TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978
// Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx
// to use this, unless we explicity specified mssql only, not Any as the db type...
// Can probably work around this by specifying an actual implementation?
let query = query_builder.build();
let result = query.execute(pool).await?;
rows_affected = Some(result.rows_affected());
let csv_columns = file_reader.headers()?.iter().map(|header| header.to_owned()).collect_vec();
let table_columns = if let Some(column_mappings) = &upload_node.column_mappings {
csv_columns
.iter()
.map(|column| {
column_mappings
.get(column).unwrap_or(column)
.clone()
})
.collect_vec()
} else {
csv_columns.clone()
};
let query_template = format!("INSERT INTO {}({}) \n", upload_node.table_name, table_columns.join(","));
let mut params = vec![];
let mut insert_query = "".to_owned();
let mut num_params = 0;
let mut running_row_total = 0;
for result in file_reader.records() {
let result = result?;
insert_query = insert_query + format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str();
let mut values = result.iter().map(|value| value.to_owned()).collect_vec();
params.append(&mut values);
num_params += csv_columns.len();
if num_params == BIND_LIMIT {
running_row_total += executor.execute_query(&query_template, &params).await?;
insert_query = "".to_owned();
params = vec![];
num_params = 0;
}
}
if !insert_query.is_empty() {
running_row_total += executor.execute_query(&query_template, &params).await?;
}
rows_affected = Some(running_row_total);
}
if let Some(post_script) = &upload_node.post_script {
sqlx::query(&post_script).execute(pool).await?;
executor.execute_query(post_script, &vec![]).await?;
}
match rows_affected {
@@ -70,6 +94,14 @@ pub async fn upload_file_bulk(pool: &Pool<Any>, upload_node: &UploadNode) -> any
}
}
#[derive(Serialize, Deserialize, Clone, JsonSchema, PartialEq)]
pub enum DBType {
Mysql,
Postgres,
Mssql,
Sqlite,
}
#[derive(Serialize, Deserialize, Clone, JsonSchema)]
pub struct UploadNode {
file_path: String,
@@ -77,15 +109,39 @@ pub struct UploadNode {
// Mappings from column in file -> column in db
column_mappings: Option<HashMap<String, String>>,
post_script: Option<String>,
db_type: DBType,
connection_string: String,
}
pub struct UploadNodeRunner {
pub upload_node: UploadNode,
}
#[async_trait]
impl RunnableNode for UploadNodeRunner {
fn run(&self) -> anyhow::Result<()> {
// TODO: Get db connection from some kind of property manager/context
todo!()
async fn run(&self) -> anyhow::Result<()> {
let upload_node = self.upload_node.clone();
if upload_node.db_type == DBType::Mssql {
let mut config = Config::from_jdbc_string(&upload_node.connection_string);
if let Ok(mut config) = config {
let tcp = tokio::net::TcpStream::connect(config.get_addr()).await;
if let Ok(tcp) = tcp {
tcp.set_nodelay(true);
let client = tiberius::Client::connect(config, tcp.compat_write()).await;
if let Ok(mut client) = client {
upload_file_bulk(&mut client, &upload_node).await;
}
}
}
}else {
let mut pool = AnyPool::connect(&upload_node.connection_string).await;
if let Ok(mut pool) = pool {
upload_file_bulk(&mut pool, &upload_node).await;
}
}
// TODO: Message to listen for task completing since join handle doesn't include this
// Alternative is to make run signature async, though that may add more complexity
// to graph mode.
Ok(())
}
}