diff --git a/Cargo.lock b/Cargo.lock index 965fa53..c889bfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom", + "getrandom 0.2.15", "once_cell", "version_check", "zerocopy", @@ -139,6 +139,42 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76" +[[package]] +name = "async-native-tls" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d57d4cec3c647232e1094dc013546c0b33ce785d8aeb251e1f20dfaf8a9a13fe" +dependencies = [ + "futures-util", + "native-tls", + "thiserror", + "url", +] + +[[package]] +name = "async-trait" +version = "0.1.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "asynchronous-codec" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4057f2c32adbb2fc158e22fb38433c8e9bbf76b75a4732c7c0cbaf695fb65568" +dependencies = [ + "bytes", + "futures-sink", + "futures-util", + "memchr", + "pin-project-lite", +] + [[package]] name = "atoi" version = "2.0.0" @@ -375,12 +411,28 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "connection-string" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "510ca239cf13b7f8d16a2b48f263de7b4f8c566f0af58d901031473c76afb1e3" + [[package]] name = "const-oid" version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -392,10 +444,13 @@ name = "coster-rs" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", "chrono", "clap", "csv", "env_logger", + "futures", + "futures-io", "itertools", "log", "nalgebra", @@ -409,7 +464,9 @@ dependencies = [ "serde_json", "sqlx", "tempfile", + "tiberius", "tokio", + "tokio-util", ] [[package]] @@ -576,6 +633,15 @@ dependencies = [ "serde", ] +[[package]] +name = "encoding_rs" +version = "0.8.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b45de904aa0b010bce2ab45264d0631681847fa7b6f2eaa7dab7619943bc4f59" +dependencies = [ + "cfg-if", +] + [[package]] name = "enum_dispatch" version = "0.3.13" @@ -588,6 +654,26 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "enumflags2" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d232db7f5956f3f14313dc2f87985c58bd2c695ce124c8cdd984e08e15ac133d" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de0d48a183585823424a4ce1aa132d174a6a81bd540895822eb4c8373a8e49e8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "env_filter" version = "0.1.2" @@ -684,6 +770,21 @@ dependencies = [ "spin", ] +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign_vec" version = "0.1.0" @@ -699,6 +800,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.30" @@ -743,6 +859,17 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + [[package]] name = "futures-sink" version = "0.3.30" @@ -761,8 +888,10 @@ version = "0.3.30" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" dependencies = [ + "futures-channel", "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -781,6 +910,17 @@ dependencies = [ "version_check", ] +[[package]] +name = "getrandom" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.9.0+wasi-snapshot-preview1", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -790,7 +930,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -1060,6 +1200,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e6bcd6433cff03a4bfc3d9834d504467db1f1cf6d0ea765d37d330249ed629d" + [[package]] name = "memchr" version = "2.7.4" @@ -1098,7 +1244,7 @@ checksum = "4569e456d394deccd22ce1c1913e6ea0e54519f577285001215d33557431afe4" dependencies = [ "hermit-abi", "libc", - "wasi", + "wasi 0.11.0+wasi-snapshot-preview1", "windows-sys 0.52.0", ] @@ -1151,6 +1297,23 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nom" version = "7.1.3" @@ -1201,7 +1364,7 @@ dependencies = [ "num-integer", "num-iter", "num-traits", - "rand", + "rand 0.8.5", "smallvec", "zeroize", ] @@ -1281,6 +1444,50 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" +[[package]] +name = "openssl" +version = "0.10.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9529f4786b70a3e8c61e11179af17ab6188ad8d0ded78c5529441ed39d4bd9c1" +dependencies = [ + "bitflags 2.6.0", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.72", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.103" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f9e8deee91df40a943c71b917e5874b951d32a802526c85721ce3b776c929d6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "parking" version = "2.2.0" @@ -1372,7 +1579,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] @@ -1438,7 +1645,7 @@ version = "0.41.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e3351ea4570e54cd556e6755b78fe7a2c85368d820c0307cca73c96e796a7ba" dependencies = [ - "getrandom", + "getrandom 0.2.15", "polars-arrow", "polars-core", "polars-error", @@ -1469,7 +1676,7 @@ dependencies = [ "ethnum", "fast-float", "foreign_vec", - "getrandom", + "getrandom 0.2.15", "hashbrown", "itoa", "itoap", @@ -1535,7 +1742,7 @@ dependencies = [ "polars-error", "polars-row", "polars-utils", - "rand", + "rand 0.8.5", "rand_distr", "rayon", "regex", @@ -1780,7 +1987,7 @@ dependencies = [ "polars-ops", "polars-plan", "polars-time", - "rand", + "rand 0.8.5", "serde", "serde_json", "sqlparser", @@ -1838,6 +2045,12 @@ dependencies = [ "zerocopy-derive", ] +[[package]] +name = "pretty-hex" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa0831dd7cc608c38a5e323422a0077678fa5744aa2be4ad91c4ece8eec8d5" + [[package]] name = "proc-macro2" version = "1.0.86" @@ -1865,6 +2078,19 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +dependencies = [ + "getrandom 0.1.16", + "libc", + "rand_chacha 0.2.2", + "rand_core 0.5.1", + "rand_hc", +] + [[package]] name = "rand" version = "0.8.5" @@ -1872,8 +2098,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +dependencies = [ + "ppv-lite86", + "rand_core 0.5.1", ] [[package]] @@ -1883,7 +2119,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_core" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +dependencies = [ + "getrandom 0.1.16", ] [[package]] @@ -1892,7 +2137,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -1902,7 +2147,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", +] + +[[package]] +name = "rand_hc" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" +dependencies = [ + "rand_core 0.5.1", ] [[package]] @@ -2015,7 +2269,7 @@ checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", "cfg-if", - "getrandom", + "getrandom 0.2.15", "libc", "spin", "untrusted", @@ -2057,7 +2311,7 @@ dependencies = [ "num-traits", "pkcs1", "pkcs8", - "rand_core", + "rand_core 0.6.4", "signature", "spki", "subtle", @@ -2134,6 +2388,15 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "schemars" version = "0.8.21" @@ -2175,6 +2438,29 @@ dependencies = [ "untrusted", ] +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags 2.6.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.204" @@ -2268,7 +2554,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -2496,7 +2782,7 @@ dependencies = [ "memchr", "once_cell", "percent-encoding", - "rand", + "rand 0.8.5", "rsa", "serde", "sha1", @@ -2535,7 +2821,7 @@ dependencies = [ "md-5", "memchr", "once_cell", - "rand", + "rand 0.8.5", "serde", "serde_json", "sha2", @@ -2726,6 +3012,33 @@ dependencies = [ "syn 2.0.72", ] +[[package]] +name = "tiberius" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1446cb4198848d1562301a3340424b4f425ef79f35ef9ee034769a9dd92c10d" +dependencies = [ + "async-native-tls", + "async-trait", + "asynchronous-codec", + "byteorder", + "bytes", + "chrono", + "connection-string", + "encoding_rs", + "enumflags2", + "futures-util", + "num-traits", + "once_cell", + "pin-project-lite", + "pretty-hex", + "thiserror", + "tokio", + "tracing", + "uuid", + "winauth", +] + [[package]] name = "tinyvec" version = "1.8.0" @@ -2781,6 +3094,20 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-io", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.40" @@ -2902,7 +3229,7 @@ version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" dependencies = [ - "getrandom", + "getrandom 0.2.15", ] [[package]] @@ -2917,6 +3244,12 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" +[[package]] +name = "wasi" +version = "0.9.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -3031,6 +3364,19 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "winauth" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f820cd208ce9c6b050812dc2d724ba98c6c1e9db5ce9b3f58d925ae5723a5e6" +dependencies = [ + "bitflags 1.3.2", + "byteorder", + "md5", + "rand 0.7.3", + "winapi", +] + [[package]] name = "windows" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 48070d6..2935cb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,11 @@ num_cpus = "1.16.0" schemars = {version = "0.8.21", features = ["chrono"]} log = "0.4.22" env_logger = "0.11.5" +tiberius = {version = "0.12.3", features = ["chrono", "tokio"]} +futures-io = "0.3.30" +futures = "0.3.30" +tokio-util = {version = "0.7.11", features = ["compat"]} +async-trait = "0.1.81" # More info on targets: https://doc.rust-lang.org/cargo/reference/cargo-targets.html#configuring-a-target [lib] diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 31a4e8e..5140a07 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -174,9 +174,11 @@ impl Cli { let reader = BufReader::new(file); let graph = serde_json::from_reader(reader)?; let graph = RunnableGraph::from_graph(graph); + // TODO: Possible to await here? graph.run_default_tasks(threads, |id, status| { info!("Node with id {} finished with status {:?}", id, status) - }) + }); + Ok(()) } Commands::GenerateSchema { output } => { let schema = schema_for!(Graph); diff --git a/src/graph/derive.rs b/src/graph/derive.rs index fc6fffa..dbc4808 100644 --- a/src/graph/derive.rs +++ b/src/graph/derive.rs @@ -1,5 +1,6 @@ use std::{collections::BTreeMap, str::FromStr}; +use async_trait::async_trait; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -298,8 +299,9 @@ pub struct DeriveNodeRunner { derive_node: DeriveNode, } +#[async_trait] impl RunnableNode for DeriveNodeRunner { - fn run(&self) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { let mut reader = csv::Reader::from_path(&self.derive_node.input_file_path)?; let mut writer = csv::Writer::from_path(&self.derive_node.output_file_path)?; let rules: anyhow::Result> = self diff --git a/src/graph/filter.rs b/src/graph/filter.rs index db45359..dacafdd 100644 --- a/src/graph/filter.rs +++ b/src/graph/filter.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use async_trait::async_trait; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -49,8 +50,9 @@ pub struct FilterNodeRunner { pub filter_node: FilterNode, } +#[async_trait] impl RunnableNode for FilterNodeRunner { - fn run(&self) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { let mut reader = csv::Reader::from_path(&self.filter_node.input_file_path)?; let mut writer = csv::Writer::from_path(&self.filter_node.output_file_path)?; let rules = derive::to_filter_rules(&self.filter_node.filters)?; diff --git a/src/graph/mod.rs b/src/graph/mod.rs index cbc9410..c10fab2 100644 --- a/src/graph/mod.rs +++ b/src/graph/mod.rs @@ -2,16 +2,16 @@ use std::{ cmp::{min, Ordering}, collections::{HashMap, HashSet}, sync::{ - mpsc::{self, Sender}, - Arc, + mpsc, Arc }, - thread, }; use chrono::Local; +use futures::lock::Mutex; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use split::{SplitNode, SplitNodeRunner}; +use tokio::sync::mpsc::Sender; use { derive::DeriveNode, @@ -24,7 +24,9 @@ use { mod derive; mod filter; mod node; +mod pull_from_db; mod split; +mod sql; mod sql_rule; mod upload_to_db; @@ -131,7 +133,7 @@ impl Node { } } -fn get_runnable_node(node: Node) -> Box { +fn get_runnable_node(node: Node) -> Box { match node.info.configuration { NodeConfiguration::FileNode => todo!(), NodeConfiguration::MoveMoneyNode(_) => todo!(), @@ -169,7 +171,7 @@ impl RunnableGraph { RunnableGraph { graph } } - pub fn run_default_tasks(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()> + pub async fn run_default_tasks(&self, num_threads: usize, status_changed: F) -> anyhow::Result<()> where F: Fn(i64, NodeStatus), { @@ -177,18 +179,18 @@ impl RunnableGraph { num_threads, Box::new(|node| get_runnable_node(node)), status_changed, - ) + ).await } // Make this not mutable, emit node status when required in a function or some other message - pub fn run<'a, F, StatusChanged>( + pub async fn run<'a, F, StatusChanged>( &self, num_threads: usize, get_node_fn: F, node_status_changed_fn: StatusChanged, ) -> anyhow::Result<()> where - F: Fn(Node) -> Box + Send + Sync + 'static, + F: Fn(Node) -> Box + Send + Sync + 'static, StatusChanged: Fn(i64, NodeStatus), { let mut nodes = self.graph.nodes.clone(); @@ -209,7 +211,7 @@ impl RunnableGraph { if num_threads < 2 { for node in &nodes { node_status_changed_fn(node.id, NodeStatus::Running); - match get_node_fn(node.clone()).run() { + match get_node_fn(node.clone()).run().await { Ok(_) => node_status_changed_fn(node.id, NodeStatus::Completed), Err(err) => node_status_changed_fn(node.id, NodeStatus::Failed(err)), }; @@ -226,15 +228,18 @@ impl RunnableGraph { let node_fn = Arc::new(get_node_fn); for n in 0..num_threads { let finish_task = finish_task.clone(); - let (tx, rx) = mpsc::channel(); + // let finish_task = finish_task.clone(); + let (tx, mut rx) = tokio::sync::mpsc::channel(32); senders.push(tx); let node_fn = node_fn.clone(); - let handle = thread::spawn(move || { - for node in rx { - let status = match node_fn(node.clone()).run() { + // TODO: Think this needs to be all reworked to be more inline with async + let handle = tokio::spawn(async move { + for node in rx.recv().await { + let status = match node_fn(node.clone()).run().await { Ok(_) => NodeStatus::Completed, Err(err) => NodeStatus::Failed(err), }; + let status = status; finish_task .send((n, node, status)) .expect("Failed to notify node status completion"); @@ -258,7 +263,7 @@ impl RunnableGraph { let node = nodes.remove(i); node_status_changed_fn(node.id, NodeStatus::Running); running_nodes.insert(node.id); - senders[i % senders.len()].send(node)?; + senders[i % senders.len()].send(node).await?; } } @@ -280,7 +285,7 @@ impl RunnableGraph { let node = nodes.remove(i); for i in 0..num_threads { if !running_threads.contains(&i) { - senders[i].send(node)?; + senders[i].send(node).await?; break; } } @@ -296,7 +301,7 @@ impl RunnableGraph { } for handle in handles { - handle.join().expect("Failed to join thread"); + handle.await.expect("Failed to join thread"); } println!("Process finished"); Ok(()) @@ -310,8 +315,8 @@ mod tests { use super::{NodeConfiguration, RunnableGraph}; - #[test] - fn test_basic() -> anyhow::Result<()> { + #[tokio::test] + async fn test_basic() -> anyhow::Result<()> { let graph = RunnableGraph { graph: super::Graph { name: "Test".to_owned(), @@ -332,7 +337,7 @@ mod tests { }], }, }; - graph.run_default_tasks(2, |_, _| {})?; + graph.run_default_tasks(2, |_, _| {}).await?; Ok(()) } } diff --git a/src/graph/node.rs b/src/graph/node.rs index 4e41188..ca9efde 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -1,6 +1,11 @@ +use async_trait::async_trait; + +#[async_trait] pub trait RunnableNode { // TODO: Get inputs/outputs to determine whether we can skip running this task // TODO: Status - fn run(&self) -> anyhow::Result<()>; + + // TODO: Is it possible to make this async? + async fn run(&self) -> anyhow::Result<()>; } diff --git a/src/graph/pull_from_db.rs b/src/graph/pull_from_db.rs new file mode 100644 index 0000000..a0210f4 --- /dev/null +++ b/src/graph/pull_from_db.rs @@ -0,0 +1,12 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; + +use super::sql::QueryExecutor; + +/** + * Pull data from a db using a db query into a csv file that can be used by another node + */ +fn pull_from_db(executor: &mut impl QueryExecutor, node: PullFromDBNode) {} + +#[derive(Serialize, Deserialize, Clone, JsonSchema)] +pub struct PullFromDBNode {} diff --git a/src/graph/split.rs b/src/graph/split.rs index b57a27c..bbf7c96 100644 --- a/src/graph/split.rs +++ b/src/graph/split.rs @@ -1,5 +1,6 @@ -use std::{collections::BTreeMap, fs::File}; +use std::collections::BTreeMap; +use async_trait::async_trait; use chrono::DateTime; use polars::{ io::SerWriter, @@ -9,7 +10,7 @@ use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use tempfile::tempfile; -use crate::io::{RecordDeserializer, RecordSerializer}; +use crate::io::RecordSerializer; use super::{ derive::{self, DataValidator, DeriveFilter}, @@ -149,8 +150,9 @@ fn split( Ok(()) } +#[async_trait] impl RunnableNode for SplitNodeRunner { - fn run(&self) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { let mut output = csv::Writer::from_path(&self.split_node.output_file_path)?; let rules: anyhow::Result> = self .split_node diff --git a/src/graph/sql.rs b/src/graph/sql.rs new file mode 100644 index 0000000..357d4dd --- /dev/null +++ b/src/graph/sql.rs @@ -0,0 +1,101 @@ +use std::borrow::Borrow; + +use futures::TryStreamExt; +use futures_io::{AsyncRead, AsyncWrite}; +use itertools::Itertools; +use sqlx::{Any, AnyPool, Column, Pool, Row}; +use tiberius::{Client, Query}; + +// TODO: This doesn't seem to work. Suggestion by compiler is to instead create an enum and implement +// the trait on the enum (basically use a match in the implementation depending on which enum we have) +pub trait QueryExecutor { + // TODO: Params binding for filtering the same query? + // Retrieve data from a database + async fn get_rows( + &mut self, + query: &str, + params: &Vec, + ) -> anyhow::Result>>; + + // Run a query that returns no results (e.g. bulk insert, insert) + async fn execute_query(&mut self, query: &str, params: &Vec) -> anyhow::Result; +} + +impl QueryExecutor for Client { + async fn get_rows( + &mut self, + query: &str, + params: &Vec, + ) -> anyhow::Result>> { + let mut query = Query::new(query); + for param in params { + query.bind(param); + } + let query_result = query.query(self).await?; + let results = query_result.into_first_result().await?; + let results = results + .into_iter() + .map(|row| { + row.columns() + .into_iter() + .map(|column| { + ( + column.name().to_owned(), + match row.get(column.name()) { + Some(value) => value, + None => "", + } + .to_owned(), + ) + }) + .collect_vec() + }) + .collect(); + Ok(results) + } + + async fn execute_query(&mut self, query: &str, params: &Vec) -> anyhow::Result { + let mut query = Query::new(query); + for param in params { + query.bind(param); + } + let result = query.execute(self).await?; + if result.rows_affected().len() == 0 { + return Ok(0); + } + Ok(result.rows_affected()[0]) + } +} + +impl QueryExecutor for Pool { + async fn get_rows( + &mut self, + query: &str, + params: &Vec, + ) -> anyhow::Result>> { + let mut query = sqlx::query(query); + for param in params { + query = query.bind(param); + } + let mut rows = query.fetch(self.borrow()); + let mut results = vec![]; + while let Some(row) = rows.try_next().await? { + results.push( + row.columns() + .into_iter() + .map(|column| (column.name().to_owned(), row.get(column.name()))) + .collect(), + ); + } + Ok(results) + } + + async fn execute_query(&mut self, query: &str, params: &Vec) -> anyhow::Result { + let mut query = sqlx::query(query); + for param in params { + query = query.bind(param); + } + let result = query.execute(self.borrow()).await?; + Ok(result.rows_affected()) + } +} diff --git a/src/graph/sql_rule.rs b/src/graph/sql_rule.rs index 70a237f..d3e9c9a 100644 --- a/src/graph/sql_rule.rs +++ b/src/graph/sql_rule.rs @@ -1,5 +1,6 @@ use std::fs::File; +use async_trait::async_trait; use polars::{ io::SerWriter, prelude::{CsvWriter, LazyCsvReader, LazyFileListReader}, @@ -42,8 +43,9 @@ pub struct SQLNodeRunner { pub sql_node: SQLNode, } +#[async_trait] impl RunnableNode for SQLNodeRunner { - fn run(&self) -> anyhow::Result<()> { + async fn run(&self) -> anyhow::Result<()> { run_sql( &self.sql_node.files, &self.sql_node.output_file, diff --git a/src/graph/upload_to_db.rs b/src/graph/upload_to_db.rs index 2e9e957..f284bea 100644 --- a/src/graph/upload_to_db.rs +++ b/src/graph/upload_to_db.rs @@ -1,65 +1,89 @@ use std::collections::HashMap; use anyhow::bail; +use async_trait::async_trait; +use futures::executor; +use itertools::Itertools; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -use sqlx::{Any, Pool, QueryBuilder}; +use sqlx::{AnyPool}; +use tiberius::Config; +use tokio::{ task}; +use tokio_util::compat::TokioAsyncWriteCompatExt; -use super::node::RunnableNode; +use super::{node::RunnableNode, sql::QueryExecutor}; const BIND_LIMIT: usize = 65535; -// Note: right now this is set to mssql only, since sqlx 0.7 is requried to use the Any -// type for sqlx 0.6 and earlier due to a query_builder lifetime issue, -// however sqlx >=0.7 currently doesn't support mssql. - -// Upload data in a file to a db table, with an optional post-script to run, -// such as to move data from the upload table into other tables -// TODO: Add bulk insert options for non-mssql dbs -// TODO: Add fallback insert when bulk insert fails (e.g. due to -// permission errors) -pub async fn upload_file_bulk(pool: &Pool, upload_node: &UploadNode) -> anyhow::Result { +pub async fn upload_file_bulk( + executor: &mut impl QueryExecutor, + upload_node: &UploadNode, +) -> anyhow::Result { let mut rows_affected = None; if upload_node.column_mappings.is_none() { - let insert_from_file_query = match pool.connect_options().database_url.scheme() { - "postgres" => Some(format!("COPY {} FROM $1", upload_node.table_name)), - "mysql" => Some(format!( + let insert_from_file_query = match upload_node.db_type { + DBType::Postgres => Some(format!("COPY {} FROM $1", upload_node.table_name)), + DBType::Mysql => Some(format!( "LOAD DATA INFILE ? INTO {}", upload_node.table_name, )), + DBType::Mssql => Some(format!("BULK INSERT {} FROM ?", upload_node.table_name)), _ => None, }; if let Some(insert_from_file_query) = insert_from_file_query { - let result = sqlx::query(&insert_from_file_query) - .bind(&upload_node.file_path) - .execute(pool) + let result = executor + .execute_query( + &insert_from_file_query, + &vec![upload_node.file_path.clone()], + ) .await?; - rows_affected = Some(result.rows_affected()); + rows_affected = Some(result); } } if rows_affected == None { - let rows: Vec> = vec![]; + let mut file_reader = csv::Reader::from_path(upload_node.file_path.clone())?; - // TODO: Columns to insert... needs some kind of mapping from file column name <-> db column - let mut query_builder = - QueryBuilder::new(format!("INSERT INTO {}({}) ", upload_node.table_name, "")); - // TODO: Iterate over all values in file, not the limit - query_builder.push_values(&rows[0..BIND_LIMIT], |mut b, row| { - b.push_bind(row.get("s")); - }); - let mut query_builder = query_builder; - // TODO: Looks like this issue: https://github.com/launchbadge/sqlx/issues/1978 - // Turns out we need v0.7 for this to not bug out, however mssql is only supported in versions before v0.7, so right now can't use sqlx - // to use this, unless we explicity specified mssql only, not Any as the db type... - // Can probably work around this by specifying an actual implementation? - let query = query_builder.build(); - let result = query.execute(pool).await?; - rows_affected = Some(result.rows_affected()); + let csv_columns = file_reader.headers()?.iter().map(|header| header.to_owned()).collect_vec(); + let table_columns = if let Some(column_mappings) = &upload_node.column_mappings { + csv_columns + .iter() + .map(|column| { + column_mappings + .get(column).unwrap_or(column) + .clone() + }) + .collect_vec() + } else { + csv_columns.clone() + }; + let query_template = format!("INSERT INTO {}({}) \n", upload_node.table_name, table_columns.join(",")); + let mut params = vec![]; + let mut insert_query = "".to_owned(); + let mut num_params = 0; + let mut running_row_total = 0; + + for result in file_reader.records() { + let result = result?; + insert_query = insert_query + format!("VALUES ({})", result.iter().map(|_| "?").join(",")).as_str(); + let mut values = result.iter().map(|value| value.to_owned()).collect_vec(); + params.append(&mut values); + num_params += csv_columns.len(); + if num_params == BIND_LIMIT { + running_row_total += executor.execute_query(&query_template, ¶ms).await?; + insert_query = "".to_owned(); + params = vec![]; + num_params = 0; + } + } + if !insert_query.is_empty() { + running_row_total += executor.execute_query(&query_template, ¶ms).await?; + } + rows_affected = Some(running_row_total); } if let Some(post_script) = &upload_node.post_script { - sqlx::query(&post_script).execute(pool).await?; + executor.execute_query(post_script, &vec![]).await?; } match rows_affected { @@ -70,6 +94,14 @@ pub async fn upload_file_bulk(pool: &Pool, upload_node: &UploadNode) -> any } } +#[derive(Serialize, Deserialize, Clone, JsonSchema, PartialEq)] +pub enum DBType { + Mysql, + Postgres, + Mssql, + Sqlite, +} + #[derive(Serialize, Deserialize, Clone, JsonSchema)] pub struct UploadNode { file_path: String, @@ -77,15 +109,39 @@ pub struct UploadNode { // Mappings from column in file -> column in db column_mappings: Option>, post_script: Option, + db_type: DBType, + connection_string: String, } pub struct UploadNodeRunner { pub upload_node: UploadNode, } +#[async_trait] impl RunnableNode for UploadNodeRunner { - fn run(&self) -> anyhow::Result<()> { - // TODO: Get db connection from some kind of property manager/context - todo!() + async fn run(&self) -> anyhow::Result<()> { + let upload_node = self.upload_node.clone(); + if upload_node.db_type == DBType::Mssql { + let mut config = Config::from_jdbc_string(&upload_node.connection_string); + if let Ok(mut config) = config { + let tcp = tokio::net::TcpStream::connect(config.get_addr()).await; + if let Ok(tcp) = tcp { + tcp.set_nodelay(true); + let client = tiberius::Client::connect(config, tcp.compat_write()).await; + if let Ok(mut client) = client { + upload_file_bulk(&mut client, &upload_node).await; + } + } + } + }else { + let mut pool = AnyPool::connect(&upload_node.connection_string).await; + if let Ok(mut pool) = pool { + upload_file_bulk(&mut pool, &upload_node).await; + } + } + // TODO: Message to listen for task completing since join handle doesn't include this + // Alternative is to make run signature async, though that may add more complexity + // to graph mode. + Ok(()) } }