Add custom graph executor and implement filter node to test it (#2)

Reviewed-on: vato007/coster-rs#2
This commit is contained in:
2024-07-28 16:41:49 +09:30
parent 25180d3616
commit 5acee8c889
12 changed files with 1123 additions and 500 deletions

816
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -22,7 +22,7 @@ chrono = {version = "0.4.31", features = ["default", "serde"]}
rayon = "1.6.0" rayon = "1.6.0"
tokio = { version = "1.26.0", features = ["full"] } tokio = { version = "1.26.0", features = ["full"] }
sqlx = { version = "0.6", features = [ "runtime-tokio-rustls", "mssql", "any" ] } sqlx = { version = "0.8", features = [ "runtime-tokio-rustls", "any" ] }
rmp-serde = "1.1.1" rmp-serde = "1.1.1"
tempfile = "3.7.0" tempfile = "3.7.0"
polars = {version = "0.32.1", features = ["lazy", "performant", "streaming", "cse", "dtype-datetime"]} polars = {version = "0.32.1", features = ["lazy", "performant", "streaming", "cse", "dtype-datetime"]}

View File

@@ -1,5 +1,4 @@
use coster_rs::upload_to_db; use sqlx::any::AnyPoolOptions;
use sqlx::{any::AnyPoolOptions, mssql::MssqlPoolOptions};
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
@@ -7,13 +6,14 @@ async fn main() -> anyhow::Result<()> {
let password = ""; let password = "";
let host = ""; let host = "";
let database = ""; let database = "";
// USing sqlx: https://github.com/launchbadge/sqlx let database_type = "";
let connection_string = format!("mssq://{}:{}@{}/{}", user, password, host, database); let connection_string = format!(
let pool = AnyPoolOptions::new() "{}://{}:{}@{}/{}",
database_type, user, password, host, database
);
let _ = AnyPoolOptions::new()
.max_connections(20) .max_connections(20)
.connect(&connection_string) .connect(&connection_string)
.await?; .await?;
// upload_to_db::upload_file_bulk(&pool, &"".to_owned(), &"".to_owned(), None, "".to_owned()).await?;
Ok(()) Ok(())
} }

75
src/derive.rs Normal file
View File

