Add custom graph executor and implement filter node to test it (#2)
Reviewed-on: vato007/coster-rs#2
This commit is contained in:
816
Cargo.lock
generated
816
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -22,7 +22,7 @@ chrono = {version = "0.4.31", features = ["default", "serde"]}
|
|||||||
|
|
||||||
rayon = "1.6.0"
|
rayon = "1.6.0"
|
||||||
tokio = { version = "1.26.0", features = ["full"] }
|
tokio = { version = "1.26.0", features = ["full"] }
|
||||||
sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "mssql", "any" ] }
|
sqlx = { version = "0.8", features = [ "runtime-tokio-rustls", "any" ] }
|
||||||
rmp-serde = "1.1.1"
|
rmp-serde = "1.1.1"
|
||||||
tempfile = "3.7.0"
|
tempfile = "3.7.0"
|
||||||
polars = {version = "0.32.1", features = ["lazy", "performant", "streaming", "cse", "dtype-datetime"]}
|
polars = {version = "0.32.1", features = ["lazy", "performant", "streaming", "cse", "dtype-datetime"]}
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
use coster_rs::upload_to_db;
|
use sqlx::any::AnyPoolOptions;
|
||||||
use sqlx::{any::AnyPoolOptions, mssql::MssqlPoolOptions};
|
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> anyhow::Result<()> {
|
async fn main() -> anyhow::Result<()> {
|
||||||
@@ -7,13 +6,14 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let password = "";
|
let password = "";
|
||||||
let host = "";
|
let host = "";
|
||||||
let database = "";
|
let database = "";
|
||||||
// USing sqlx: https://github.com/launchbadge/sqlx
|
let database_type = "";
|
||||||
let connection_string = format!("mssq://{}:{}@{}/{}", user, password, host, database);
|
let connection_string = format!(
|
||||||
let pool = AnyPoolOptions::new()
|
"{}://{}:{}@{}/{}",
|
||||||
|
database_type, user, password, host, database
|
||||||
|
);
|
||||||
|
let _ = AnyPoolOptions::new()
|
||||||
.max_connections(20)
|
.max_connections(20)
|
||||||
.connect(&connection_string)
|
.connect(&connection_string)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// upload_to_db::upload_file_bulk(&pool, &"".to_owned(), &"".to_owned(), None, "".to_owned()).await?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
75
src/derive.rs
Normal file
75
src/derive.rs
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub enum DeriveColumnType {
|
||||||
|
Column(String),
|
||||||
|
Constant(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct MapOperation {
|
||||||
|
pub mapped_value: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub enum DatePart {
|
||||||
|
Year,
|
||||||
|
Month,
|
||||||
|
Week,
|
||||||
|
Day,
|
||||||
|
Hour,
|
||||||
|
Minute,
|
||||||
|
Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub enum SplitType {
|
||||||
|
DateTime(String, DatePart),
|
||||||
|
Numeric(String, isize),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub enum MatchComparisonType {
|
||||||
|
Equal,
|
||||||
|
GreaterThan,
|
||||||
|
LessThan,
|
||||||
|
NotEqual,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub enum DeriveOperation {
|
||||||
|
Concat(Vec<DeriveColumnType>),
|
||||||
|
Add(Vec<DeriveColumnType>),
|
||||||
|
Multiply(Vec<DeriveColumnType>),
|
||||||
|
Subtract(DeriveColumnType, DeriveColumnType),
|
||||||
|
Divide(DeriveColumnType, DeriveColumnType),
|
||||||
|
Map(String, Vec<MapOperation>),
|
||||||
|
Split(String, SplitType),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub enum ValueType {
|
||||||
|
String,
|
||||||
|
Integer,
|
||||||
|
Float,
|
||||||
|
Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct DeriveFilter {
|
||||||
|
pub column_name: String,
|
||||||
|
pub comparator: MatchComparisonType,
|
||||||
|
pub match_value: String,
|
||||||
|
pub value_type: ValueType,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct DeriveRule {
|
||||||
|
pub operations: Vec<DeriveOperation>,
|
||||||
|
pub filters: Vec<DeriveFilter>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct DeriveNode {
|
||||||
|
pub rules: Vec<DeriveRule>,
|
||||||
|
}
|
||||||
139
src/filter.rs
139
src/filter.rs
@@ -1,6 +1,12 @@
|
|||||||
use std::{collections::HashMap, io::Read, str::FromStr};
|
use std::{collections::BTreeMap, str::FromStr};
|
||||||
|
|
||||||
use crate::{graph::RunnableNode, io::RecordSerializer};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
derive::{DeriveFilter, MatchComparisonType},
|
||||||
|
io::{RecordDeserializer, RecordSerializer},
|
||||||
|
node::RunnableNode,
|
||||||
|
};
|
||||||
|
|
||||||
pub enum Comparator<T: PartialOrd> {
|
pub enum Comparator<T: PartialOrd> {
|
||||||
Equal(T),
|
Equal(T),
|
||||||
@@ -56,13 +62,17 @@ impl<T: FromStr + PartialOrd> DataValidator for FilterRule<T> {
|
|||||||
* that don't satisfy the filter criteria
|
* that don't satisfy the filter criteria
|
||||||
*/
|
*/
|
||||||
pub fn filter_file(
|
pub fn filter_file(
|
||||||
rules: Vec<&dyn DataValidator>,
|
rules: &Vec<Box<dyn DataValidator>>,
|
||||||
// TODO: Custom serialisers/deserialisers so we don't rely on csv only
|
input: &mut impl RecordDeserializer,
|
||||||
input: &mut csv::Reader<impl Read>,
|
|
||||||
output: &mut impl RecordSerializer,
|
output: &mut impl RecordSerializer,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
for line in input.deserialize() {
|
if let Some(line) = input.deserialize()? {
|
||||||
let line: HashMap<String, String> = line?;
|
let line: BTreeMap<String, String> = line;
|
||||||
|
output.write_header(&line)?;
|
||||||
|
output.write_record(&line)?;
|
||||||
|
|
||||||
|
while let Some(line) = input.deserialize()? {
|
||||||
|
let line: BTreeMap<String, String> = line;
|
||||||
if rules.iter().all(|rule| {
|
if rules.iter().all(|rule| {
|
||||||
line.get(&rule.get_field_name()).map_or(true, |value| {
|
line.get(&rule.get_field_name()).map_or(true, |value| {
|
||||||
if value.trim().is_empty() {
|
if value.trim().is_empty() {
|
||||||
@@ -72,14 +82,123 @@ pub fn filter_file(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
output.serialize(line)?;
|
output.write_record(&line)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
output.flush()?;
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FilterNodeRunner {}
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct FilterNode {
|
||||||
|
pub filters: Vec<DeriveFilter>,
|
||||||
|
pub input_file_path: String,
|
||||||
|
pub output_file_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl FilterNode {
|
||||||
|
fn to_filter_rules(&self) -> anyhow::Result<Vec<Box<dyn DataValidator>>> {
|
||||||
|
self.filters
|
||||||
|
.iter()
|
||||||
|
// For some reason inlining to_filter_rules causes a compiler error, so leaving
|
||||||
|
// in a separate function (it is cleaner at least)
|
||||||
|
.map(|filter| to_filter_rule(filter))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_filter_rule(filter: &DeriveFilter) -> anyhow::Result<Box<dyn DataValidator>> {
|
||||||
|
let value = filter.match_value.clone();
|
||||||
|
match filter.value_type {
|
||||||
|
crate::derive::ValueType::String => Ok(Box::new(get_filter_rule(filter, value))),
|
||||||
|
crate::derive::ValueType::Integer => {
|
||||||
|
Ok(Box::new(get_filter_rule(filter, value.parse::<i64>()?)))
|
||||||
|
}
|
||||||
|
crate::derive::ValueType::Float => {
|
||||||
|
Ok(Box::new(get_filter_rule(filter, value.parse::<f64>()?)))
|
||||||
|
}
|
||||||
|
crate::derive::ValueType::Boolean => {
|
||||||
|
Ok(Box::new(get_filter_rule(filter, value.parse::<bool>()?)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_filter_rule<T: PartialOrd>(filter: &DeriveFilter, value: T) -> FilterRule<T> {
|
||||||
|
FilterRule {
|
||||||
|
column_name: filter.column_name.clone(),
|
||||||
|
comparator: match filter.comparator {
|
||||||
|
MatchComparisonType::Equal => Comparator::Equal(value),
|
||||||
|
MatchComparisonType::GreaterThan => Comparator::GreaterThan(value),
|
||||||
|
MatchComparisonType::LessThan => Comparator::LessThan(value),
|
||||||
|
MatchComparisonType::NotEqual => Comparator::NotEqual(value),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct FilterNodeRunner {
|
||||||
|
pub filter_node: FilterNode,
|
||||||
|
}
|
||||||
|
|
||||||
impl RunnableNode for FilterNodeRunner {
|
impl RunnableNode for FilterNodeRunner {
|
||||||
fn run(&self) {}
|
fn run(&self) -> anyhow::Result<()> {
|
||||||
|
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
|
||||||
|
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
|
||||||
|
let rules = self.filter_node.to_filter_rules()?;
|
||||||
|
filter_file(&rules, &mut reader, &mut writer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use crate::filter::FilterRule;
|
||||||
|
|
||||||
|
use super::filter_file;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_filters_passes_through() -> anyhow::Result<()> {
|
||||||
|
let records = "Column1,Column2
|
||||||
|
Value1,Value2
|
||||||
|
Value3,Value4
|
||||||
|
";
|
||||||
|
let mut reader: csv::Reader<&[u8]> = csv::Reader::from_reader(records.as_bytes());
|
||||||
|
let mut writer = csv::Writer::from_writer(vec![]);
|
||||||
|
filter_file(&vec![], &mut reader, &mut writer)?;
|
||||||
|
let result = String::from_utf8(writer.into_inner()?)?;
|
||||||
|
assert_eq!(
|
||||||
|
records, result,
|
||||||
|
"Should not modify input when no filters are defined"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn filters_data() -> anyhow::Result<()> {
|
||||||
|
let records = "Column1,Column2
|
||||||
|
Value1,Value2
|
||||||
|
Value3,Value4
|
||||||
|
";
|
||||||
|
let mut reader: csv::Reader<&[u8]> = csv::Reader::from_reader(records.as_bytes());
|
||||||
|
let mut writer = csv::Writer::from_writer(vec![]);
|
||||||
|
filter_file(
|
||||||
|
&vec![Box::new(FilterRule {
|
||||||
|
column_name: "Column1".to_owned(),
|
||||||
|
comparator: crate::filter::Comparator::NotEqual("Value3".to_owned()),
|
||||||
|
})],
|
||||||
|
&mut reader,
|
||||||
|
&mut writer,
|
||||||
|
)?;
|
||||||
|
let result = String::from_utf8(writer.into_inner()?)?;
|
||||||
|
assert_eq!(
|
||||||
|
"Column1,Column2
|
||||||
|
Value1,Value2
|
||||||
|
",
|
||||||
|
result,
|
||||||
|
"Should filter out second record due to filter rules"
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_print_header_when_no_rules_pass() {}
|
||||||
}
|
}
|
||||||
|
|||||||
310
src/graph.rs
310
src/graph.rs
@@ -1,9 +1,23 @@
|
|||||||
use itertools::Itertools;
|
use std::{
|
||||||
|
cmp::{min, Ordering},
|
||||||
|
collections::{HashMap, HashSet},
|
||||||
|
sync::{
|
||||||
|
mpsc::{self, Sender},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
thread,
|
||||||
|
};
|
||||||
|
|
||||||
|
use chrono::Local;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::filter::FilterNodeRunner;
|
use crate::{
|
||||||
|
derive::DeriveNode,
|
||||||
|
filter::{FilterNode, FilterNodeRunner},
|
||||||
|
node::RunnableNode,
|
||||||
|
upload_to_db::{UploadNode, UploadNodeRunner},
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: Break all this up into separate files in the graph module
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
pub enum NodeConfiguration {
|
pub enum NodeConfiguration {
|
||||||
FileNode,
|
FileNode,
|
||||||
@@ -12,6 +26,14 @@ pub enum NodeConfiguration {
|
|||||||
DeriveNode(DeriveNode),
|
DeriveNode(DeriveNode),
|
||||||
CodeRuleNode(CodeRuleNode),
|
CodeRuleNode(CodeRuleNode),
|
||||||
FilterNode(FilterNode),
|
FilterNode(FilterNode),
|
||||||
|
UploadNode(UploadNode),
|
||||||
|
Dynamic,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct DynamicConfiguration {
|
||||||
|
pub node_type: String,
|
||||||
|
pub parameters: HashMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
@@ -19,6 +41,7 @@ pub struct NodeInfo {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
pub output_files: Vec<String>,
|
pub output_files: Vec<String>,
|
||||||
pub configuration: NodeConfiguration,
|
pub configuration: NodeConfiguration,
|
||||||
|
pub dynamic_configuration: Option<DynamicConfiguration>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
@@ -65,70 +88,6 @@ pub struct MergeNode {
|
|||||||
pub joins: Vec<MergeJoin>,
|
pub joins: Vec<MergeJoin>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub enum DeriveColumnType {
|
|
||||||
Column(String),
|
|
||||||
Constant(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub struct MapOperation {
|
|
||||||
pub mapped_value: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub enum DatePart {
|
|
||||||
Year,
|
|
||||||
Month,
|
|
||||||
Week,
|
|
||||||
Day,
|
|
||||||
Hour,
|
|
||||||
Minute,
|
|
||||||
Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub enum SplitType {
|
|
||||||
DateTime(String, DatePart),
|
|
||||||
Numeric(String, isize),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub enum MatchComparisonType {
|
|
||||||
Equal,
|
|
||||||
GreaterThan,
|
|
||||||
LessThan,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub enum DeriveOperation {
|
|
||||||
Concat(Vec<DeriveColumnType>),
|
|
||||||
Add(Vec<DeriveColumnType>),
|
|
||||||
Multiply(Vec<DeriveColumnType>),
|
|
||||||
Subtract(DeriveColumnType, DeriveColumnType),
|
|
||||||
Divide(DeriveColumnType, DeriveColumnType),
|
|
||||||
Map(String, Vec<MapOperation>),
|
|
||||||
Split(String, SplitType),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub struct DeriveFilter {
|
|
||||||
pub column_name: String,
|
|
||||||
pub comparator: MatchComparisonType,
|
|
||||||
pub match_value: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub struct DeriveRule {
|
|
||||||
pub operations: Vec<DeriveOperation>,
|
|
||||||
pub filters: Vec<DeriveFilter>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub struct DeriveNode {
|
|
||||||
pub rules: Vec<DeriveRule>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
pub enum CodeRuleLanguage {
|
pub enum CodeRuleLanguage {
|
||||||
Javascript,
|
Javascript,
|
||||||
@@ -142,16 +101,16 @@ pub struct CodeRuleNode {
|
|||||||
pub text: String,
|
pub text: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
|
||||||
pub struct FilterNode {
|
|
||||||
pub filters: Vec<DeriveFilter>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Clone)]
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
pub struct Node {
|
pub struct Node {
|
||||||
pub id: i64,
|
pub id: i64,
|
||||||
pub info: NodeInfo,
|
pub info: NodeInfo,
|
||||||
pub dependent_node_ids: Vec<i64>,
|
pub dependent_node_ids: Vec<i64>,
|
||||||
|
// Lets us work out whether a task should be rerun, by
|
||||||
|
// inspecting the files at the output paths and check if
|
||||||
|
// their timestamp is before this
|
||||||
|
// TODO: Could just be seconds since unix epoch?
|
||||||
|
pub last_modified: chrono::DateTime<Local>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Node {
|
impl Node {
|
||||||
@@ -160,21 +119,16 @@ impl Node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Into<RunnableGraphNode> for Node {
|
fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
|
||||||
fn into(self) -> RunnableGraphNode {
|
match node.info.configuration {
|
||||||
RunnableGraphNode {
|
NodeConfiguration::FileNode => todo!(),
|
||||||
runnable_node: Box::new(FilterNodeRunner {}),
|
NodeConfiguration::MoveMoneyNode(_) => todo!(),
|
||||||
// TODO: Construct node objects
|
NodeConfiguration::MergeNode(_) => todo!(),
|
||||||
// runnable_node: match &self.info.configuration {
|
NodeConfiguration::DeriveNode(_) => todo!(),
|
||||||
// NodeConfiguration::FileNode => todo!(),
|
NodeConfiguration::CodeRuleNode(_) => todo!(),
|
||||||
// NodeConfiguration::MoveMoneyNode(_) => todo!(),
|
NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
|
||||||
// NodeConfiguration::MergeNode(_) => todo!(),
|
NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
|
||||||
// NodeConfiguration::DeriveNode(_) => todo!(),
|
NodeConfiguration::Dynamic => todo!(),
|
||||||
// NodeConfiguration::CodeRuleNode(_) => todo!(),
|
|
||||||
// NodeConfiguration::FilterNode(_) => todo!(),
|
|
||||||
// },
|
|
||||||
node: self,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,32 +138,176 @@ pub struct Graph {
|
|||||||
pub nodes: Vec<Node>,
|
pub nodes: Vec<Node>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait RunnableNode {
|
pub enum NodeStatus {
|
||||||
fn run(&self);
|
Completed,
|
||||||
}
|
Running,
|
||||||
|
// TODO: Error code?
|
||||||
pub struct RunnableGraphNode {
|
Failed,
|
||||||
pub runnable_node: Box<dyn RunnableNode>,
|
|
||||||
pub node: Node,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct RunnableGraph {
|
pub struct RunnableGraph {
|
||||||
pub name: String,
|
pub graph: Graph,
|
||||||
pub nodes: Vec<RunnableGraphNode>,
|
pub node_statuses: HashMap<i64, NodeStatus>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RunnableGraph {
|
impl RunnableGraph {
|
||||||
pub fn from_graph(graph: &Graph) -> RunnableGraph {
|
pub fn from_graph(graph: Graph) -> RunnableGraph {
|
||||||
RunnableGraph {
|
RunnableGraph {
|
||||||
name: graph.name.clone(),
|
graph,
|
||||||
nodes: graph
|
node_statuses: HashMap::new(),
|
||||||
.nodes
|
|
||||||
.iter()
|
|
||||||
.map(|node| {
|
|
||||||
let runnable_graph_node: RunnableGraphNode = node.clone().into();
|
|
||||||
runnable_graph_node
|
|
||||||
})
|
|
||||||
.collect_vec(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
|
||||||
|
self.run(num_threads, Box::new(|node| get_runnable_node(node)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make this not mutable, emit node status when required in a function or some other message
|
||||||
|
pub fn run(
|
||||||
|
&mut self,
|
||||||
|
num_threads: usize,
|
||||||
|
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let mut nodes = self.graph.nodes.clone();
|
||||||
|
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
|
||||||
|
nodes.sort_by(|a, b| {
|
||||||
|
if b.dependent_node_ids.contains(&a.id) {
|
||||||
|
return Ordering::Greater;
|
||||||
|
}
|
||||||
|
if a.dependent_node_ids.contains(&b.id) {
|
||||||
|
return Ordering::Less;
|
||||||
|
}
|
||||||
|
Ordering::Equal
|
||||||
|
});
|
||||||
|
|
||||||
|
let num_threads = min(num_threads, nodes.len());
|
||||||
|
|
||||||
|
// Sync version
|
||||||
|
if num_threads < 2 {
|
||||||
|
for node in &nodes {
|
||||||
|
self.node_statuses.insert(node.id, NodeStatus::Running);
|
||||||
|
match get_node_fn(node.clone()).run() {
|
||||||
|
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed),
|
||||||
|
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut running_nodes = HashSet::new();
|
||||||
|
let mut completed_nodes = HashSet::new();
|
||||||
|
|
||||||
|
let mut senders: Vec<Sender<Node>> = vec![];
|
||||||
|
let mut handles = vec![];
|
||||||
|
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
|
||||||
|
let node_fn = Arc::new(get_node_fn);
|
||||||
|
for n in 0..num_threads {
|
||||||
|
let finish_task = finish_task.clone();
|
||||||
|
let (tx, rx) = mpsc::channel();
|
||||||
|
senders.push(tx);
|
||||||
|
let node_fn = node_fn.clone();
|
||||||
|
let handle = thread::spawn(move || {
|
||||||
|
for node in rx {
|
||||||
|
node_fn(node.clone()).run();
|
||||||
|
finish_task.send((n, node));
|
||||||
|
}
|
||||||
|
println!("Thread {} finished", n);
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut running_threads = HashSet::new();
|
||||||
|
// Run nodes without dependencies. There'll always be at least one
|
||||||
|
|
||||||
|
for i in 0..nodes.len() {
|
||||||
|
// Ensure we don't overload threads
|
||||||
|
if i >= num_threads {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if !nodes[i].has_dependent_nodes() {
|
||||||
|
running_threads.insert(i);
|
||||||
|
// Run all nodes that have no dependencies
|
||||||
|
let node = nodes.remove(i);
|
||||||
|
self.node_statuses.insert(node.id, NodeStatus::Running);
|
||||||
|
running_nodes.insert(node.id);
|
||||||
|
senders[i % senders.len()].send(node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run each dependent node after a graph above finishes.
|
||||||
|
for (n, node) in listen_finish_task {
|
||||||
|
running_threads.remove(&n);
|
||||||
|
// TODO: Add error check here
|
||||||
|
self.node_statuses.insert(node.id, NodeStatus::Completed);
|
||||||
|
running_nodes.remove(&node.id);
|
||||||
|
completed_nodes.insert(node.id);
|
||||||
|
// Run all the nodes that can be run and aren't in completed
|
||||||
|
let mut i = 0;
|
||||||
|
while running_threads.len() < num_threads && i < nodes.len() {
|
||||||
|
if !running_nodes.contains(&nodes[i].id)
|
||||||
|
&& nodes[i]
|
||||||
|
.dependent_node_ids
|
||||||
|
.iter()
|
||||||
|
.all(|id| completed_nodes.contains(id))
|
||||||
|
{
|
||||||
|
let node = nodes.remove(i);
|
||||||
|
for i in 0..num_threads {
|
||||||
|
if !running_threads.contains(&i) {
|
||||||
|
senders[i].send(node);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
if nodes.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for sender in senders {
|
||||||
|
drop(sender);
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.join().expect("Failed to join thread");
|
||||||
|
}
|
||||||
|
println!("Process finished");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use chrono::Local;
|
||||||
|
|
||||||
|
use super::{NodeConfiguration, RunnableGraph};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_basic() -> anyhow::Result<()> {
|
||||||
|
let mut graph = RunnableGraph {
|
||||||
|
node_statuses: HashMap::new(),
|
||||||
|
graph: super::Graph {
|
||||||
|
name: "Test".to_owned(),
|
||||||
|
nodes: vec![super::Node {
|
||||||
|
id: 1,
|
||||||
|
dependent_node_ids: vec![],
|
||||||
|
info: super::NodeInfo {
|
||||||
|
name: "Hello".to_owned(),
|
||||||
|
configuration: NodeConfiguration::FilterNode(super::FilterNode {
|
||||||
|
filters: vec![],
|
||||||
|
input_file_path: "".to_owned(),
|
||||||
|
output_file_path: "".to_owned(),
|
||||||
|
}),
|
||||||
|
output_files: vec![],
|
||||||
|
dynamic_configuration: None,
|
||||||
|
},
|
||||||
|
last_modified: Local::now(),
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
graph.run_default_tasks(2)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
113
src/io.rs
113
src/io.rs
@@ -1,12 +1,21 @@
|
|||||||
use std::io::{Read, Seek, Write};
|
use std::{
|
||||||
|
collections::BTreeMap,
|
||||||
|
io::{Read, Seek, Write},
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::bail;
|
use anyhow::bail;
|
||||||
use csv::Position;
|
|
||||||
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
|
use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
|
||||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||||
|
|
||||||
pub trait RecordSerializer {
|
pub trait RecordSerializer {
|
||||||
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;
|
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
// For when serde serialization can't be used. Forcing BTreeMap to ensure keys/values are
|
||||||
|
// sorted consistently
|
||||||
|
fn write_header(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()>;
|
||||||
|
fn write_record(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()>;
|
||||||
|
|
||||||
|
fn flush(&mut self) -> anyhow::Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: Write> RecordSerializer for csv::Writer<W> {
|
impl<W: Write> RecordSerializer for csv::Writer<W> {
|
||||||
@@ -14,6 +23,21 @@ impl<W: Write> RecordSerializer for csv::Writer<W> {
|
|||||||
self.serialize(record)?;
|
self.serialize(record)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> anyhow::Result<()> {
|
||||||
|
self.flush()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_header(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()> {
|
||||||
|
self.write_record(record.keys())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_record(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()> {
|
||||||
|
self.write_record(record.values())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<W: Write> RecordSerializer for Serializer<W> {
|
impl<W: Write> RecordSerializer for Serializer<W> {
|
||||||
@@ -21,34 +45,28 @@ impl<W: Write> RecordSerializer for Serializer<W> {
|
|||||||
record.serialize(self)?;
|
record.serialize(self)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn flush(&mut self) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_header(&mut self, _: &BTreeMap<String, String>) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_record(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()> {
|
||||||
|
self.serialize(record)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: I still don't like this api, should split deserialize and position at the least,
|
pub trait RecordDeserializer {
|
||||||
// and we need a way to get the current position (otherwise it's left to consumers to track current)
|
|
||||||
// position
|
|
||||||
pub trait RecordDeserializer<P> {
|
|
||||||
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error>;
|
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error>;
|
||||||
|
|
||||||
// Move the deserializer to the specified position in the underlying reader
|
|
||||||
fn position(&mut self, record: P) -> anyhow::Result<()>;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CsvMessagePackDeserializer<R> {
|
impl<R: Read> RecordDeserializer for csv::Reader<R> {
|
||||||
reader: csv::Reader<R>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: Read> CsvMessagePackDeserializer<R> {
|
|
||||||
fn new(reader: R) -> CsvMessagePackDeserializer<R> {
|
|
||||||
CsvMessagePackDeserializer {
|
|
||||||
reader: csv::Reader::from_reader(reader),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: Read + Seek> RecordDeserializer<Position> for CsvMessagePackDeserializer<R> {
|
|
||||||
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> {
|
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> {
|
||||||
// TODO: This isn't great, need to somehow maintain the state/position
|
match self.deserialize().next() {
|
||||||
match self.reader.deserialize().next() {
|
|
||||||
None => Ok(Option::None),
|
None => Ok(Option::None),
|
||||||
Some(result) => match result {
|
Some(result) => match result {
|
||||||
Ok(ok) => Ok(Option::Some(ok)),
|
Ok(ok) => Ok(Option::Some(ok)),
|
||||||
@@ -56,56 +74,13 @@ impl<R: Read + Seek> RecordDeserializer<Position> for CsvMessagePackDeserializer
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn position(&mut self, record: Position) -> anyhow::Result<()> {
|
|
||||||
self.reader.seek(record)?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct MessagePackDeserializer<R: Read> {
|
impl<R: Read + Seek> RecordDeserializer for Deserializer<ReadReader<R>> {
|
||||||
reader: Deserializer<ReadReader<R>>,
|
|
||||||
record_positions: Vec<u64>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R: Read + Seek> MessagePackDeserializer<R> {
|
|
||||||
fn new(reader: R) -> MessagePackDeserializer<R> {
|
|
||||||
MessagePackDeserializer {
|
|
||||||
reader: Deserializer::new(reader),
|
|
||||||
record_positions: vec![],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: These need tests
|
|
||||||
impl<R: Read + Seek> RecordDeserializer<usize> for MessagePackDeserializer<R> {
|
|
||||||
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> {
|
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> {
|
||||||
// Keep track of byte position of each record, in case we want to go back later
|
match Deserialize::deserialize(self) {
|
||||||
let current_position = self.reader.get_mut().stream_position()?;
|
|
||||||
if self
|
|
||||||
.record_positions
|
|
||||||
.last()
|
|
||||||
.map_or(true, |position| *position < current_position)
|
|
||||||
{
|
|
||||||
self.record_positions.push(current_position);
|
|
||||||
}
|
|
||||||
match Deserialize::deserialize(&mut self.reader) {
|
|
||||||
Ok(value) => Ok(value),
|
Ok(value) => Ok(value),
|
||||||
Err(value) => Err(anyhow::Error::from(value)),
|
Err(value) => Err(anyhow::Error::from(value)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn position(&mut self, record: usize) -> anyhow::Result<()> {
|
|
||||||
let reader = self.reader.get_mut();
|
|
||||||
// Unsigned so can't be less than 0
|
|
||||||
if self.record_positions.len() > record {
|
|
||||||
// Go to position in reader
|
|
||||||
let position = self.record_positions[record];
|
|
||||||
reader.seek(std::io::SeekFrom::Start(position))?;
|
|
||||||
} else {
|
|
||||||
// read through the reader until we get to the correct record
|
|
||||||
bail!("Record hasn't been read yet, please use deserialize to find the record")
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
29
src/lib.rs
29
src/lib.rs
@@ -12,10 +12,12 @@ pub use self::products::csv::SourceType;
|
|||||||
mod shared_models;
|
mod shared_models;
|
||||||
pub use self::shared_models::*;
|
pub use self::shared_models::*;
|
||||||
pub mod code_rule;
|
pub mod code_rule;
|
||||||
|
pub mod derive;
|
||||||
pub mod filter;
|
pub mod filter;
|
||||||
mod graph;
|
mod graph;
|
||||||
mod io;
|
mod io;
|
||||||
pub mod link;
|
pub mod link;
|
||||||
|
pub mod node;
|
||||||
pub mod upload_to_db;
|
pub mod upload_to_db;
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
@@ -63,22 +65,21 @@ pub extern "C" fn move_money_from_file(
|
|||||||
output_path: *const c_char,
|
output_path: *const c_char,
|
||||||
use_numeric_accounts: bool,
|
use_numeric_accounts: bool,
|
||||||
) {
|
) {
|
||||||
let mut output_writer = csv::Writer::from_writer(vec![]);
|
|
||||||
let safe_rules = unwrap_c_char(rules_file);
|
let safe_rules = unwrap_c_char(rules_file);
|
||||||
let safe_lines = unwrap_c_char(lines);
|
let safe_lines = unwrap_c_char(lines);
|
||||||
let safe_accounts = unwrap_c_char(accounts);
|
let safe_accounts = unwrap_c_char(accounts);
|
||||||
let safe_cost_centres = unwrap_c_char(cost_centres);
|
let safe_cost_centres = unwrap_c_char(cost_centres);
|
||||||
// move_money_2()
|
let output_path = unwrap_c_char(output_path);
|
||||||
// move_money(
|
move_money(
|
||||||
// ,
|
&mut csv::Reader::from_reader(safe_rules.to_bytes()),
|
||||||
// &mut csv::Reader::from_reader(safe_lines.to_str().unwrap()),
|
&mut csv::Reader::from_reader(safe_lines.to_bytes()),
|
||||||
// &mut csv::Reader::from_reader(safe_accounts.to_bytes()),
|
&mut csv::Reader::from_reader(safe_accounts.to_bytes()),
|
||||||
// &mut csv::Reader::from_reader(safe_cost_centres.to_bytes()),
|
&mut csv::Reader::from_reader(safe_cost_centres.to_bytes()),
|
||||||
// &mut output_writer,
|
&mut csv::Writer::from_path(output_path.to_str().unwrap()).unwrap(),
|
||||||
// use_numeric_accounts,
|
use_numeric_accounts,
|
||||||
// false,
|
false,
|
||||||
// )
|
)
|
||||||
// .expect("Failed to move money");
|
.expect("Failed to move money");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
@@ -87,7 +88,7 @@ pub unsafe extern "C" fn move_money_from_text_free(s: *mut c_char) {
|
|||||||
if s.is_null() {
|
if s.is_null() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CString::from_raw(s)
|
let _ = CString::from_raw(s);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,6 +182,6 @@ pub unsafe extern "C" fn allocate_overheads_from_text_free(s: *mut c_char) {
|
|||||||
if s.is_null() {
|
if s.is_null() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
CString::from_raw(s)
|
let _ = CString::from_raw(s);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ pub fn link(
|
|||||||
// TODO: Check this filters out correctly, as it's filtering on a reference, not a value
|
// TODO: Check this filters out correctly, as it's filtering on a reference, not a value
|
||||||
.unique()
|
.unique()
|
||||||
.collect();
|
.collect();
|
||||||
let mut source_date_columns: Vec<&String> = linking_rule
|
let source_date_columns: Vec<&String> = linking_rule
|
||||||
.linking_rules
|
.linking_rules
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|rule| {
|
.flat_map(|rule| {
|
||||||
@@ -86,7 +86,7 @@ pub fn link(
|
|||||||
// TODO: Merge with source_indexes?
|
// TODO: Merge with source_indexes?
|
||||||
// Also store the actual date value rather than string, so we
|
// Also store the actual date value rather than string, so we
|
||||||
// don't need to convert as much later
|
// don't need to convert as much later
|
||||||
let mut source_dates: Vec<HashMap<String, Vec<usize>>>;
|
let source_dates: Vec<HashMap<String, Vec<usize>>>;
|
||||||
for source_record in source_reader.deserialize() {
|
for source_record in source_reader.deserialize() {
|
||||||
let source_record: HashMap<String, String> = source_record?;
|
let source_record: HashMap<String, String> = source_record?;
|
||||||
let current_idx = source_ids.len();
|
let current_idx = source_ids.len();
|
||||||
|
|||||||
6
src/node.rs
Normal file
6
src/node.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
pub trait RunnableNode {
|
||||||
|
// TODO: Get inputs/outputs to determine whether we can skip running this task
|
||||||
|
|
||||||
|
// TODO: Status
|
||||||
|
fn run(&self) -> anyhow::Result<()>;
|
||||||
|
}
|
||||||
@@ -11,7 +11,7 @@ use polars::lazy::dsl::*;
|
|||||||
use polars::prelude::*;
|
use polars::prelude::*;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use super::csv::{read_definitions, Component, Definition, FileJoin, SourceType};
|
use super::csv::{read_definitions, Component, Definition};
|
||||||
|
|
||||||
// TODO: Polars suggests this, but docs suggest it doesn't have very good platform support
|
// TODO: Polars suggests this, but docs suggest it doesn't have very good platform support
|
||||||
//use jemallocator::Jemalloc;
|
//use jemallocator::Jemalloc;
|
||||||
|
|||||||
@@ -1,7 +1,12 @@
|
|||||||
use std::{collections::HashMap, io::Read};
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use csv::Reader;
|
use anyhow::bail;
|
||||||
use sqlx::{query, query_builder, Any, Mssql, Pool, QueryBuilder};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use sqlx::{Any, Pool, QueryBuilder};
|
||||||
|
|
||||||
|
use crate::node::RunnableNode;
|
||||||
|
|
||||||
|
const BIND_LIMIT: usize = 65535;
|
||||||
|
|
||||||
// Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any
|
// Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any
|
||||||
// type for sqlx 0.6 and earlier due to a query_builder lifetime issue,
|
// type for sqlx 0.6 and earlier due to a query_builder lifetime issue,
|
||||||
@@ -12,44 +17,32 @@ use sqlx::{query, query_builder, Any, Mssql, Pool, QueryBuilder};
|
|||||||
// TODO: Add bulk insert options for non-mssql dbs
|
// TODO: Add bulk insert options for non-mssql dbs
|
||||||
// TODO: Add fallback insert when bulk insert fails (e.g. due to
|
// TODO: Add fallback insert when bulk insert fails (e.g. due to
|
||||||
// permission errors)
|
// permission errors)
|
||||||
pub async fn upload_file_bulk(
|
pub async fn upload_file_bulk(pool: &Pool<Any>, upload_node: &UploadNode) -> anyhow::Result<u64> {
|
||||||
pool: &Pool<sqlx::Mssql>,
|
let mut rows_affected = None;
|
||||||
file_name: &String,
|
if upload_node.column_mappings.is_none() {
|
||||||
table_name: &String,
|
let insert_from_file_query = match pool.connect_options().database_url.scheme() {
|
||||||
// Mappings from column in file -> column in db
|
"postgres" => Some(format!("COPY {} FROM $1", upload_node.table_name)),
|
||||||
column_mappings: Option<HashMap<String, String>>,
|
"mysql" => Some(format!(
|
||||||
post_script: Option<String>,
|
"LOAD DATA INFILE ? INTO {}",
|
||||||
) -> anyhow::Result<u64> {
|
upload_node.table_name,
|
||||||
// TODO: Test if the table already exists. If it doesn't, try creating the table
|
)),
|
||||||
|
_ => None,
|
||||||
// First try a bulk insert command
|
};
|
||||||
// let result = match pool.any_kind() {
|
if let Some(insert_from_file_query) = insert_from_file_query {
|
||||||
// sqlx::any::AnyKind::Mssql => {
|
let result = sqlx::query(&insert_from_file_query)
|
||||||
let result = sqlx::query(&format!("BULK INSERT {} FROM {}", table_name, file_name))
|
.bind(&upload_node.file_path)
|
||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
// }
|
rows_affected = Some(result.rows_affected());
|
||||||
// };
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let mut rows_affected = result.rows_affected();
|
if rows_affected == None {
|
||||||
|
|
||||||
|
|
||||||
// let mut rows_affected = match &result {
|
|
||||||
// Result::Ok(result) => result.rows_affected(),
|
|
||||||
// // TODO: Log error
|
|
||||||
// Err(error) => 0_u64,
|
|
||||||
// };
|
|
||||||
|
|
||||||
// TODO: Adjust for various dbmss
|
|
||||||
if rows_affected == 0 {
|
|
||||||
let rows: Vec<HashMap<String, String>> = vec![];
|
let rows: Vec<HashMap<String, String>> = vec![];
|
||||||
|
|
||||||
let BIND_LIMIT: usize = 65535;
|
|
||||||
// TODO: Use csv to read from file
|
|
||||||
|
|
||||||
// TODO: When bulk insert fails, Fall back to sql batched insert
|
|
||||||
// TODO: Columns to insert... needs some kind of mapping from file column name <-> db column
|
// TODO: Columns to insert... needs some kind of mapping from file column name <-> db column
|
||||||
let mut query_builder = QueryBuilder::new(format!("INSERT INTO {}({}) ", table_name, ""));
|
let mut query_builder =
|
||||||
|
QueryBuilder::new(format!("INSERT INTO {}({}) ", upload_node.table_name, ""));
|
||||||
// TODO: Iterate over all values in file, not the limit
|
// TODO: Iterate over all values in file, not the limit
|
||||||
query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| {
|
query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| {
|
||||||
b.push_bind(row.get("s"));
|
b.push_bind(row.get("s"));
|
||||||
@@ -58,14 +51,40 @@ pub async fn upload_file_bulk(
|
|||||||
// TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978
|
// TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978
|
||||||
// Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx
|
// Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx
|
||||||
// to use this, unless we explicity specified mssql only, not Any as the db type...
|
// to use this, unless we explicity specified mssql only, not Any as the db type...
|
||||||
|
// Can probably work around this by specifying an actual implementation?
|
||||||
let query = query_builder.build();
|
let query = query_builder.build();
|
||||||
let result = query.execute(pool).await?;
|
let result = query.execute(pool).await?;
|
||||||
rows_affected = result.rows_affected();
|
rows_affected = Some(result.rows_affected());
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(post_script) = post_script {
|
if let Some(post_script) = &upload_node.post_script {
|
||||||
sqlx::query(&post_script).execute(pool).await?;
|
sqlx::query(&post_script).execute(pool).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(rows_affected)
|
match rows_affected {
|
||||||
|
Some(rows_affected) => Ok(rows_affected),
|
||||||
|
None => bail!(
|
||||||
|
"Invalid state, rows_affected must be populated, or an error should have occurred"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize, Clone)]
|
||||||
|
pub struct UploadNode {
|
||||||
|
file_path: String,
|
||||||
|
table_name: String,
|
||||||
|
// Mappings from column in file -> column in db
|
||||||
|
column_mappings: Option<HashMap<String, String>>,
|
||||||
|
post_script: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct UploadNodeRunner {
|
||||||
|
pub upload_node: UploadNode,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RunnableNode for UploadNodeRunner {
|
||||||
|
fn run(&self) -> anyhow::Result<()> {
|
||||||
|
// TODO: Get db connection from some kind of property manager/context
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user