@@ -0,0 +1,75 @@
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone)]
pub enum DeriveColumnType {
Column(String),
Constant(String),
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MapOperation {
pub mapped_value: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum DatePart {
Year,
Month,
Week,
Day,
Hour,
Minute,
Second,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum SplitType {
DateTime(String, DatePart),
Numeric(String, isize),
}
#[derive(Serialize, Deserialize, Clone)]
pub enum MatchComparisonType {
Equal,
GreaterThan,
LessThan,
NotEqual,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum DeriveOperation {
Concat(Vec<DeriveColumnType>),
Add(Vec<DeriveColumnType>),
Multiply(Vec<DeriveColumnType>),
Subtract(DeriveColumnType, DeriveColumnType),
Divide(DeriveColumnType, DeriveColumnType),
Map(String, Vec<MapOperation>),
Split(String, SplitType),
}
#[derive(Serialize, Deserialize, Clone)]
pub enum ValueType {
String,
Integer,
Float,
Boolean,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DeriveFilter {
pub column_name: String,
pub comparator: MatchComparisonType,
pub match_value: String,
pub value_type: ValueType,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DeriveRule {
pub operations: Vec<DeriveOperation>,
pub filters: Vec<DeriveFilter>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DeriveNode {
pub rules: Vec<DeriveRule>,
}

View File

@@ -1,6 +1,12 @@
use std::{collections::HashMap, io::Read, str::FromStr}; use std::{collections::BTreeMap, str::FromStr};
use crate::{graph::RunnableNode, io::RecordSerializer}; use serde::{Deserialize, Serialize};
use crate::{
derive::{DeriveFilter, MatchComparisonType},
io::{RecordDeserializer, RecordSerializer},
node::RunnableNode,
};
pub enum Comparator<T: PartialOrd> { pub enum Comparator<T: PartialOrd> {
Equal(T), Equal(T),
@@ -56,30 +62,143 @@ impl<T: FromStr + PartialOrd> DataValidator for FilterRule<T> {
* that don't satisfy the filter criteria * that don't satisfy the filter criteria
*/ */
pub fn filter_file( pub fn filter_file(
rules: Vec<&dyn DataValidator>, rules: &Vec<Box<dyn DataValidator>>,
// TODO: Custom serialisers/deserialisers so we don't rely on csv only input: &mut impl RecordDeserializer,
input: &mut csv::Reader<impl Read>,
output: &mut impl RecordSerializer, output: &mut impl RecordSerializer,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
for line in input.deserialize() { if let Some(line) = input.deserialize()? {
let line: HashMap<String, String> = line?; let line: BTreeMap<String, String> = line;
if rules.iter().all(|rule| { output.write_header(&line)?;
line.get(&rule.get_field_name()).map_or(true, |value| { output.write_record(&line)?;
if value.trim().is_empty() {
true while let Some(line) = input.deserialize()? {
} else { let line: BTreeMap<String, String> = line;
rule.is_valid(value) if rules.iter().all(|rule| {
} line.get(&rule.get_field_name()).map_or(true, |value| {
}) if value.trim().is_empty() {
}) { true
output.serialize(line)?; } else {
rule.is_valid(value)
}
})
}) {
output.write_record(&line)?;
}
} }
output.flush()?;
} }
Ok(()) Ok(())
} }
pub struct FilterNodeRunner {} #[derive(Serialize, Deserialize, Clone)]
pub struct FilterNode {
pub filters: Vec<DeriveFilter>,
pub input_file_path: String,
pub output_file_path: String,
}
impl FilterNode {
fn to_filter_rules(&self) -> anyhow::Result<Vec<Box<dyn DataValidator>>> {
self.filters
.iter()
// For some reason inlining to_filter_rules causes a compiler error, so leaving
// in a separate function (it is cleaner at least)
.map(|filter| to_filter_rule(filter))
.collect()
}
}
fn to_filter_rule(filter: &DeriveFilter) -> anyhow::Result<Box<dyn DataValidator>> {
let value = filter.match_value.clone();
match filter.value_type {
crate::derive::ValueType::String => Ok(Box::new(get_filter_rule(filter, value))),
crate::derive::ValueType::Integer => {
Ok(Box::new(get_filter_rule(filter, value.parse::<i64>()?)))
}
crate::derive::ValueType::Float => {
Ok(Box::new(get_filter_rule(filter, value.parse::<f64>()?)))
}
crate::derive::ValueType::Boolean => {
Ok(Box::new(get_filter_rule(filter, value.parse::<bool>()?)))
}
}
}
fn get_filter_rule<T: PartialOrd>(filter: &DeriveFilter, value: T) -> FilterRule<T> {
FilterRule {
column_name: filter.column_name.clone(),
comparator: match filter.comparator {
MatchComparisonType::Equal => Comparator::Equal(value),
MatchComparisonType::GreaterThan => Comparator::GreaterThan(value),
MatchComparisonType::LessThan => Comparator::LessThan(value),
MatchComparisonType::NotEqual => Comparator::NotEqual(value),
},
}
}
pub struct FilterNodeRunner {
pub filter_node: FilterNode,
}
impl RunnableNode for FilterNodeRunner { impl RunnableNode for FilterNodeRunner {
fn run(&self) {} fn run(&self) -> anyhow::Result<()> {
let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?;
let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?;
let rules = self.filter_node.to_filter_rules()?;
filter_file(&rules, &mut reader, &mut writer)
}
}
#[cfg(test)]
mod tests {
use crate::filter::FilterRule;
use super::filter_file;
#[test]
fn no_filters_passes_through() -> anyhow::Result<()> {
let records = "Column1,Column2
Value1,Value2
Value3,Value4
";
let mut reader: csv::Reader<&[u8]> = csv::Reader::from_reader(records.as_bytes());
let mut writer = csv::Writer::from_writer(vec![]);
filter_file(&vec![], &mut reader, &mut writer)?;
let result = String::from_utf8(writer.into_inner()?)?;
assert_eq!(
records, result,
"Should not modify input when no filters are defined"
);
Ok(())
}
#[test]
fn filters_data() -> anyhow::Result<()> {
let records = "Column1,Column2
Value1,Value2
Value3,Value4
";
let mut reader: csv::Reader<&[u8]> = csv::Reader::from_reader(records.as_bytes());
let mut writer = csv::Writer::from_writer(vec![]);
filter_file(
&vec![Box::new(FilterRule {
column_name: "Column1".to_owned(),
comparator: crate::filter::Comparator::NotEqual("Value3".to_owned()),
})],
&mut reader,
&mut writer,
)?;
let result = String::from_utf8(writer.into_inner()?)?;
assert_eq!(
"Column1,Column2
Value1,Value2
",
result,
"Should filter out second record due to filter rules"
);
Ok(())
}
#[test]
fn should_print_header_when_no_rules_pass() {}
} }

View File

@@ -1,9 +1,23 @@
use itertools::Itertools; use std::{
cmp::{min, Ordering},
collections::{HashMap, HashSet},
sync::{
mpsc::{self, Sender},
Arc,
},
thread,
};
use chrono::Local;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::filter::FilterNodeRunner; use crate::{
derive::DeriveNode,
filter::{FilterNode, FilterNodeRunner},
node::RunnableNode,
upload_to_db::{UploadNode, UploadNodeRunner},
};
// TODO: Break all this up into separate files in the graph module
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
pub enum NodeConfiguration { pub enum NodeConfiguration {
FileNode, FileNode,
@@ -12,6 +26,14 @@ pub enum NodeConfiguration {
DeriveNode(DeriveNode), DeriveNode(DeriveNode),
CodeRuleNode(CodeRuleNode), CodeRuleNode(CodeRuleNode),
FilterNode(FilterNode), FilterNode(FilterNode),
UploadNode(UploadNode),
Dynamic,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DynamicConfiguration {
pub node_type: String,
pub parameters: HashMap<String, String>,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@@ -19,6 +41,7 @@ pub struct NodeInfo {
pub name: String, pub name: String,
pub output_files: Vec<String>, pub output_files: Vec<String>,
pub configuration: NodeConfiguration, pub configuration: NodeConfiguration,
pub dynamic_configuration: Option<DynamicConfiguration>,
} }
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
@@ -65,70 +88,6 @@ pub struct MergeNode {
pub joins: Vec<MergeJoin>, pub joins: Vec<MergeJoin>,
} }
#[derive(Serialize, Deserialize, Clone)]
pub enum DeriveColumnType {
Column(String),
Constant(String),
}
#[derive(Serialize, Deserialize, Clone)]
pub struct MapOperation {
pub mapped_value: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum DatePart {
Year,
Month,
Week,
Day,
Hour,
Minute,
Second,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum SplitType {
DateTime(String, DatePart),
Numeric(String, isize),
}
#[derive(Serialize, Deserialize, Clone)]
pub enum MatchComparisonType {
Equal,
GreaterThan,
LessThan,
}
#[derive(Serialize, Deserialize, Clone)]
pub enum DeriveOperation {
Concat(Vec<DeriveColumnType>),
Add(Vec<DeriveColumnType>),
Multiply(Vec<DeriveColumnType>),
Subtract(DeriveColumnType, DeriveColumnType),
Divide(DeriveColumnType, DeriveColumnType),
Map(String, Vec<MapOperation>),
Split(String, SplitType),
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DeriveFilter {
pub column_name: String,
pub comparator: MatchComparisonType,
pub match_value: String,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DeriveRule {
pub operations: Vec<DeriveOperation>,
pub filters: Vec<DeriveFilter>,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct DeriveNode {
pub rules: Vec<DeriveRule>,
}
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
pub enum CodeRuleLanguage { pub enum CodeRuleLanguage {
Javascript, Javascript,
@@ -142,16 +101,16 @@ pub struct CodeRuleNode {
pub text: String, pub text: String,
} }
#[derive(Serialize, Deserialize, Clone)]
pub struct FilterNode {
pub filters: Vec<DeriveFilter>,
}
#[derive(Serialize, Deserialize, Clone)] #[derive(Serialize, Deserialize, Clone)]
pub struct Node { pub struct Node {
pub id: i64, pub id: i64,
pub info: NodeInfo, pub info: NodeInfo,
pub dependent_node_ids: Vec<i64>, pub dependent_node_ids: Vec<i64>,
// Lets us work out whether a task should be rerun, by
// inspecting the files at the output paths and check if
// their timestamp is before this
// TODO: Could just be seconds since unix epoch?
pub last_modified: chrono::DateTime<Local>,
} }
impl Node { impl Node {
@@ -160,21 +119,16 @@ impl Node {
} }
} }
impl Into<RunnableGraphNode> for Node { fn get_runnable_node(node: Node) -> Box<dyn RunnableNode> {
fn into(self) -> RunnableGraphNode { match node.info.configuration {
RunnableGraphNode { NodeConfiguration::FileNode => todo!(),
runnable_node: Box::new(FilterNodeRunner {}), NodeConfiguration::MoveMoneyNode(_) => todo!(),
// TODO: Construct node objects NodeConfiguration::MergeNode(_) => todo!(),
// runnable_node: match &self.info.configuration { NodeConfiguration::DeriveNode(_) => todo!(),
// NodeConfiguration::FileNode => todo!(), NodeConfiguration::CodeRuleNode(_) => todo!(),
// NodeConfiguration::MoveMoneyNode(_) => todo!(), NodeConfiguration::FilterNode(filter_node) => Box::new(FilterNodeRunner { filter_node }),
// NodeConfiguration::MergeNode(_) => todo!(), NodeConfiguration::UploadNode(upload_node) => Box::new(UploadNodeRunner { upload_node }),
// NodeConfiguration::DeriveNode(_) => todo!(), NodeConfiguration::Dynamic => todo!(),
// NodeConfiguration::CodeRuleNode(_) => todo!(),
// NodeConfiguration::FilterNode(_) => todo!(),
// },
node: self,
}
} }
} }
@@ -184,32 +138,176 @@ pub struct Graph {
pub nodes: Vec<Node>, pub nodes: Vec<Node>,
} }
pub trait RunnableNode { pub enum NodeStatus {
fn run(&self); Completed,
} Running,
// TODO: Error code?
pub struct RunnableGraphNode { Failed,
pub runnable_node: Box<dyn RunnableNode>,
pub node: Node,
} }
pub struct RunnableGraph { pub struct RunnableGraph {
pub name: String, pub graph: Graph,
pub nodes: Vec<RunnableGraphNode>, pub node_statuses: HashMap<i64, NodeStatus>,
} }
impl RunnableGraph { impl RunnableGraph {
pub fn from_graph(graph: &Graph) -> RunnableGraph { pub fn from_graph(graph: Graph) -> RunnableGraph {
RunnableGraph { RunnableGraph {
name: graph.name.clone(), graph,
nodes: graph node_statuses: HashMap::new(),
.nodes
.iter()
.map(|node| {
let runnable_graph_node: RunnableGraphNode = node.clone().into();
runnable_graph_node
})
.collect_vec(),
} }
} }
pub fn run_default_tasks(&mut self, num_threads: usize) -> anyhow::Result<()> {
self.run(num_threads, Box::new(|node| get_runnable_node(node)))
}
// Make this not mutable, emit node status when required in a function or some other message
pub fn run(
&mut self,
num_threads: usize,
get_node_fn: Box<dyn Fn(Node) -> Box<dyn RunnableNode> + Send + Sync>,
) -> anyhow::Result<()> {
let mut nodes = self.graph.nodes.clone();
// 1. nodes the nodes based on dependencies (i.e. nodes without dependencies go first)
nodes.sort_by(|a, b| {
if b.dependent_node_ids.contains(&a.id) {
return Ordering::Greater;
}
if a.dependent_node_ids.contains(&b.id) {
return Ordering::Less;
}
Ordering::Equal
});
let num_threads = min(num_threads, nodes.len());
// Sync version
if num_threads < 2 {
for node in &nodes {
self.node_statuses.insert(node.id, NodeStatus::Running);
match get_node_fn(node.clone()).run() {
Ok(_) => self.node_statuses.insert(node.id, NodeStatus::Completed),
Err(_) => self.node_statuses.insert(node.id, NodeStatus::Failed),
};
}
return Ok(());
}
let mut running_nodes = HashSet::new();
let mut completed_nodes = HashSet::new();
let mut senders: Vec<Sender<Node>> = vec![];
let mut handles = vec![];
let (finish_task, listen_finish_task) = mpsc::sync_channel(num_threads);
let node_fn = Arc::new(get_node_fn);
for n in 0..num_threads {
let finish_task = finish_task.clone();
let (tx, rx) = mpsc::channel();
senders.push(tx);
let node_fn = node_fn.clone();
let handle = thread::spawn(move || {
for node in rx {
node_fn(node.clone()).run();
finish_task.send((n, node));
}
println!("Thread {} finished", n);
});
handles.push(handle);
}
let mut running_threads = HashSet::new();
// Run nodes without dependencies. There'll always be at least one
for i in 0..nodes.len() {
// Ensure we don't overload threads
if i >= num_threads {
break;
}
if !nodes[i].has_dependent_nodes() {
running_threads.insert(i);
// Run all nodes that have no dependencies
let node = nodes.remove(i);
self.node_statuses.insert(node.id, NodeStatus::Running);
running_nodes.insert(node.id);
senders[i % senders.len()].send(node);
}
}
// Run each dependent node after a graph above finishes.
for (n, node) in listen_finish_task {
running_threads.remove(&n);
// TODO: Add error check here
self.node_statuses.insert(node.id, NodeStatus::Completed);
running_nodes.remove(&node.id);
completed_nodes.insert(node.id);
// Run all the nodes that can be run and aren't in completed
let mut i = 0;
while running_threads.len() < num_threads && i < nodes.len() {
if !running_nodes.contains(&nodes[i].id)
&& nodes[i]
.dependent_node_ids
.iter()
.all(|id| completed_nodes.contains(id))
{
let node = nodes.remove(i);
for i in 0..num_threads {
if !running_threads.contains(&i) {
senders[i].send(node);
break;
}
}
}
i += 1;
}
if nodes.is_empty() {
break;
}
}
for sender in senders {
drop(sender);
}
for handle in handles {
handle.join().expect("Failed to join thread");
}
println!("Process finished");
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use chrono::Local;
use super::{NodeConfiguration, RunnableGraph};
#[test]
fn test_basic() -> anyhow::Result<()> {
let mut graph = RunnableGraph {
node_statuses: HashMap::new(),
graph: super::Graph {
name: "Test".to_owned(),
nodes: vec![super::Node {
id: 1,
dependent_node_ids: vec![],
info: super::NodeInfo {
name: "Hello".to_owned(),
configuration: NodeConfiguration::FilterNode(super::FilterNode {
filters: vec![],
input_file_path: "".to_owned(),
output_file_path: "".to_owned(),
}),
output_files: vec![],
dynamic_configuration: None,
},
last_modified: Local::now(),
}],
},
};
graph.run_default_tasks(2)?;
Ok(())
}
} }

111
src/io.rs
View File

@@ -1,12 +1,21 @@
use std::io::{Read, Seek, Write}; use std::{
collections::BTreeMap,
io::{Read, Seek, Write},
};
use anyhow::bail; use anyhow::bail;
use csv::Position;
use rmp_serde::{decode::ReadReader, Deserializer, Serializer}; use rmp_serde::{decode::ReadReader, Deserializer, Serializer};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
pub trait RecordSerializer { pub trait RecordSerializer {
fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>; fn serialize(&mut self, record: impl Serialize) -> anyhow::Result<()>;
// For when serde serialization can't be used. Forcing BTreeMap to ensure keys/values are
// sorted consistently
fn write_header(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()>;
fn write_record(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()>;
fn flush(&mut self) -> anyhow::Result<()>;
} }
impl<W: Write> RecordSerializer for csv::Writer<W> { impl<W: Write> RecordSerializer for csv::Writer<W> {
@@ -14,6 +23,21 @@ impl<W: Write> RecordSerializer for csv::Writer<W> {
self.serialize(record)?; self.serialize(record)?;
Ok(()) Ok(())
} }
fn flush(&mut self) -> anyhow::Result<()> {
self.flush()?;
Ok(())
}
fn write_header(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()> {
self.write_record(record.keys())?;
Ok(())
}
fn write_record(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()> {
self.write_record(record.values())?;
Ok(())
}
} }
impl<W: Write> RecordSerializer for Serializer<W> { impl<W: Write> RecordSerializer for Serializer<W> {
@@ -21,34 +45,28 @@ impl<W: Write> RecordSerializer for Serializer<W> {
record.serialize(self)?; record.serialize(self)?;
Ok(()) Ok(())
} }
}
// TODO: I still don't like this api, should split deserialize and position at the least, fn flush(&mut self) -> anyhow::Result<()> {
// and we need a way to get the current position (otherwise it's left to consumers to track current) Ok(())
// position }
pub trait RecordDeserializer<P> {
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error>;
// Move the deserializer to the specified position in the underlying reader fn write_header(&mut self, _: &BTreeMap<String, String>) -> anyhow::Result<()> {
fn position(&mut self, record: P) -> anyhow::Result<()>; Ok(())
} }
struct CsvMessagePackDeserializer<R> { fn write_record(&mut self, record: &BTreeMap<String, String>) -> anyhow::Result<()> {
reader: csv::Reader<R>, self.serialize(record)?;
} Ok(())
impl<R: Read> CsvMessagePackDeserializer<R> {
fn new(reader: R) -> CsvMessagePackDeserializer<R> {
CsvMessagePackDeserializer {
reader: csv::Reader::from_reader(reader),
}
} }
} }
impl<R: Read + Seek> RecordDeserializer<Position> for CsvMessagePackDeserializer<R> { pub trait RecordDeserializer {
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error>;
}
impl<R: Read> RecordDeserializer for csv::Reader<R> {
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> { fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> {
// TODO: This isn't great, need to somehow maintain the state/position match self.deserialize().next() {
match self.reader.deserialize().next() {
None => Ok(Option::None), None => Ok(Option::None),
Some(result) => match result { Some(result) => match result {
Ok(ok) => Ok(Option::Some(ok)), Ok(ok) => Ok(Option::Some(ok)),
@@ -56,56 +74,13 @@ impl<R: Read + Seek> RecordDeserializer<Position> for CsvMessagePackDeserializer
}, },
} }
} }
fn position(&mut self, record: Position) -> anyhow::Result<()> {
self.reader.seek(record)?;
Ok(())
}
} }
struct MessagePackDeserializer<R: Read> { impl<R: Read + Seek> RecordDeserializer for Deserializer<ReadReader<R>> {
reader: Deserializer<ReadReader<R>>,
record_positions: Vec<u64>,
}
impl<R: Read + Seek> MessagePackDeserializer<R> {
fn new(reader: R) -> MessagePackDeserializer<R> {
MessagePackDeserializer {
reader: Deserializer::new(reader),
record_positions: vec![],
}
}
}
// TODO: These need tests
impl<R: Read + Seek> RecordDeserializer<usize> for MessagePackDeserializer<R> {
fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> { fn deserialize<D: DeserializeOwned>(&mut self) -> Result<Option<D>, anyhow::Error> {
// Keep track of byte position of each record, in case we want to go back later match Deserialize::deserialize(self) {
let current_position = self.reader.get_mut().stream_position()?;
if self
.record_positions
.last()
.map_or(true, |position| *position < current_position)
{
self.record_positions.push(current_position);
}
match Deserialize::deserialize(&mut self.reader) {
Ok(value) => Ok(value), Ok(value) => Ok(value),
Err(value) => Err(anyhow::Error::from(value)), Err(value) => Err(anyhow::Error::from(value)),
} }
} }
fn position(&mut self, record: usize) -> anyhow::Result<()> {
let reader = self.reader.get_mut();
// Unsigned so can't be less than 0
if self.record_positions.len() > record {
// Go to position in reader
let position = self.record_positions[record];
reader.seek(std::io::SeekFrom::Start(position))?;
} else {
// read through the reader until we get to the correct record
bail!("Record hasn't been read yet, please use deserialize to find the record")
}
Ok(())
}
} }

View File

@@ -12,10 +12,12 @@ pub use self::products::csv::SourceType;
mod shared_models; mod shared_models;
pub use self::shared_models::*; pub use self::shared_models::*;
pub mod code_rule; pub mod code_rule;
pub mod derive;
pub mod filter; pub mod filter;
mod graph; mod graph;
mod io; mod io;
pub mod link; pub mod link;
pub mod node;
pub mod upload_to_db; pub mod upload_to_db;
#[no_mangle] #[no_mangle]
@@ -63,22 +65,21 @@ pub extern "C" fn move_money_from_file(
output_path: *const c_char, output_path: *const c_char,
use_numeric_accounts: bool, use_numeric_accounts: bool,
) { ) {
let mut output_writer = csv::Writer::from_writer(vec![]);
let safe_rules = unwrap_c_char(rules_file); let safe_rules = unwrap_c_char(rules_file);
let safe_lines = unwrap_c_char(lines); let safe_lines = unwrap_c_char(lines);
let safe_accounts = unwrap_c_char(accounts); let safe_accounts = unwrap_c_char(accounts);
let safe_cost_centres = unwrap_c_char(cost_centres); let safe_cost_centres = unwrap_c_char(cost_centres);
// move_money_2() let output_path = unwrap_c_char(output_path);
// move_money( move_money(
// , &mut csv::Reader::from_reader(safe_rules.to_bytes()),
// &mut csv::Reader::from_reader(safe_lines.to_str().unwrap()), &mut csv::Reader::from_reader(safe_lines.to_bytes()),
// &mut csv::Reader::from_reader(safe_accounts.to_bytes()), &mut csv::Reader::from_reader(safe_accounts.to_bytes()),
// &mut csv::Reader::from_reader(safe_cost_centres.to_bytes()), &mut csv::Reader::from_reader(safe_cost_centres.to_bytes()),
// &mut output_writer, &mut csv::Writer::from_path(output_path.to_str().unwrap()).unwrap(),
// use_numeric_accounts, use_numeric_accounts,
// false, false,
// ) )
// .expect("Failed to move money"); .expect("Failed to move money");
} }
#[no_mangle] #[no_mangle]
@@ -87,7 +88,7 @@ pub unsafe extern "C" fn move_money_from_text_free(s: *mut c_char) {
if s.is_null() { if s.is_null() {
return; return;
} }
CString::from_raw(s) let _ = CString::from_raw(s);
}; };
} }
@@ -181,6 +182,6 @@ pub unsafe extern "C" fn allocate_overheads_from_text_free(s: *mut c_char) {
if s.is_null() { if s.is_null() {
return; return;
} }
CString::from_raw(s) let _ = CString::from_raw(s);
}; };
} }

View File

@@ -63,7 +63,7 @@ pub fn link(
// TODO: Check this filters out correctly, as it's filtering on a reference, not a value // TODO: Check this filters out correctly, as it's filtering on a reference, not a value
.unique() .unique()
.collect(); .collect();
let mut source_date_columns: Vec<&String> = linking_rule let source_date_columns: Vec<&String> = linking_rule
.linking_rules .linking_rules
.iter() .iter()
.flat_map(|rule| { .flat_map(|rule| {
@@ -86,7 +86,7 @@ pub fn link(
// TODO: Merge with source_indexes? // TODO: Merge with source_indexes?
// Also store the actual date value rather than string, so we // Also store the actual date value rather than string, so we
// don't need to convert as much later // don't need to convert as much later
let mut source_dates: Vec<HashMap<String, Vec<usize>>>; let source_dates: Vec<HashMap<String, Vec<usize>>>;
for source_record in source_reader.deserialize() { for source_record in source_reader.deserialize() {
let source_record: HashMap<String, String> = source_record?; let source_record: HashMap<String, String> = source_record?;
let current_idx = source_ids.len(); let current_idx = source_ids.len();

6
src/node.rs Normal file
View File

@@ -0,0 +1,6 @@
pub trait RunnableNode {
// TODO: Get inputs/outputs to determine whether we can skip running this task
// TODO: Status
fn run(&self) -> anyhow::Result<()>;
}

View File

@@ -11,7 +11,7 @@ use polars::lazy::dsl::*;
use polars::prelude::*; use polars::prelude::*;
use serde::Serialize; use serde::Serialize;
use super::csv::{read_definitions, Component, Definition, FileJoin, SourceType}; use super::csv::{read_definitions, Component, Definition};
// TODO: Polars suggests this, but docs suggest it doesn't have very good platform support // TODO: Polars suggests this, but docs suggest it doesn't have very good platform support
//use jemallocator::Jemalloc; //use jemallocator::Jemalloc;

View File

@@ -1,7 +1,12 @@
use std::{collections::HashMap, io::Read}; use std::collections::HashMap;
use csv::Reader; use anyhow::bail;
use sqlx::{query, query_builder, Any, Mssql, Pool, QueryBuilder}; use serde::{Deserialize, Serialize};
use sqlx::{Any, Pool, QueryBuilder};
use crate::node::RunnableNode;
const BIND_LIMIT: usize = 65535;
// Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any // Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any
// type for sqlx 0.6 and earlier due to a query_builder lifetime issue, // type for sqlx 0.6 and earlier due to a query_builder lifetime issue,
@@ -12,44 +17,32 @@ use sqlx::{query, query_builder, Any, Mssql, Pool, QueryBuilder};
// TODO: Add bulk insert options for non-mssql dbs // TODO: Add bulk insert options for non-mssql dbs
// TODO: Add fallback insert when bulk insert fails (e.g. due to // TODO: Add fallback insert when bulk insert fails (e.g. due to
// permission errors) // permission errors)
pub async fn upload_file_bulk( pub async fn upload_file_bulk(pool: &Pool<Any>, upload_node: &UploadNode) -> anyhow::Result<u64> {
pool: &Pool<sqlx::Mssql>, let mut rows_affected = None;
file_name: &String, if upload_node.column_mappings.is_none() {
table_name: &String, let insert_from_file_query = match pool.connect_options().database_url.scheme() {
// Mappings from column in file -> column in db "postgres" => Some(format!("COPY {} FROM $1", upload_node.table_name)),
column_mappings: Option<HashMap<String, String>>, "mysql" => Some(format!(
post_script: Option<String>, "LOAD DATA INFILE ? INTO {}",
) -> anyhow::Result<u64> { upload_node.table_name,
// TODO: Test if the table already exists. If it doesn't, try creating the table )),
_ => None,
// First try a bulk insert command };
// let result = match pool.any_kind() { if let Some(insert_from_file_query) = insert_from_file_query {
// sqlx::any::AnyKind::Mssql => { let result = sqlx::query(&insert_from_file_query)
let result = sqlx::query(&format!("BULK INSERT {} FROM {}", table_name, file_name)) .bind(&upload_node.file_path)
.execute(pool) .execute(pool)
.await?; .await?;
// } rows_affected = Some(result.rows_affected());
// }; }
}
let mut rows_affected = result.rows_affected(); if rows_affected == None {
// let mut rows_affected = match &result {
// Result::Ok(result) => result.rows_affected(),
// // TODO: Log error
// Err(error) => 0_u64,
// };
// TODO: Adjust for various dbmss
if rows_affected == 0 {
let rows: Vec<HashMap<String, String>> = vec![]; let rows: Vec<HashMap<String, String>> = vec![];
let BIND_LIMIT: usize = 65535;
// TODO: Use csv to read from file
// TODO: When bulk insert fails, Fall back to sql batched insert
// TODO: Columns to insert... needs some kind of mapping from file column name <-> db column // TODO: Columns to insert... needs some kind of mapping from file column name <-> db column
let mut query_builder = QueryBuilder::new(format!("INSERT INTO {}({}) ", table_name, "")); let mut query_builder =
QueryBuilder::new(format!("INSERT INTO {}({}) ", upload_node.table_name, ""));
// TODO: Iterate over all values in file, not the limit // TODO: Iterate over all values in file, not the limit
query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| { query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| {
b.push_bind(row.get("s")); b.push_bind(row.get("s"));
@@ -58,14 +51,40 @@ pub async fn upload_file_bulk(
// TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978 // TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978
// Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx // Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx
// to use this, unless we explicity specified mssql only, not Any as the db type... // to use this, unless we explicity specified mssql only, not Any as the db type...
// Can probably work around this by specifying an actual implementation?
let query = query_builder.build(); let query = query_builder.build();
let result = query.execute(pool).await?; let result = query.execute(pool).await?;
rows_affected = result.rows_affected(); rows_affected = Some(result.rows_affected());
} }
if let Some(post_script) = post_script { if let Some(post_script) = &upload_node.post_script {
sqlx::query(&post_script).execute(pool).await?; sqlx::query(&post_script).execute(pool).await?;
} }
Ok(rows_affected) match rows_affected {
} Some(rows_affected) => Ok(rows_affected),
None => bail!(
"Invalid state, rows_affected must be populated, or an error should have occurred"
),
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct UploadNode {
file_path: String,
table_name: String,
// Mappings from column in file -> column in db
column_mappings: Option<HashMap<String, String>>,
post_script: Option<String>,
}
pub struct UploadNodeRunner {
pub upload_node: UploadNode,
}
impl RunnableNode for UploadNodeRunner {
fn run(&self) -> anyhow::Result<()> {
// TODO: Get db connection from some kind of property manager/context
todo!()
}
}