diff --git a/pgml-extension/Cargo.lock b/pgml-extension/Cargo.lock index acf5e52f2..fbbb90e9d 100644 --- a/pgml-extension/Cargo.lock +++ b/pgml-extension/Cargo.lock @@ -34,9 +34,9 @@ checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anyhow" -version = "1.0.77" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9d19de80eff169429ac1e9f48fffb163916b448a44e8e046186232046d9e1f9" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" [[package]] name = "approx" @@ -120,13 +120,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.75" +version = "0.1.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdf6721fb0140e4f897002dd086c06f6c27775df19cfe1fccb21181a48fd2c98" +checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -210,11 +210,11 @@ dependencies = [ "peeking_take_while", "prettyplease", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "regex", "rustc-hash", "shlex", - "syn 2.0.43", + "syn 2.0.46", "which", ] @@ -358,9 +358,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clang-sys" -version = "1.6.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" dependencies = [ "glob", "libc", @@ -405,8 +405,8 @@ checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -564,7 +564,7 @@ dependencies = [ "fnv", "ident_case", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "strsim", "syn 1.0.109", ] @@ -576,15 +576,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" dependencies = [ "darling_core", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", ] [[package]] name = "deranged" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8eb30d70a07a3b04884d2677f06bec33509dc67ca60d92949e5535352d3191dc" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", ] @@ -627,7 +627,7 @@ checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" dependencies = [ "darling", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", ] @@ -742,8 +742,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -754,9 +754,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "erased-serde" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4adbf0983fe06bd3a5c19c8477a637c2389feb0994eca7a59e3b961054aa7c0a" +checksum = "55d05712b2d8d88102bc9868020c9e5c7a1f5527c452b9b97450a1d006140ba7" dependencies = [ "serde", ] @@ -886,8 +886,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -1053,7 +1053,6 @@ checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", "hashbrown 0.12.3", - "serde", ] [[package]] @@ -1064,6 +1063,7 @@ checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f" dependencies = [ "equivalent", "hashbrown 0.14.3", + "serde", ] [[package]] @@ -1098,9 +1098,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" dependencies = [ "either", ] @@ -1149,12 +1149,12 @@ checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "libloading" -version = "0.7.4" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" +checksum = "c571b676ddfc9a8c12f1f3d3085a7b163966a8fd8098a90640953ce5f6170161" dependencies = [ "cfg-if", - "winapi", + "windows-sys 0.48.0", ] [[package]] @@ -1609,8 +1609,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -1731,8 +1731,8 @@ dependencies = [ "csv", "flate2", "heapless", - "indexmap 1.9.3", - "itertools 0.11.0", + "indexmap 2.1.0", + "itertools 0.12.0", "lightgbm", "linfa", "linfa-linear", @@ -1790,7 +1790,7 @@ checksum = "a18ac8628b7de2f29a93d0abdbdcaee95a0e0ef4b59fd4de99cc117e166e843b" dependencies = [ "pgrx-sql-entity-graph", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", ] @@ -1828,7 +1828,7 @@ dependencies = [ "pgrx-pg-config", "pgrx-sql-entity-graph", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "serde", "shlex", "sptr", @@ -1846,7 +1846,7 @@ dependencies = [ "eyre", "petgraph", "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "syn 1.0.109", "unescape", ] @@ -1968,19 +1968,19 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae005bd773ab59b4725093fd7df83fd7892f7d8eafb48dbd7de6e024e4215f9d" +checksum = "a41cf62165e97c7f814d2221421dbb9afcbcdb0a88068e5ea206e19951c2cbb5" dependencies = [ "proc-macro2", - "syn 2.0.43", + "syn 2.0.46", ] [[package]] name = "proc-macro2" -version = "1.0.71" +version = "1.0.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +checksum = "2de98502f212cfcea8d0bb305bd0f49d7ebdd75b64ba0a68f937d888f4e0d6db" dependencies = [ "unicode-ident", ] @@ -2007,9 +2007,9 @@ dependencies = [ [[package]] name = "pyo3" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +checksum = "e82ad98ce1991c9c70c3464ba4187337b9c45fcbbb060d46dca15f0c075e14e2" dependencies = [ "cfg-if", "indoc", @@ -2024,9 +2024,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +checksum = "5503d0b3aee2c7a8dbb389cd87cd9649f675d4c7f60ca33699a3e3859d81a891" dependencies = [ "once_cell", "target-lexicon", @@ -2034,9 +2034,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +checksum = "18a79e8d80486a00d11c0dcb27cd2aa17c022cc95c677b461f01797226ba8f41" dependencies = [ "libc", "pyo3-build-config", @@ -2044,26 +2044,26 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +checksum = "1f4b0dc7eaa578604fab11c8c7ff8934c71249c61d4def8e272c76ed879f03d4" dependencies = [ "proc-macro2", "pyo3-macros-backend", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] name = "pyo3-macros-backend" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +checksum = "816a4f709e29ddab2e3cdfe94600d554c5556cad0ddfeea95c47b580c3247fa4" dependencies = [ "heck", "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -2080,9 +2080,9 @@ checksum = "7a6e920b65c65f10b2ae65c831a81a073a89edd28c7cce89475bff467ab4167a" [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] @@ -2279,7 +2279,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.20", + "semver 1.0.21", ] [[package]] @@ -2404,9 +2404,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "836fa6a3e1e547f9a2c4040802ec865b5d85f4014efe00555d7090a3dcaa1090" +checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" [[package]] name = "semver-parser" @@ -2425,9 +2425,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.193" +version = "1.0.194" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +checksum = "0b114498256798c94a0689e1a15fec6005dee8ac1f41de56404b67afc2a4b773" dependencies = [ "serde_derive", ] @@ -2444,20 +2444,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.193" +version = "1.0.194" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +checksum = "a3385e45322e8f9931410f01b3031ec534c3947d0e94c18049af4d9f9907d4e0" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] name = "serde_json" -version = "1.0.108" +version = "1.0.110" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +checksum = "6fbd975230bada99c8bb618e0c365c2eefa219158d5c6c29610fd09ff1833257" dependencies = [ "indexmap 2.1.0", "itoa", @@ -2659,18 +2659,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" dependencies = [ "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "unicode-ident", ] [[package]] name = "syn" -version = "2.0.43" +version = "2.0.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" +checksum = "89456b690ff72fddcecf231caedbe615c59480c93358a93dfae7fc29e3ebbf0e" dependencies = [ "proc-macro2", - "quote 1.0.33", + "quote 1.0.35", "unicode-ident", ] @@ -2723,9 +2723,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.12" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c39fd04924ca3a864207c66fc2cd7d22d7c016007f9ce846cbb9326331930a" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "tempfile" @@ -2753,22 +2753,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -2943,9 +2943,9 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "typetag" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196976efd4a62737b3a2b662cda76efb448d099b1049613d7a5d72743c611ce0" +checksum = "c43148481c7b66502c48f35b8eef38b6ccdc7a9f04bd4cc294226d901ccc9bc7" dependencies = [ "erased-serde", "inventory", @@ -2956,13 +2956,13 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eea6765137e2414c44c7b1e07c73965a118a72c46148e1e168b3fc9d3ccf3aa" +checksum = "291db8a81af4840c10d636e047cac67664e343be44e24dfdbd1492df9a5d3390" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", ] [[package]] @@ -3125,8 +3125,8 @@ dependencies = [ "log", "once_cell", "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", "wasm-bindgen-shared", ] @@ -3136,7 +3136,7 @@ version = "0.2.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" dependencies = [ - "quote 1.0.33", + "quote 1.0.35", "wasm-bindgen-macro-support", ] @@ -3147,8 +3147,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", - "quote 1.0.33", - "syn 2.0.43", + "quote 1.0.35", + "syn 2.0.46", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3356,9 +3356,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.31" +version = "0.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a4882e6b134d6c28953a387571f1acdd3496830d5e36c5e3a1075580ea641c" +checksum = "8434aeec7b290e8da5c3f0d628cb0eac6cabcb31d14bb74f779a08109a5914d6" dependencies = [ "memchr", ] diff --git a/pgml-extension/Cargo.toml b/pgml-extension/Cargo.toml index ab5fb00dc..362bb017b 100644 --- a/pgml-extension/Cargo.toml +++ b/pgml-extension/Cargo.toml @@ -24,8 +24,8 @@ csv = "1.2" flate2 = "1.0" blas = { version = "0.22" } blas-src = { version = "0.9", features = ["openblas"] } -indexmap = { version = "1.0", features = ["serde"] } -itertools = "0.11" +indexmap = { version = "2.1", features = ["serde"] } +itertools = "0.12" heapless = "0.7" lightgbm = { git = "https://github.com/postgresml/lightgbm-rs", branch = "main" } linfa = { path = "deps/linfa" } diff --git a/pgml-extension/rustfmt.toml b/pgml-extension/rustfmt.toml new file mode 100644 index 000000000..94ac875fa --- /dev/null +++ b/pgml-extension/rustfmt.toml @@ -0,0 +1 @@ +max_width=120 diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index 5a4d8a29a..380bfb330 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -163,21 +163,30 @@ fn train_joint( let task = task.map(|t| Task::from_str(t).unwrap()); let project = match Project::find_by_name(project_name) { Some(project) => project, - None => Project::create(project_name, match task { - Some(task) => task, - None => error!("Project `{}` does not exist. To create a new project, you must specify a `task`.", project_name), - }), + None => Project::create( + project_name, + match task { + Some(task) => task, + None => error!( + "Project `{}` does not exist. To create a new project, you must specify a `task`.", + project_name + ), + }, + ), }; if task.is_some() && task.unwrap() != project.task { - error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + error!( + "Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", + project.name, project.task + ); } let mut snapshot = match relation_name { None => { - let snapshot = project - .last_snapshot() - .expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."); + let snapshot = project.last_snapshot().expect( + "You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model.", + ); info!("Using existing snapshot from {}", snapshot.snapshot_name(),); @@ -302,7 +311,7 @@ fn train_joint( #[pg_extern(name = "deploy")] fn deploy_model( - model_id: i64 + model_id: i64, ) -> TableIterator< 'static, ( @@ -319,8 +328,7 @@ fn deploy_model( ) .unwrap(); - let project_id = - project_id.unwrap_or_else(|| error!("Project does not exist.")); + let project_id = project_id.unwrap_or_else(|| error!("Project does not exist.")); let project = Project::find(project_id).unwrap(); project.deploy(model_id, Strategy::specific); @@ -351,8 +359,7 @@ fn deploy_strategy( ) .unwrap(); - let project_id = - project_id.unwrap_or_else(|| error!("Project named `{}` does not exist.", project_name)); + let project_id = project_id.unwrap_or_else(|| error!("Project named `{}` does not exist.", project_name)); let task = Task::from_str(&task.unwrap()).unwrap(); @@ -367,11 +374,7 @@ fn deploy_strategy( } match strategy { Strategy::best_score => { - let _ = write!( - sql, - "{predicate}\n{}", - task.default_target_metric_sql_order() - ); + let _ = write!(sql, "{predicate}\n{}", task.default_target_metric_sql_order()); } Strategy::most_recent => { @@ -401,22 +404,16 @@ fn deploy_strategy( _ => error!("invalid strategy"), } sql += "\nLIMIT 1"; - let (model_id, algorithm) = Spi::get_two_with_args::( - &sql, - vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())], - ) - .unwrap(); + let (model_id, algorithm) = + Spi::get_two_with_args::(&sql, vec![(PgBuiltInOids::TEXTOID.oid(), project_name.into_datum())]) + .unwrap(); let model_id = model_id.expect("No qualified models exist for this deployment."); let algorithm = algorithm.expect("No qualified models exist for this deployment."); let project = Project::find(project_id).unwrap(); project.deploy(model_id, strategy); - TableIterator::new(vec![( - project_name.to_string(), - strategy.to_string(), - algorithm, - )]) + TableIterator::new(vec![(project_name.to_string(), strategy.to_string(), algorithm)]) } #[pg_extern(immutable, parallel_safe, strict, name = "predict")] @@ -446,10 +443,7 @@ fn predict_i64(project_name: &str, features: Vec) -> f32 { #[pg_extern(immutable, parallel_safe, strict, name = "predict")] fn predict_bool(project_name: &str, features: Vec) -> f32 { - predict_f32( - project_name, - features.iter().map(|&i| i as u8 as f32).collect(), - ) + predict_f32(project_name, features.iter().map(|&i| i as u8 as f32).collect()) } #[pg_extern(immutable, parallel_safe, strict, name = "predict_proba")] @@ -507,8 +501,7 @@ fn predict_model_row(model_id: i64, row: pgrx::datum::AnyElement) -> f32 { let features_width = snapshot.features_width(); let mut processed = vec![0_f32; features_width]; - let feature_data = - ndarray::ArrayView2::from_shape((1, features_width), &numeric_encoded_features).unwrap(); + let feature_data = ndarray::ArrayView2::from_shape((1, features_width), &numeric_encoded_features).unwrap(); Zip::from(feature_data.columns()) .and(&snapshot.feature_positions) @@ -555,12 +548,10 @@ fn load_dataset( "linnerud" => dataset::load_linnerud(limit), "wine" => dataset::load_wine(limit), _ => { - let rows = - match crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0) - { - Ok(rows) => rows, - Err(e) => error!("{e}"), - }; + let rows = match crate::bindings::transformers::load_dataset(source, subset, limit, &kwargs.0) { + Ok(rows) => rows, + Err(e) => error!("{e}"), + }; (source.into(), rows as i64) } }; @@ -579,11 +570,7 @@ pub fn embed(transformer: &str, text: &str, kwargs: default!(JsonB, "'{}'")) -> #[cfg(all(feature = "python", not(feature = "use_as_lib")))] #[pg_extern(immutable, parallel_safe, name = "embed")] -pub fn embed_batch( - transformer: &str, - inputs: Vec<&str>, - kwargs: default!(JsonB, "'{}'"), -) -> Vec> { +pub fn embed_batch(transformer: &str, inputs: Vec<&str>, kwargs: default!(JsonB, "'{}'")) -> Vec> { match crate::bindings::transformers::embed(transformer, inputs, &kwargs.0) { Ok(output) => output, Err(e) => error!("{e}"), @@ -673,13 +660,8 @@ pub fn transform_conversational_json( inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), ) -> JsonB { - if !task.0["task"] - .as_str() - .is_some_and(|v| v == "conversational") - { - error!( - "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" - ); + if !task.0["task"].as_str().is_some_and(|v| v == "conversational") { + error!("ARRAY[]::JSONB inputs for transform should only be used with a conversational task"); } match crate::bindings::transformers::transform(&task.0, &args.0, inputs) { Ok(output) => JsonB(output), @@ -697,9 +679,7 @@ pub fn transform_conversational_string( cache: default!(bool, false), ) -> JsonB { if task != "conversational" { - error!( - "ARRAY[]::JSONB inputs for transform should only be used with a conversational task" - ); + error!("ARRAY[]::JSONB inputs for transform should only be used with a conversational task"); } let task_json = json!({ "task": task }); match crate::bindings::transformers::transform(&task_json, &args.0, inputs) { @@ -718,10 +698,9 @@ pub fn transform_stream_json( cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -736,10 +715,9 @@ pub fn transform_stream_string( ) -> SetOfIterator<'static, JsonB> { let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -752,19 +730,13 @@ pub fn transform_stream_conversational_json( inputs: default!(Vec, "ARRAY[]::JSONB[]"), cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { - if !task.0["task"] - .as_str() - .is_some_and(|v| v == "conversational") - { - error!( - "ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task" - ); + if !task.0["task"].as_str().is_some_and(|v| v == "conversational") { + error!("ARRAY[]::JSONB inputs for transform_stream should only be used with a conversational task"); } // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task.0, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -778,16 +750,13 @@ pub fn transform_stream_conversational_string( cache: default!(bool, false), ) -> SetOfIterator<'static, JsonB> { if task != "conversational" { - error!( - "ARRAY::JSONB inputs for transform_stream should only be used with a conversational task" - ); + error!("ARRAY::JSONB inputs for transform_stream should only be used with a conversational task"); } let task_json = json!({ "task": task }); // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call - let python_iter = - crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = crate::bindings::transformers::transform_stream_iterator(&task_json, &args.0, inputs) + .map_err(|e| error!("{e}")) + .unwrap(); SetOfIterator::new(python_iter) } @@ -802,16 +771,8 @@ fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) - #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] -fn generate_batch( - project_name: &str, - inputs: Vec<&str>, - config: default!(JsonB, "'{}'"), -) -> Vec { - match crate::bindings::transformers::generate( - Project::get_deployed_model_id(project_name), - inputs, - config, - ) { +fn generate_batch(project_name: &str, inputs: Vec<&str>, config: default!(JsonB, "'{}'")) -> Vec { + match crate::bindings::transformers::generate(Project::get_deployed_model_id(project_name), inputs, config) { Ok(output) => output, Err(e) => error!("{e}"), } @@ -857,14 +818,17 @@ fn tune( }; if task.is_some() && task.unwrap() != project.task { - error!("Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", project.name, project.task); + error!( + "Project `{:?}` already exists with a different task: `{:?}`. Create a new project instead.", + project.name, project.task + ); } let mut snapshot = match relation_name { None => { - let snapshot = project - .last_snapshot() - .expect("You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model."); + let snapshot = project.last_snapshot().expect( + "You must pass a `relation_name` and `y_column_name` to snapshot the first time you train a model.", + ); info!("Using existing snapshot from {}", snapshot.snapshot_name(),); @@ -980,20 +944,13 @@ pub fn sklearn_r2_score(ground_truth: Vec, y_hat: Vec) -> f32 { #[cfg(feature = "python")] #[pg_extern(name = "sklearn_regression_metrics")] pub fn sklearn_regression_metrics(ground_truth: Vec, y_hat: Vec) -> JsonB { - let metrics = unwrap_or_error!(crate::bindings::sklearn::regression_metrics( - &ground_truth, - &y_hat, - )); + let metrics = unwrap_or_error!(crate::bindings::sklearn::regression_metrics(&ground_truth, &y_hat,)); JsonB(json!(metrics)) } #[cfg(feature = "python")] #[pg_extern(name = "sklearn_classification_metrics")] -pub fn sklearn_classification_metrics( - ground_truth: Vec, - y_hat: Vec, - num_classes: i64, -) -> JsonB { +pub fn sklearn_classification_metrics(ground_truth: Vec, y_hat: Vec, num_classes: i64) -> JsonB { let metrics = unwrap_or_error!(crate::bindings::sklearn::classification_metrics( &ground_truth, &y_hat, @@ -1006,32 +963,16 @@ pub fn sklearn_classification_metrics( #[pg_extern] pub fn dump_all(path: &str) { let p = std::path::Path::new(path).join("projects.csv"); - Spi::run(&format!( - "COPY pgml.projects TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.projects TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); - Spi::run(&format!( - "COPY pgml.snapshots TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.snapshots TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("models.csv"); - Spi::run(&format!( - "COPY pgml.models TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.models TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("files.csv"); - Spi::run(&format!( - "COPY pgml.files TO '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.files TO '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( @@ -1044,11 +985,7 @@ pub fn dump_all(path: &str) { #[pg_extern] pub fn load_all(path: &str) { let p = std::path::Path::new(path).join("projects.csv"); - Spi::run(&format!( - "COPY pgml.projects FROM '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.projects FROM '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("snapshots.csv"); Spi::run(&format!( @@ -1058,18 +995,10 @@ pub fn load_all(path: &str) { .unwrap(); let p = std::path::Path::new(path).join("models.csv"); - Spi::run(&format!( - "COPY pgml.models FROM '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.models FROM '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("files.csv"); - Spi::run(&format!( - "COPY pgml.files FROM '{}' CSV HEADER", - p.to_str().unwrap() - )) - .unwrap(); + Spi::run(&format!("COPY pgml.files FROM '{}' CSV HEADER", p.to_str().unwrap())).unwrap(); let p = std::path::Path::new(path).join("deployments.csv"); Spi::run(&format!( @@ -1630,9 +1559,7 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. - let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'") - .unwrap(); + let setting = Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -1670,9 +1597,7 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. - let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'") - .unwrap(); + let setting = Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); info!("Data directory: {}", setting.unwrap()); @@ -1710,9 +1635,7 @@ mod tests { // Modify postgresql.conf and add shared_preload_libraries = 'pgml' // to test deployments. - let setting = - Spi::get_one::("select setting from pg_settings where name = 'data_directory'") - .unwrap(); + let setting = Spi::get_one::("select setting from pg_settings where name = 'data_directory'").unwrap(); info!("Data directory: {}", setting.unwrap()); diff --git a/pgml-extension/src/bindings/langchain/mod.rs b/pgml-extension/src/bindings/langchain/mod.rs index 7d8d2582f..d17993df7 100644 --- a/pgml-extension/src/bindings/langchain/mod.rs +++ b/pgml-extension/src/bindings/langchain/mod.rs @@ -18,10 +18,7 @@ pub fn chunk(splitter: &str, text: &str, kwargs: &serde_json::Value) -> Result, - ) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { formatter.debug_struct("Estimator").finish() } } @@ -28,10 +25,7 @@ pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Result Result> { +pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Result> { fit(dataset, hyperparams, Task::classification) } @@ -39,17 +33,11 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result { - hyperparams.insert( - "objective".to_string(), - serde_json::Value::from("regression"), - ); + hyperparams.insert("objective".to_string(), serde_json::Value::from("regression")); } Task::classification => { if dataset.num_distinct_labels > 2 { - hyperparams.insert( - "objective".to_string(), - serde_json::Value::from("multiclass"), - ); + hyperparams.insert("objective".to_string(), serde_json::Value::from("multiclass")); hyperparams.insert( "num_class".to_string(), serde_json::Value::from(dataset.num_distinct_labels), @@ -61,12 +49,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result error!("lightgbm only supports `regression` and `classification` tasks."), }; - let data = lightgbm::Dataset::from_vec( - &dataset.x_train, - &dataset.y_train, - dataset.num_features as i32, - ) - .unwrap(); + let data = lightgbm::Dataset::from_vec(&dataset.x_train, &dataset.y_train, dataset.num_features as i32).unwrap(); let estimator = lightgbm::Booster::train(data, &json! {hyperparams}).unwrap(); @@ -75,12 +58,7 @@ fn fit(dataset: &Dataset, hyperparams: &Hyperparams, task: Task) -> Result Result> { + fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> { let results = self.predict_proba(features, num_features)?; Ok(match num_classes { // TODO make lightgbm predict both classes like scikit and xgboost diff --git a/pgml-extension/src/bindings/linfa.rs b/pgml-extension/src/bindings/linfa.rs index d0dbeda47..c2a6fc437 100644 --- a/pgml-extension/src/bindings/linfa.rs +++ b/pgml-extension/src/bindings/linfa.rs @@ -20,11 +20,7 @@ impl LinearRegression { where Self: Sized, { - let records = ArrayView2::from_shape( - (dataset.num_train_rows, dataset.num_features), - &dataset.x_train, - ) - .unwrap(); + let records = ArrayView2::from_shape((dataset.num_train_rows, dataset.num_features), &dataset.x_train).unwrap(); let targets = ArrayView1::from_shape(dataset.num_train_rows, &dataset.y_train).unwrap(); @@ -34,8 +30,7 @@ impl LinearRegression { for (key, value) in hyperparams { match key.as_str() { "fit_intercept" => { - estimator = estimator - .with_intercept(value.as_bool().expect("fit_intercept must be boolean")) + estimator = estimator.with_intercept(value.as_bool().expect("fit_intercept must be boolean")) } _ => bail!("Unknown {}: {:?}", key.as_str(), value), }; @@ -52,14 +47,8 @@ impl LinearRegression { impl Bindings for LinearRegression { /// Predict a novel datapoint. - fn predict( - &self, - features: &[f32], - num_features: usize, - _num_classes: usize, - ) -> Result> { - let records = - ArrayView2::from_shape((features.len() / num_features, num_features), features)?; + fn predict(&self, features: &[f32], num_features: usize, _num_classes: usize) -> Result> { + let records = ArrayView2::from_shape((features.len() / num_features, num_features), features)?; Ok(self.estimator.predict(records).targets.into_raw_vec()) } @@ -96,11 +85,7 @@ impl LogisticRegression { where Self: Sized, { - let records = ArrayView2::from_shape( - (dataset.num_train_rows, dataset.num_features), - &dataset.x_train, - ) - .unwrap(); + let records = ArrayView2::from_shape((dataset.num_train_rows, dataset.num_features), &dataset.x_train).unwrap(); // Copy to convert to i32 because LogisticRegression doesn't continuous targets. let y_train: Vec = dataset.y_train.iter().map(|x| *x as i32).collect(); @@ -114,22 +99,16 @@ impl LogisticRegression { for (key, value) in hyperparams { match key.as_str() { "fit_intercept" => { - estimator = estimator - .with_intercept(value.as_bool().expect("fit_intercept must be boolean")) - } - "alpha" => { - estimator = - estimator.alpha(value.as_f64().expect("alpha must be a float") as f32) + estimator = estimator.with_intercept(value.as_bool().expect("fit_intercept must be boolean")) } + "alpha" => estimator = estimator.alpha(value.as_f64().expect("alpha must be a float") as f32), "max_iterations" => { - estimator = estimator.max_iterations( - value.as_i64().expect("max_iterations must be an integer") as u64, - ) + estimator = + estimator.max_iterations(value.as_i64().expect("max_iterations must be an integer") as u64) } "gradient_tolerance" => { - estimator = estimator.gradient_tolerance( - value.as_f64().expect("gradient_tolerance must be a float") as f32, - ) + estimator = estimator + .gradient_tolerance(value.as_f64().expect("gradient_tolerance must be a float") as f32) } _ => bail!("Unknown {}: {:?}", key.as_str(), value), }; @@ -149,22 +128,16 @@ impl LogisticRegression { for (key, value) in hyperparams { match key.as_str() { "fit_intercept" => { - estimator = estimator - .with_intercept(value.as_bool().expect("fit_intercept must be boolean")) - } - "alpha" => { - estimator = - estimator.alpha(value.as_f64().expect("alpha must be a float") as f32) + estimator = estimator.with_intercept(value.as_bool().expect("fit_intercept must be boolean")) } + "alpha" => estimator = estimator.alpha(value.as_f64().expect("alpha must be a float") as f32), "max_iterations" => { - estimator = estimator.max_iterations( - value.as_i64().expect("max_iterations must be an integer") as u64, - ) + estimator = + estimator.max_iterations(value.as_i64().expect("max_iterations must be an integer") as u64) } "gradient_tolerance" => { - estimator = estimator.gradient_tolerance( - value.as_f64().expect("gradient_tolerance must be a float") as f32, - ) + estimator = estimator + .gradient_tolerance(value.as_f64().expect("gradient_tolerance must be a float") as f32) } _ => bail!("Unknown {}: {:?}", key.as_str(), value), }; @@ -187,16 +160,8 @@ impl Bindings for LogisticRegression { bail!("predict_proba is currently only supported by the Python runtime.") } - fn predict( - &self, - features: &[f32], - _num_features: usize, - _num_classes: usize, - ) -> Result> { - let records = ArrayView2::from_shape( - (features.len() / self.num_features, self.num_features), - features, - )?; + fn predict(&self, features: &[f32], _num_features: usize, _num_classes: usize) -> Result> { + let records = ArrayView2::from_shape((features.len() / self.num_features, self.num_features), features)?; Ok(if self.num_distinct_labels > 2 { self.estimator_multi @@ -244,11 +209,7 @@ pub struct Svm { impl Svm { pub fn fit(dataset: &Dataset, hyperparams: &Hyperparams) -> Result> { - let records = ArrayView2::from_shape( - (dataset.num_train_rows, dataset.num_features), - &dataset.x_train, - ) - .unwrap(); + let records = ArrayView2::from_shape((dataset.num_train_rows, dataset.num_features), &dataset.x_train).unwrap(); let targets = ArrayView1::from_shape(dataset.num_train_rows, &dataset.y_train).unwrap(); @@ -264,13 +225,8 @@ impl Svm { for (key, value) in hyperparams { match key.as_str() { - "eps" => { - estimator = estimator.eps(value.as_f64().expect("eps must be a float") as f32) - } - "shrinking" => { - estimator = - estimator.shrinking(value.as_bool().expect("shrinking must be a bool")) - } + "eps" => estimator = estimator.eps(value.as_f64().expect("eps must be a float") as f32), + "shrinking" => estimator = estimator.shrinking(value.as_bool().expect("shrinking must be a bool")), "kernel" => { match value.as_str().expect("kernel must be a string") { "poli" => estimator = estimator.polynomial_kernel(3.0, 1.0), // degree = 3, c = 1.0 as per Scikit @@ -298,14 +254,8 @@ impl Bindings for Svm { } /// Predict a novel datapoint. - fn predict( - &self, - features: &[f32], - num_features: usize, - _num_classes: usize, - ) -> Result> { - let records = - ArrayView2::from_shape((features.len() / num_features, num_features), features)?; + fn predict(&self, features: &[f32], num_features: usize, _num_classes: usize) -> Result> { + let records = ArrayView2::from_shape((features.len() / num_features, num_features), features)?; Ok(self.estimator.predict(records).targets.into_raw_vec()) } diff --git a/pgml-extension/src/bindings/mod.rs b/pgml-extension/src/bindings/mod.rs index 79e543490..d877f490a 100644 --- a/pgml-extension/src/bindings/mod.rs +++ b/pgml-extension/src/bindings/mod.rs @@ -11,19 +11,18 @@ use crate::orm::*; #[macro_export] macro_rules! create_pymodule { ($pyfile:literal) => { - pub static PY_MODULE: once_cell::sync::Lazy< - anyhow::Result>, - > = once_cell::sync::Lazy::new(|| { - pyo3::Python::with_gil(|py| -> anyhow::Result> { - use $crate::bindings::TracebackError; - let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile)); - Ok( - pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") - .format_traceback(py)? - .into(), - ) - }) - }); + pub static PY_MODULE: once_cell::sync::Lazy>> = + once_cell::sync::Lazy::new(|| { + pyo3::Python::with_gil(|py| -> anyhow::Result> { + use $crate::bindings::TracebackError; + let src = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), $pyfile)); + Ok( + pyo3::types::PyModule::from_code(py, src, "transformers.py", "__main__") + .format_traceback(py)? + .into(), + ) + }) + }); }; } @@ -59,12 +58,7 @@ pub type Fit = fn(dataset: &Dataset, hyperparams: &Hyperparams) -> Result Result>; + fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result>; /// Predict the probability of each class. fn predict_proba(&self, features: &[f32], num_features: usize) -> Result>; diff --git a/pgml-extension/src/bindings/python/mod.rs b/pgml-extension/src/bindings/python/mod.rs index 9ab7300c0..84e7505b7 100644 --- a/pgml-extension/src/bindings/python/mod.rs +++ b/pgml-extension/src/bindings/python/mod.rs @@ -16,8 +16,7 @@ create_pymodule!("/src/bindings/python/python.py"); pub fn activate_venv(venv: &str) -> Result { Python::with_gil(|py| { let activate_venv: Py = get_module!(PY_MODULE).getattr(py, "activate_venv")?; - let result: Py = - activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?; + let result: Py = activate_venv.call1(py, PyTuple::new(py, &[venv.to_string().into_py(py)]))?; Ok(result.extract(py)?) }) @@ -39,9 +38,7 @@ pub fn pip_freeze() -> Result> Ok(result.extract(py)?) })?; - Ok(TableIterator::new( - packages.into_iter().map(|package| (package,)), - )) + Ok(TableIterator::new(packages.into_iter().map(|package| (package,)))) } pub fn validate_dependencies() -> Result { @@ -54,9 +51,7 @@ pub fn validate_dependencies() -> Result { match py.import(module) { Ok(_) => (), Err(e) => { - panic!( - "The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}" - ); + panic!("The {module} package is missing. Install it with `sudo pip3 install {module}`\n{e}"); } } } diff --git a/pgml-extension/src/bindings/sklearn/mod.rs b/pgml-extension/src/bindings/sklearn/mod.rs index 4b8ce6625..bee066b87 100644 --- a/pgml-extension/src/bindings/sklearn/mod.rs +++ b/pgml-extension/src/bindings/sklearn/mod.rs @@ -33,10 +33,7 @@ wrap_fit!(elastic_net_regression, "elastic_net_regression"); wrap_fit!(ridge_regression, "ridge_regression"); wrap_fit!(random_forest_regression, "random_forest_regression"); wrap_fit!(xgboost_regression, "xgboost_regression"); -wrap_fit!( - xgboost_random_forest_regression, - "xgboost_random_forest_regression" -); +wrap_fit!(xgboost_random_forest_regression, "xgboost_random_forest_regression"); wrap_fit!( orthogonal_matching_persuit_regression, "orthogonal_matching_persuit_regression" @@ -50,10 +47,7 @@ wrap_fit!( stochastic_gradient_descent_regression, "stochastic_gradient_descent_regression" ); -wrap_fit!( - passive_aggressive_regression, - "passive_aggressive_regression" -); +wrap_fit!(passive_aggressive_regression, "passive_aggressive_regression"); wrap_fit!(ransac_regression, "ransac_regression"); wrap_fit!(theil_sen_regression, "theil_sen_regression"); wrap_fit!(huber_regression, "huber_regression"); @@ -64,14 +58,8 @@ wrap_fit!(nu_svm_regression, "nu_svm_regression"); wrap_fit!(ada_boost_regression, "ada_boost_regression"); wrap_fit!(bagging_regression, "bagging_regression"); wrap_fit!(extra_trees_regression, "extra_trees_regression"); -wrap_fit!( - gradient_boosting_trees_regression, - "gradient_boosting_trees_regression" -); -wrap_fit!( - hist_gradient_boosting_regression, - "hist_gradient_boosting_regression" -); +wrap_fit!(gradient_boosting_trees_regression, "gradient_boosting_trees_regression"); +wrap_fit!(hist_gradient_boosting_regression, "hist_gradient_boosting_regression"); wrap_fit!(least_angle_regression, "least_angle_regression"); wrap_fit!(lasso_least_angle_regression, "lasso_least_angle_regression"); wrap_fit!(linear_svm_regression, "linear_svm_regression"); @@ -91,10 +79,7 @@ wrap_fit!( "stochastic_gradient_descent_classification" ); wrap_fit!(perceptron_classification, "perceptron_classification"); -wrap_fit!( - passive_aggressive_classification, - "passive_aggressive_classification" -); +wrap_fit!(passive_aggressive_classification, "passive_aggressive_classification"); wrap_fit!(gaussian_process, "gaussian_process"); wrap_fit!(nu_svm_classification, "nu_svm_classification"); wrap_fit!(ada_boost_classification, "ada_boost_classification"); @@ -124,47 +109,41 @@ wrap_fit!(spectral, "spectral_clustering"); wrap_fit!(spectral_bi, "spectral_biclustering"); wrap_fit!(spectral_co, "spectral_coclustering"); -fn fit( - dataset: &Dataset, - hyperparams: &Hyperparams, - algorithm_task: &'static str, -) -> Result> { +fn fit(dataset: &Dataset, hyperparams: &Hyperparams, algorithm_task: &'static str) -> Result> { let hyperparams = serde_json::to_string(hyperparams).unwrap(); - let (estimator, predict, predict_proba) = - Python::with_gil(|py| -> Result<(Py, Py, Py)> { - let module = get_module!(PY_MODULE); + let (estimator, predict, predict_proba) = Python::with_gil(|py| -> Result<(Py, Py, Py)> { + let module = get_module!(PY_MODULE); - let estimator: Py = module.getattr(py, "estimator")?; + let estimator: Py = module.getattr(py, "estimator")?; - let train: Py = estimator.call1( + let train: Py = estimator.call1( + py, + PyTuple::new( py, - PyTuple::new( - py, - &[ - String::from(algorithm_task).into_py(py), - dataset.num_features.into_py(py), - dataset.num_labels.into_py(py), - hyperparams.into_py(py), - ], - ), - )?; - - let estimator: Py = - train.call1(py, PyTuple::new(py, [&dataset.x_train, &dataset.y_train]))?; - - let predict: Py = module - .getattr(py, "predictor")? - .call1(py, PyTuple::new(py, [&estimator]))? - .extract(py)?; + &[ + String::from(algorithm_task).into_py(py), + dataset.num_features.into_py(py), + dataset.num_labels.into_py(py), + hyperparams.into_py(py), + ], + ), + )?; + + let estimator: Py = train.call1(py, PyTuple::new(py, [&dataset.x_train, &dataset.y_train]))?; + + let predict: Py = module + .getattr(py, "predictor")? + .call1(py, PyTuple::new(py, [&estimator]))? + .extract(py)?; - let predict_proba: Py = module - .getattr(py, "predictor_proba")? - .call1(py, PyTuple::new(py, [&estimator]))? - .extract(py)?; + let predict_proba: Py = module + .getattr(py, "predictor_proba")? + .call1(py, PyTuple::new(py, [&estimator]))? + .extract(py)?; - Ok((estimator, predict, predict_proba)) - })?; + Ok((estimator, predict, predict_proba)) + })?; Ok(Box::new(Estimator { estimator, @@ -183,28 +162,15 @@ unsafe impl Send for Estimator {} unsafe impl Sync for Estimator {} impl std::fmt::Debug for Estimator { - fn fmt( - &self, - formatter: &mut std::fmt::Formatter<'_>, - ) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { formatter.debug_struct("Estimator").finish() } } impl Bindings for Estimator { /// Predict a novel datapoint. - fn predict( - &self, - features: &[f32], - _num_features: usize, - _num_classes: usize, - ) -> Result> { - Python::with_gil(|py| { - Ok(self - .predict - .call1(py, PyTuple::new(py, [features]))? - .extract(py)?) - }) + fn predict(&self, features: &[f32], _num_features: usize, _num_classes: usize) -> Result> { + Python::with_gil(|py| Ok(self.predict.call1(py, PyTuple::new(py, [features]))?.extract(py)?)) } fn predict_proba(&self, features: &[f32], _num_features: usize) -> Result> { @@ -220,9 +186,7 @@ impl Bindings for Estimator { fn to_bytes(&self) -> Result> { Python::with_gil(|py| { let save = get_module!(PY_MODULE).getattr(py, "save")?; - Ok(save - .call1(py, PyTuple::new(py, [&self.estimator]))? - .extract(py)?) + Ok(save.call1(py, PyTuple::new(py, [&self.estimator]))?.extract(py)?) }) } @@ -258,12 +222,8 @@ impl Bindings for Estimator { fn sklearn_metric(name: &str, ground_truth: &[f32], y_hat: &[f32]) -> Result { Python::with_gil(|py| { - let calculate_metric = get_module!(PY_MODULE) - .getattr(py, "calculate_metric") - .unwrap(); - let wrapper: Py = calculate_metric - .call1(py, PyTuple::new(py, [name]))? - .extract(py)?; + let calculate_metric = get_module!(PY_MODULE).getattr(py, "calculate_metric").unwrap(); + let wrapper: Py = calculate_metric.call1(py, PyTuple::new(py, [name]))?.extract(py)?; let score: f32 = wrapper .call1(py, PyTuple::new(py, [ground_truth, y_hat]))? @@ -315,11 +275,7 @@ pub fn regression_metrics(ground_truth: &[f32], y_hat: &[f32]) -> Result Result> { +pub fn classification_metrics(ground_truth: &[f32], y_hat: &[f32], num_classes: usize) -> Result> { let mut scores = Python::with_gil(|py| -> Result> { let calculate_metric = get_module!(PY_MODULE).getattr(py, "classification_metrics")?; let scores: HashMap = calculate_metric @@ -337,11 +293,7 @@ pub fn classification_metrics( Ok(scores) } -pub fn cluster_metrics( - num_features: usize, - inputs: &[f32], - labels: &[f32], -) -> Result> { +pub fn cluster_metrics(num_features: usize, inputs: &[f32], labels: &[f32]) -> Result> { Python::with_gil(|py| { let calculate_metric = get_module!(PY_MODULE).getattr(py, "cluster_metrics")?; diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index 9a8528ddb..b300d84e3 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -33,18 +33,12 @@ pub fn get_model_from(task: &Value) -> Result { }) } -pub fn embed( - transformer: &str, - inputs: Vec<&str>, - kwargs: &serde_json::Value, -) -> Result>> { +pub fn embed(transformer: &str, inputs: Vec<&str>, kwargs: &serde_json::Value) -> Result>> { crate::bindings::python::activate()?; let kwargs = serde_json::to_string(kwargs)?; Python::with_gil(|py| -> Result>> { - let embed: Py = get_module!(PY_MODULE) - .getattr(py, "embed") - .format_traceback(py)?; + let embed: Py = get_module!(PY_MODULE).getattr(py, "embed").format_traceback(py)?; let output = embed .call1( py, @@ -63,21 +57,14 @@ pub fn embed( }) } -pub fn tune( - task: &Task, - dataset: TextDataset, - hyperparams: &JsonB, - path: &Path, -) -> Result> { +pub fn tune(task: &Task, dataset: TextDataset, hyperparams: &JsonB, path: &Path) -> Result> { crate::bindings::python::activate()?; let task = task.to_string(); let hyperparams = serde_json::to_string(&hyperparams.0)?; Python::with_gil(|py| -> Result> { - let tune = get_module!(PY_MODULE) - .getattr(py, "tune") - .format_traceback(py)?; + let tune = get_module!(PY_MODULE).getattr(py, "tune").format_traceback(py)?; let path = path.to_string_lossy(); let output = tune .call1( @@ -102,9 +89,7 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result Result> { - let generate = get_module!(PY_MODULE) - .getattr(py, "generate") - .format_traceback(py)?; + let generate = get_module!(PY_MODULE).getattr(py, "generate").format_traceback(py)?; let config = serde_json::to_string(&config.0)?; // cloning inputs in case we have to re-call on error is rather unfortunate here // similarly, using a json string to pass kwargs is also unfortunate extra parsing @@ -130,14 +115,10 @@ pub fn generate(model_id: i64, inputs: Vec<&str>, config: JsonB) -> Result Result<()> { } std::fs::create_dir_all(&dir).context("failed to create directory while dumping model")?; Spi::connect(|client| -> Result<()> { - let result = client.select("SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", - None, - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), model_id.into_datum()), - ]) - )?; + let result = client.select( + "SELECT path, part, data FROM pgml.files WHERE model_id = $1 ORDER BY path ASC, part ASC", + None, + Some(vec![(PgBuiltInOids::INT8OID.oid(), model_id.into_datum())]), + )?; for row in result { let mut path = dir.clone(); path.push( row.get::(1)? .ok_or(anyhow!("row get ordinal 1 returned None"))?, ); - let data: Vec = row - .get(3)? - .ok_or(anyhow!("row get ordinal 3 returned None"))?; - let mut file = std::fs::OpenOptions::new() - .create(true) - .append(true) - .open(path)?; + let data: Vec = row.get(3)?.ok_or(anyhow!("row get ordinal 3 returned None"))?; + let mut file = std::fs::OpenOptions::new().create(true).append(true).open(path)?; let _num_bytes = file.write(&data)?; file.flush()?; @@ -217,9 +192,7 @@ pub fn load_dataset( // Columns are a (name: String, values: Vec) pair let json: serde_json::Value = serde_json::from_str(&dataset)?; - let json = json - .as_object() - .ok_or(anyhow!("dataset json is not object"))?; + let json = json.as_object().ok_or(anyhow!("dataset json is not object"))?; let types = json .get("types") .ok_or(anyhow!("dataset json missing `types` key"))? @@ -238,9 +211,7 @@ pub fn load_dataset( let column_types = types .iter() .map(|(name, type_)| -> Result { - let type_ = type_ - .as_str() - .ok_or(anyhow!("expected {type_} to be a json string"))?; + let type_ = type_.as_str().ok_or(anyhow!("expected {type_} to be a json string"))?; let type_ = match type_ { "string" => "TEXT", "dict" | "list" => "JSONB", @@ -276,16 +247,17 @@ pub fn load_dataset( .len(); // Avoid the existence warning by checking the schema for the table first - let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) - ])?.ok_or(anyhow!("table count query returned None"))?; + let table_count = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", + vec![(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())], + )? + .ok_or(anyhow!("table count query returned None"))?; if table_count == 1 { Spi::run(&format!(r#"DROP TABLE IF EXISTS {table_name}"#))?; } Spi::run(&format!(r#"CREATE TABLE {table_name} ({column_types})"#))?; - let insert = - format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); + let insert = format!(r#"INSERT INTO {table_name} ({column_names}) VALUES ({column_placeholders})"#); for i in 0..num_rows { let mut row = Vec::with_capacity(num_cols); for (name, values) in data { @@ -307,10 +279,7 @@ pub fn load_dataset( .ok_or_else(|| anyhow!("expected {value} to be string"))? .into_datum(), )), - "dict" | "list" => row.push(( - PgBuiltInOids::JSONBOID.oid(), - JsonB(value.clone()).into_datum(), - )), + "dict" | "list" => row.push((PgBuiltInOids::JSONBOID.oid(), JsonB(value.clone()).into_datum())), "int64" | "int32" | "int16" => row.push(( PgBuiltInOids::INT8OID.oid(), value diff --git a/pgml-extension/src/bindings/transformers/transform.rs b/pgml-extension/src/bindings/transformers/transform.rs index fa03984d9..21503f186 100644 --- a/pgml-extension/src/bindings/transformers/transform.rs +++ b/pgml-extension/src/bindings/transformers/transform.rs @@ -54,17 +54,12 @@ pub fn transform( let inputs = serde_json::to_string(&inputs)?; let results = Python::with_gil(|py| -> Result { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; + let transform: Py = get_module!(PY_MODULE).getattr(py, "transform").format_traceback(py)?; let output = transform .call1( py, - PyTuple::new( - py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], - ), + PyTuple::new(py, &[task.into_py(py), args.into_py(py), inputs.into_py(py)]), ) .format_traceback(py)?; @@ -87,21 +82,14 @@ pub fn transform_stream( let input = serde_json::to_string(&input)?; Python::with_gil(|py| -> Result> { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; + let transform: Py = get_module!(PY_MODULE).getattr(py, "transform").format_traceback(py)?; let output = transform .call1( py, PyTuple::new( py, - &[ - task.into_py(py), - args.into_py(py), - input.into_py(py), - true.into_py(py), - ], + &[task.into_py(py), args.into_py(py), input.into_py(py), true.into_py(py)], ), ) .format_traceback(py)?; @@ -115,8 +103,6 @@ pub fn transform_stream_iterator( args: &serde_json::Value, input: T, ) -> Result { - let python_iter = transform_stream(task, args, input) - .map_err(|e| error!("{e}")) - .unwrap(); + let python_iter = transform_stream(task, args, input).map_err(|e| error!("{e}")).unwrap(); Ok(TransformStreamIterator::new(python_iter)) } diff --git a/pgml-extension/src/bindings/transformers/whitelist.rs b/pgml-extension/src/bindings/transformers/whitelist.rs index 3714091d1..0194180c0 100644 --- a/pgml-extension/src/bindings/transformers/whitelist.rs +++ b/pgml-extension/src/bindings/transformers/whitelist.rs @@ -17,8 +17,7 @@ pub fn verify_task(task: &Value) -> Result<(), Error> { }; let whitelisted_models = config_csv_list(CONFIG_HF_WHITELIST); - let model_is_allowed = - whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); + let model_is_allowed = whitelisted_models.is_empty() || whitelisted_models.contains(&task_model); if !model_is_allowed { bail!("model {task_model} is not whitelisted. Consider adding to {CONFIG_HF_WHITELIST} in postgresql.conf"); } @@ -45,13 +44,7 @@ fn config_csv_list(name: &str) -> Vec { Some(value) => value .trim_matches('"') .split(',') - .filter_map(|s| { - if s.is_empty() { - None - } else { - Some(s.to_string()) - } - }) + .filter_map(|s| if s.is_empty() { None } else { Some(s.to_string()) }) .collect(), None => vec![], } @@ -76,13 +69,10 @@ fn get_trust_remote_code(task: &Value) -> Option { // The JSON key for the trust remote code flag static TASK_REMOTE_CODE_KEY: &str = "trust_remote_code"; match task { - Value::Object(map) => map.get(TASK_REMOTE_CODE_KEY).and_then(|v| { - if let Value::Bool(trust) = v { - Some(*trust) - } else { - None - } - }), + Value::Object(map) => { + map.get(TASK_REMOTE_CODE_KEY) + .and_then(|v| if let Value::Bool(trust) = v { Some(*trust) } else { None }) + } _ => None, } } diff --git a/pgml-extension/src/bindings/xgboost.rs b/pgml-extension/src/bindings/xgboost.rs index be3d2b09f..3e533d5f3 100644 --- a/pgml-extension/src/bindings/xgboost.rs +++ b/pgml-extension/src/bindings/xgboost.rs @@ -128,9 +128,7 @@ fn get_tree_params(hyperparams: &Hyperparams) -> tree::TreeBoosterParameters { }, "max_leaves" => params.max_leaves(value.as_u64().unwrap() as u32), "max_bin" => params.max_bin(value.as_u64().unwrap() as u32), - "booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => { - &mut params - } // Valid but not relevant to this section + "booster" | "n_estimators" | "boost_rounds" | "eval_metric" | "objective" => &mut params, // Valid but not relevant to this section "nthread" => &mut params, "random_state" => &mut params, _ => panic!("Unknown hyperparameter {:?}: {:?}", key, value), @@ -143,10 +141,7 @@ pub fn fit_regression(dataset: &Dataset, hyperparams: &Hyperparams) -> Result Result> { +pub fn fit_classification(dataset: &Dataset, hyperparams: &Hyperparams) -> Result> { fit( dataset, hyperparams, @@ -187,12 +182,8 @@ fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective { "gpu:binary:logitraw" => learning::Objective::GpuBinaryLogisticRaw, "count:poisson" => learning::Objective::CountPoisson, "survival:cox" => learning::Objective::SurvivalCox, - "multi:softmax" => { - learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap()) - } - "multi:softprob" => { - learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap()) - } + "multi:softmax" => learning::Objective::MultiSoftmax(dataset.num_distinct_labels.try_into().unwrap()), + "multi:softprob" => learning::Objective::MultiSoftprob(dataset.num_distinct_labels.try_into().unwrap()), "rank:pairwise" => learning::Objective::RankPairwise, "reg:gamma" => learning::Objective::RegGamma, "reg:tweedie" => learning::Objective::RegTweedie(Some(dataset.num_distinct_labels as f32)), @@ -200,11 +191,7 @@ fn objective_from_string(name: &str, dataset: &Dataset) -> learning::Objective { } } -fn fit( - dataset: &Dataset, - hyperparams: &Hyperparams, - objective: learning::Objective, -) -> Result> { +fn fit(dataset: &Dataset, hyperparams: &Hyperparams, objective: learning::Objective) -> Result> { // split the train/test data into DMatrix let mut dtrain = DMatrix::from_dense(&dataset.x_train, dataset.num_train_rows).unwrap(); let mut dtest = DMatrix::from_dense(&dataset.x_test, dataset.num_test_rows).unwrap(); @@ -230,9 +217,7 @@ fn fit( .collect(), ) } else { - learning::Metrics::Custom(Vec::from([eval_metric_from_string( - metrics.as_str().unwrap(), - )])) + learning::Metrics::Custom(Vec::from([eval_metric_from_string(metrics.as_str().unwrap())])) } } None => learning::Metrics::Auto, @@ -314,21 +299,13 @@ unsafe impl Send for Estimator {} unsafe impl Sync for Estimator {} impl std::fmt::Debug for Estimator { - fn fmt( - &self, - formatter: &mut std::fmt::Formatter<'_>, - ) -> std::result::Result<(), std::fmt::Error> { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { formatter.debug_struct("Estimator").finish() } } impl Bindings for Estimator { - fn predict( - &self, - features: &[f32], - num_features: usize, - num_classes: usize, - ) -> Result> { + fn predict(&self, features: &[f32], num_features: usize, num_classes: usize) -> Result> { let x = DMatrix::from_dense(features, features.len() / num_features)?; let y = self.estimator.predict(&x)?; Ok(match num_classes { diff --git a/pgml-extension/src/lib.rs b/pgml-extension/src/lib.rs index ce0bdbeb2..2bf5235d4 100644 --- a/pgml-extension/src/lib.rs +++ b/pgml-extension/src/lib.rs @@ -57,7 +57,9 @@ pub mod pg_test { let option = format!("pgml.venv = '{venv}'"); options.push(Box::leak(option.into_boxed_str())); } else { - println!("If using virtualenv for Python depenencies, set the `PGML_VENV` environment variable for testing"); + println!( + "If using virtualenv for Python depenencies, set the `PGML_VENV` environment variable for testing" + ); } options } diff --git a/pgml-extension/src/metrics.rs b/pgml-extension/src/metrics.rs index b3c1d2b5d..0d674668b 100644 --- a/pgml-extension/src/metrics.rs +++ b/pgml-extension/src/metrics.rs @@ -47,11 +47,7 @@ impl ConfusionMatrix { /// and the predictions. /// `num_classes` is passed it to ensure that all classes /// were present in the test set. - pub fn new( - ground_truth: &ArrayView1, - y_hat: &ArrayView1, - num_classes: usize, - ) -> ConfusionMatrix { + pub fn new(ground_truth: &ArrayView1, y_hat: &ArrayView1, num_classes: usize) -> ConfusionMatrix { // Distinct classes. let mut classes = ground_truth.iter().collect::>(); classes.extend(&mut y_hat.iter().collect::>().into_iter()); @@ -115,22 +111,14 @@ impl ConfusionMatrix { /// Average recall. pub fn recall(&self) -> f32 { - let recalls = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fn_)) - .collect::>(); + let recalls = self.metrics.iter().map(|m| m.tp / (m.tp + m.fn_)).collect::>(); recalls.iter().sum::() / recalls.len() as f32 } /// Average precision. pub fn precision(&self) -> f32 { - let precisions = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fp)) - .collect::>(); + let precisions = self.metrics.iter().map(|m| m.tp / (m.tp + m.fp)).collect::>(); precisions.iter().sum::() / precisions.len() as f32 } @@ -162,16 +150,8 @@ impl ConfusionMatrix { /// Calculate f1 using the average of class f1's. /// This gives equal opportunity to each class to impact the overall score. fn f1_macro(&self) -> f32 { - let recalls = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fn_)) - .collect::>(); - let precisions = self - .metrics - .iter() - .map(|m| m.tp / (m.tp + m.fp)) - .collect::>(); + let recalls = self.metrics.iter().map(|m| m.tp / (m.tp + m.fn_)).collect::>(); + let precisions = self.metrics.iter().map(|m| m.tp / (m.tp + m.fp)).collect::>(); let mut f1s = Vec::new(); @@ -194,11 +174,7 @@ mod test { let ground_truth = array![1, 2, 3, 4, 4]; let y_hat = array![1, 2, 3, 4, 4]; - let mat = ConfusionMatrix::new( - &ArrayView1::from(&ground_truth), - &ArrayView1::from(&y_hat), - 4, - ); + let mat = ConfusionMatrix::new(&ArrayView1::from(&ground_truth), &ArrayView1::from(&y_hat), 4); let f1 = mat.f1(Average::Macro); let f1_micro = mat.f1(Average::Micro); diff --git a/pgml-extension/src/orm/algorithm.rs b/pgml-extension/src/orm/algorithm.rs index a8a72d1fb..21a87e3bf 100644 --- a/pgml-extension/src/orm/algorithm.rs +++ b/pgml-extension/src/orm/algorithm.rs @@ -122,9 +122,7 @@ impl std::string::ToString for Algorithm { Algorithm::lasso_least_angle => "lasso_least_angle".to_string(), Algorithm::orthogonal_matching_pursuit => "orthogonal_matching_pursuit".to_string(), Algorithm::bayesian_ridge => "bayesian_ridge".to_string(), - Algorithm::automatic_relevance_determination => { - "automatic_relevance_determination".to_string() - } + Algorithm::automatic_relevance_determination => "automatic_relevance_determination".to_string(), Algorithm::stochastic_gradient_descent => "stochastic_gradient_descent".to_string(), Algorithm::perceptron => "perceptron".to_string(), Algorithm::passive_aggressive => "passive_aggressive".to_string(), diff --git a/pgml-extension/src/orm/dataset.rs b/pgml-extension/src/orm/dataset.rs index 9e22ef0ae..062886a5c 100644 --- a/pgml-extension/src/orm/dataset.rs +++ b/pgml-extension/src/orm/dataset.rs @@ -94,9 +94,12 @@ impl Display for TextDataset { fn drop_table_if_exists(table_name: &str) { // Avoid the existence for DROP TABLE IF EXISTS warning by checking the schema for the table first - let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.into_datum()) - ]).unwrap().unwrap(); + let table_count = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'pgml'", + vec![(PgBuiltInOids::TEXTOID.oid(), table_name.into_datum())], + ) + .unwrap() + .unwrap(); if table_count == 1 { Spi::run(&format!(r#"DROP TABLE pgml.{table_name} CASCADE"#)).unwrap(); } @@ -476,15 +479,9 @@ pub fn load_iris(limit: Option) -> (String, i64) { VALUES ($1, $2, $3, $4, $5) ", Some(vec![ - ( - PgBuiltInOids::FLOAT4OID.oid(), - row.sepal_length.into_datum(), - ), + (PgBuiltInOids::FLOAT4OID.oid(), row.sepal_length.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.sepal_width.into_datum()), - ( - PgBuiltInOids::FLOAT4OID.oid(), - row.petal_length.into_datum(), - ), + (PgBuiltInOids::FLOAT4OID.oid(), row.petal_length.into_datum()), (PgBuiltInOids::FLOAT4OID.oid(), row.petal_width.into_datum()), (PgBuiltInOids::INT4OID.oid(), row.target.into_datum()), ]), diff --git a/pgml-extension/src/orm/model.rs b/pgml-extension/src/orm/model.rs index da1940f60..8deebe042 100644 --- a/pgml-extension/src/orm/model.rs +++ b/pgml-extension/src/orm/model.rs @@ -21,8 +21,7 @@ use crate::bindings::*; use crate::orm::*; #[allow(clippy::type_complexity)] -static DEPLOYED_MODELS_BY_ID: Lazy>>> = - Lazy::new(|| Mutex::new(HashMap::new())); +static DEPLOYED_MODELS_BY_ID: Lazy>>> = Lazy::new(|| Mutex::new(HashMap::new())); #[derive(Debug)] pub struct Model { @@ -197,10 +196,7 @@ impl Model { hyperparams: result.get(6).unwrap().unwrap(), status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), - search: result - .get(9) - .unwrap() - .map(|search| Search::from_str(search).unwrap()), + search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), created_at: result.get(12).unwrap().unwrap(), @@ -251,11 +247,15 @@ impl Model { "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, $2, $3, $4) RETURNING id", vec![ (PgBuiltInOids::INT8OID.oid(), model.id.into_datum()), - (PgBuiltInOids::TEXTOID.oid(), path.file_name().unwrap().to_str().into_datum()), + ( + PgBuiltInOids::TEXTOID.oid(), + path.file_name().unwrap().to_str().into_datum(), + ), (PgBuiltInOids::INT8OID.oid(), (i as i64).into_datum()), (PgBuiltInOids::BYTEAOID.oid(), chunk.into_datum()), ], - ).unwrap(); + ) + .unwrap(); } } @@ -360,10 +360,7 @@ impl Model { hyperparams: result.get(6).unwrap().unwrap(), status: Status::from_str(result.get(7).unwrap().unwrap()).unwrap(), metrics: result.get(8).unwrap(), - search: result - .get(9) - .unwrap() - .map(|search| Search::from_str(search).unwrap()), + search: result.get(9).unwrap().map(|search| Search::from_str(search).unwrap()), search_params: result.get(10).unwrap().unwrap(), search_args: result.get(11).unwrap().unwrap(), created_at: result.get(12).unwrap().unwrap(), @@ -379,12 +376,7 @@ impl Model { Ok(()) })?; - model.ok_or_else(|| { - anyhow!( - "pgml.models WHERE id = {:?} could not be loaded. Does it exist?", - id - ) - }) + model.ok_or_else(|| anyhow!("pgml.models WHERE id = {:?} could not be loaded. Does it exist?", id)) } pub fn find_cached(id: i64) -> Result> { @@ -443,16 +435,12 @@ impl Model { Algorithm::random_forest => sklearn::random_forest_regression, Algorithm::xgboost => sklearn::xgboost_regression, Algorithm::xgboost_random_forest => sklearn::xgboost_random_forest_regression, - Algorithm::orthogonal_matching_pursuit => { - sklearn::orthogonal_matching_persuit_regression - } + Algorithm::orthogonal_matching_pursuit => sklearn::orthogonal_matching_persuit_regression, Algorithm::bayesian_ridge => sklearn::bayesian_ridge_regression, Algorithm::automatic_relevance_determination => { sklearn::automatic_relevance_determination_regression } - Algorithm::stochastic_gradient_descent => { - sklearn::stochastic_gradient_descent_regression - } + Algorithm::stochastic_gradient_descent => sklearn::stochastic_gradient_descent_regression, Algorithm::passive_aggressive => sklearn::passive_aggressive_regression, Algorithm::ransac => sklearn::ransac_regression, Algorithm::theil_sen => sklearn::theil_sen_regression, @@ -464,9 +452,7 @@ impl Model { Algorithm::ada_boost => sklearn::ada_boost_regression, Algorithm::bagging => sklearn::bagging_regression, Algorithm::extra_trees => sklearn::extra_trees_regression, - Algorithm::gradient_boosting_trees => { - sklearn::gradient_boosting_trees_regression - } + Algorithm::gradient_boosting_trees => sklearn::gradient_boosting_trees_regression, Algorithm::hist_gradient_boosting => sklearn::hist_gradient_boosting_regression, Algorithm::least_angle => sklearn::least_angle_regression, Algorithm::lasso_least_angle => sklearn::lasso_least_angle_regression, @@ -481,12 +467,8 @@ impl Model { Algorithm::ridge => sklearn::ridge_classification, Algorithm::random_forest => sklearn::random_forest_classification, Algorithm::xgboost => sklearn::xgboost_classification, - Algorithm::xgboost_random_forest => { - sklearn::xgboost_random_forest_classification - } - Algorithm::stochastic_gradient_descent => { - sklearn::stochastic_gradient_descent_classification - } + Algorithm::xgboost_random_forest => sklearn::xgboost_random_forest_classification, + Algorithm::stochastic_gradient_descent => sklearn::stochastic_gradient_descent_classification, Algorithm::perceptron => sklearn::perceptron_classification, Algorithm::passive_aggressive => sklearn::passive_aggressive_classification, Algorithm::gaussian_process => sklearn::gaussian_process, @@ -494,12 +476,8 @@ impl Model { Algorithm::ada_boost => sklearn::ada_boost_classification, Algorithm::bagging => sklearn::bagging_classification, Algorithm::extra_trees => sklearn::extra_trees_classification, - Algorithm::gradient_boosting_trees => { - sklearn::gradient_boosting_trees_classification - } - Algorithm::hist_gradient_boosting => { - sklearn::hist_gradient_boosting_classification - } + Algorithm::gradient_boosting_trees => sklearn::gradient_boosting_trees_classification, + Algorithm::hist_gradient_boosting => sklearn::hist_gradient_boosting_classification, Algorithm::linear_svm => sklearn::linear_svm_classification, Algorithm::lightgbm => sklearn::lightgbm_classification, Algorithm::catboost => sklearn::catboost_classification, @@ -531,17 +509,17 @@ impl Model { } for (key, values) in self.search_params.0.as_object().unwrap() { if all_hyperparam_names.contains(key) { - error!("`{key}` cannot be present in both hyperparams and search_params. Please choose one or the other."); + error!( + "`{key}` cannot be present in both hyperparams and search_params. Please choose one or the other." + ); } all_hyperparam_names.push(key.to_string()); all_hyperparam_values.push(values.as_array().unwrap().to_vec()); } // The search space is all possible combinations - let all_hyperparam_values: Vec> = all_hyperparam_values - .into_iter() - .multi_cartesian_product() - .collect(); + let all_hyperparam_values: Vec> = + all_hyperparam_values.into_iter().multi_cartesian_product().collect(); let mut all_hyperparam_values = match self.search { Some(Search::random) => { // TODO support things like ranges to be random sampled @@ -587,17 +565,10 @@ impl Model { Task::regression => { #[cfg(all(feature = "python", any(test, feature = "pg_test")))] { - let sklearn_metrics = - crate::bindings::sklearn::regression_metrics(y_test, &y_hat).unwrap(); + let sklearn_metrics = crate::bindings::sklearn::regression_metrics(y_test, &y_hat).unwrap(); metrics.insert("sklearn_r2".to_string(), sklearn_metrics["r2"]); - metrics.insert( - "sklearn_mean_absolute_error".to_string(), - sklearn_metrics["mae"], - ); - metrics.insert( - "sklearn_mean_squared_error".to_string(), - sklearn_metrics["mse"], - ); + metrics.insert("sklearn_mean_absolute_error".to_string(), sklearn_metrics["mae"]); + metrics.insert("sklearn_mean_squared_error".to_string(), sklearn_metrics["mse"]); } let y_test = ArrayView1::from(&y_test); @@ -616,12 +587,9 @@ impl Model { Task::classification => { #[cfg(all(feature = "python", any(test, feature = "pg_test")))] { - let sklearn_metrics = crate::bindings::sklearn::classification_metrics( - y_test, - &y_hat, - dataset.num_distinct_labels, - ) - .unwrap(); + let sklearn_metrics = + crate::bindings::sklearn::classification_metrics(y_test, &y_hat, dataset.num_distinct_labels) + .unwrap(); if dataset.num_distinct_labels == 2 { metrics.insert("sklearn_roc_auc".to_string(), sklearn_metrics["roc_auc"]); @@ -629,10 +597,7 @@ impl Model { metrics.insert("sklearn_f1".to_string(), sklearn_metrics["f1"]); metrics.insert("sklearn_f1_micro".to_string(), sklearn_metrics["f1_micro"]); - metrics.insert( - "sklearn_precision".to_string(), - sklearn_metrics["precision"], - ); + metrics.insert("sklearn_precision".to_string(), sklearn_metrics["precision"]); metrics.insert("sklearn_recall".to_string(), sklearn_metrics["recall"]); metrics.insert("sklearn_accuracy".to_string(), sklearn_metrics["accuracy"]); metrics.insert("sklearn_mcc".to_string(), sklearn_metrics["mcc"]); @@ -646,10 +611,7 @@ impl Model { let y_hat = ArrayView1::from(&y_hat).mapv(Pr::new); let y_test: Vec = y_test.iter().map(|&i| i == 1.).collect(); - metrics.insert( - "roc_auc".to_string(), - y_hat.roc(&y_test).unwrap().area_under_curve(), - ); + metrics.insert("roc_auc".to_string(), y_hat.roc(&y_test).unwrap().area_under_curve()); metrics.insert("log_loss".to_string(), y_hat.log_loss(&y_test).unwrap()); } @@ -662,11 +624,8 @@ impl Model { let confusion_matrix = y_hat.confusion_matrix(y_test).unwrap(); // This has to be identical to Scikit. - let pgml_confusion_matrix = crate::metrics::ConfusionMatrix::new( - &y_test, - &y_hat, - dataset.num_distinct_labels, - ); + let pgml_confusion_matrix = + crate::metrics::ConfusionMatrix::new(&y_test, &y_hat, dataset.num_distinct_labels); // These are validated against Scikit and seem to be correct. metrics.insert( @@ -683,12 +642,9 @@ impl Model { Task::cluster => { #[cfg(feature = "python")] { - let sklearn_metrics = crate::bindings::sklearn::cluster_metrics( - dataset.num_features, - &dataset.x_test, - &y_hat, - ) - .unwrap(); + let sklearn_metrics = + crate::bindings::sklearn::cluster_metrics(dataset.num_features, &dataset.x_test, &y_hat) + .unwrap(); metrics.insert("silhouette".to_string(), sklearn_metrics["silhouette"]); } } @@ -703,10 +659,7 @@ impl Model { dataset: &Dataset, hyperparams: &Hyperparams, ) -> (Box, IndexMap) { - info!( - "Hyperparams: {}", - serde_json::to_string_pretty(hyperparams).unwrap() - ); + info!("Hyperparams: {}", serde_json::to_string_pretty(hyperparams).unwrap()); let fit = self.get_fit_function(); let now = Instant::now(); @@ -749,25 +702,11 @@ impl Model { } pub fn f1(&self) -> f32 { - self.metrics - .as_ref() - .unwrap() - .0 - .get("f1") - .unwrap() - .as_f64() - .unwrap() as f32 + self.metrics.as_ref().unwrap().0.get("f1").unwrap().as_f64().unwrap() as f32 } pub fn r2(&self) -> f32 { - self.metrics - .as_ref() - .unwrap() - .0 - .get("r2") - .unwrap() - .as_f64() - .unwrap() as f32 + self.metrics.as_ref().unwrap().0.get("r2").unwrap().as_f64().unwrap() as f32 } fn fit(&mut self, dataset: &Dataset) { @@ -955,9 +894,13 @@ impl Model { "INSERT INTO pgml.files (model_id, path, part, data) VALUES($1, 'estimator.rmp', 0, $2) RETURNING id", vec![ (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), - (PgBuiltInOids::BYTEAOID.oid(), self.bindings.as_ref().unwrap().to_bytes().into_datum()), + ( + PgBuiltInOids::BYTEAOID.oid(), + self.bindings.as_ref().unwrap().to_bytes().into_datum(), + ), ], - ).unwrap(); + ) + .unwrap(); } pub fn numeric_encode_features(&self, rows: &[pgrx::datum::AnyElement]) -> Vec { @@ -976,68 +919,47 @@ impl Model { pgrx_pg_sys::UNKNOWNOID => { error!("Type information missing for column: {:?}. If this is intended to be a TEXT or other categorical column, you will need to explicitly cast it, e.g. change `{:?}` to `CAST({:?} AS TEXT)`.", column.name, column.name, column.name); } - pgrx_pg_sys::TEXTOID - | pgrx_pg_sys::VARCHAROID - | pgrx_pg_sys::BPCHAROID => { + pgrx_pg_sys::TEXTOID | pgrx_pg_sys::VARCHAROID | pgrx_pg_sys::BPCHAROID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index); - element - .unwrap() - .unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) + element.unwrap().unwrap_or(snapshot::NULL_CATEGORY_KEY.to_string()) } pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::INT2OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::INT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::INT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::FLOAT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } pgrx_pg_sys::FLOAT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); element .unwrap() - .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| { - k.to_string() - }) + .map_or(snapshot::NULL_CATEGORY_KEY.to_string(), |k| k.to_string()) } _ => error!( "Unsupported type for categorical column: {:?}. oid: {:?}", @@ -1055,38 +977,27 @@ impl Model { pgrx_pg_sys::BOOLOID => { let element: Result, TryFromDatumError> = tuple.get_by_index(index); - features.push( - element.unwrap().map_or(f32::NAN, |v| v as u8 as f32), - ); + features.push(element.unwrap().map_or(f32::NAN, |v| v as u8 as f32)); } pgrx_pg_sys::INT2OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::INT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::INT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } pgrx_pg_sys::FLOAT4OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); features.push(element.unwrap().map_or(f32::NAN, |v| v)); } pgrx_pg_sys::FLOAT8OID => { - let element: Result, TryFromDatumError> = - tuple.get_by_index(index); - features - .push(element.unwrap().map_or(f32::NAN, |v| v as f32)); + let element: Result, TryFromDatumError> = tuple.get_by_index(index); + features.push(element.unwrap().map_or(f32::NAN, |v| v as f32)); } // TODO handle NULL to NaN for arrays pgrx_pg_sys::BOOLARRAYOID => { @@ -1140,9 +1051,7 @@ impl Model { } } } - _ => error!( - "This preprocessing requires Postgres `record` types created with `row()`." - ), + _ => error!("This preprocessing requires Postgres `record` types created with `row()`."), } } features @@ -1166,11 +1075,11 @@ impl Model { pub fn predict_joint(&self, features: &[f32]) -> Result> { match self.project.task { - Task::regression => self.bindings.as_ref().unwrap().predict( - features, - self.num_features, - self.num_classes, - ), + Task::regression => self + .bindings + .as_ref() + .unwrap() + .predict(features, self.num_features, self.num_classes), Task::classification => { bail!("You can't predict joint probabilities for a classification model") } diff --git a/pgml-extension/src/orm/project.rs b/pgml-extension/src/orm/project.rs index a30db3169..ea23ba80e 100644 --- a/pgml-extension/src/orm/project.rs +++ b/pgml-extension/src/orm/project.rs @@ -8,10 +8,8 @@ use pgrx::*; use crate::orm::*; -static PROJECT_ID_TO_DEPLOYED_MODEL_ID: PgLwLock> = - PgLwLock::new(); -static PROJECT_NAME_TO_PROJECT_ID: Lazy>> = - Lazy::new(|| Mutex::new(HashMap::new())); +static PROJECT_ID_TO_DEPLOYED_MODEL_ID: PgLwLock> = PgLwLock::new(); +static PROJECT_NAME_TO_PROJECT_ID: Lazy>> = Lazy::new(|| Mutex::new(HashMap::new())); /// Initialize shared memory. /// # Note @@ -56,23 +54,12 @@ impl Project { ); let (project_id, model_id) = match result { Ok(o) => o, - Err(_) => error!( - "No deployed model exists for the project named: `{}`", - project_name - ), + Err(_) => error!("No deployed model exists for the project named: `{}`", project_name), }; - let project_id = project_id.unwrap_or_else(|| { - error!( - "No deployed model exists for the project named: `{}`", - project_name - ) - }); - let model_id = model_id.unwrap_or_else(|| { - error!( - "No deployed model exists for the project named: `{}`", - project_name - ) - }); + let project_id = project_id + .unwrap_or_else(|| error!("No deployed model exists for the project named: `{}`", project_name)); + let model_id = model_id + .unwrap_or_else(|| error!("No deployed model exists for the project named: `{}`", project_name)); projects.insert(project_name.to_string(), project_id); let mut projects = PROJECT_ID_TO_DEPLOYED_MODEL_ID.exclusive(); if projects.len() == 1024 { @@ -83,10 +70,7 @@ impl Project { project_id } }; - *PROJECT_ID_TO_DEPLOYED_MODEL_ID - .share() - .get(&project_id) - .unwrap() + *PROJECT_ID_TO_DEPLOYED_MODEL_ID.share().get(&project_id).unwrap() } pub fn deploy(&self, model_id: i64, strategy: Strategy) { @@ -111,12 +95,14 @@ impl Project { let mut project: Option = None; Spi::connect(|client| { - let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE id = $1 LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::INT8OID.oid(), id.into_datum()), - ]) - ).unwrap().first(); + let result = client + .select( + "SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE id = $1 LIMIT 1;", + Some(1), + Some(vec![(PgBuiltInOids::INT8OID.oid(), id.into_datum())]), + ) + .unwrap() + .first(); if !result.is_empty() { project = Some(Project { id: result.get(1).unwrap().unwrap(), @@ -135,12 +121,14 @@ impl Project { let mut project = None; Spi::connect(|client| { - let result = client.select("SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE name = $1 LIMIT 1;", - Some(1), - Some(vec![ - (PgBuiltInOids::TEXTOID.oid(), name.into_datum()), - ]) - ).unwrap().first(); + let result = client + .select( + "SELECT id, name, task::TEXT, created_at, updated_at FROM pgml.projects WHERE name = $1 LIMIT 1;", + Some(1), + Some(vec![(PgBuiltInOids::TEXTOID.oid(), name.into_datum())]), + ) + .unwrap() + .first(); if !result.is_empty() { project = Some(Project { id: result.get(1).unwrap().unwrap(), diff --git a/pgml-extension/src/orm/snapshot.rs b/pgml-extension/src/orm/snapshot.rs index 85f697508..6a5973148 100644 --- a/pgml-extension/src/orm/snapshot.rs +++ b/pgml-extension/src/orm/snapshot.rs @@ -163,13 +163,10 @@ impl Column { pub(crate) fn scale(&self, value: f32) -> f32 { match self.preprocessor.scale { Scale::standard => (value - self.statistics.mean) / self.statistics.std_dev, - Scale::min_max => { - (value - self.statistics.min) / (self.statistics.max - self.statistics.min) - } + Scale::min_max => (value - self.statistics.min) / (self.statistics.max - self.statistics.min), Scale::max_abs => value / self.statistics.max_abs, Scale::robust => { - (value - self.statistics.median) - / (self.statistics.ventiles[15] - self.statistics.ventiles[5]) + (value - self.statistics.median) / (self.statistics.ventiles[15] - self.statistics.ventiles[5]) } Scale::preserve => value, } @@ -456,10 +453,7 @@ impl Snapshot { LIMIT 1; ", Some(1), - Some(vec![( - PgBuiltInOids::INT8OID.oid(), - project_id.into_datum(), - )]), + Some(vec![(PgBuiltInOids::INT8OID.oid(), project_id.into_datum())]), ) .unwrap() .first(); @@ -467,8 +461,7 @@ impl Snapshot { let jsonb: JsonB = result.get(7).unwrap().unwrap(); let columns: Vec = serde_json::from_value(jsonb.0).unwrap(); let jsonb: JsonB = result.get(8).unwrap().unwrap(); - let analysis: Option> = - Some(serde_json::from_value(jsonb.0).unwrap()); + let analysis: Option> = Some(serde_json::from_value(jsonb.0).unwrap()); let mut s = Snapshot { id: result.get(1).unwrap().unwrap(), @@ -505,8 +498,7 @@ impl Snapshot { // Validate table exists. let (schema_name, table_name) = Self::fully_qualified_table(relation_name); - let preprocessors: HashMap = - serde_json::from_value(preprocess.0).expect("is valid"); + let preprocessors: HashMap = serde_json::from_value(preprocess.0).expect("is valid"); Spi::connect(|mut client| { let mut columns: Vec = Vec::new(); @@ -674,9 +666,7 @@ impl Snapshot { } pub(crate) fn first_label(&self) -> &Column { - self.labels() - .find(|l| l.name == self.y_column_name[0]) - .unwrap() + self.labels().find(|l| l.name == self.y_column_name[0]).unwrap() } pub(crate) fn num_classes(&self) -> usize { @@ -716,9 +706,12 @@ impl Snapshot { match schema_name { None => { - let table_count = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public'", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()) - ]).unwrap().unwrap(); + let table_count = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = 'public'", + vec![(PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum())], + ) + .unwrap() + .unwrap(); let error = format!("Relation \"{}\" could not be found in the public schema. Please specify the table schema, e.g. pgml.{}", table_name, table_name); @@ -730,18 +723,19 @@ impl Snapshot { } Some(schema_name) => { - let exists = Spi::get_one_with_args::("SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2", vec![ - (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), - (PgBuiltInOids::TEXTOID.oid(), schema_name.clone().into_datum()), - ]).unwrap(); + let exists = Spi::get_one_with_args::( + "SELECT COUNT(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2", + vec![ + (PgBuiltInOids::TEXTOID.oid(), table_name.clone().into_datum()), + (PgBuiltInOids::TEXTOID.oid(), schema_name.clone().into_datum()), + ], + ) + .unwrap(); if exists == Some(1) { (schema_name, table_name) } else { - error!( - "Relation \"{}\".\"{}\" doesn't exist", - schema_name, table_name - ); + error!("Relation \"{}\".\"{}\" doesn't exist", schema_name, table_name); } } } @@ -818,12 +812,10 @@ impl Snapshot { }; match column.pg_type.as_str() { - "bpchar" | "text" | "varchar" => { - match row[column.position].value::().unwrap() { - Some(text) => vector.push(text), - None => error!("NULL training text is not handled"), - } - } + "bpchar" | "text" | "varchar" => match row[column.position].value::().unwrap() { + Some(text) => vector.push(text), + None => error!("NULL training text is not handled"), + }, _ => error!("only text type columns are supported"), } } @@ -906,24 +898,15 @@ impl Snapshot { } let mut analysis = IndexMap::new(); - analysis.insert( - "samples".to_string(), - numeric_encoded_dataset.num_rows as f32, - ); + analysis.insert("samples".to_string(), numeric_encoded_dataset.num_rows as f32); self.analysis = Some(analysis); // Record the analysis Spi::run_with_args( "UPDATE pgml.snapshots SET analysis = $1, columns = $2 WHERE id = $3", Some(vec![ - ( - PgBuiltInOids::JSONBOID.oid(), - JsonB(json!(self.analysis)).into_datum(), - ), - ( - PgBuiltInOids::JSONBOID.oid(), - JsonB(json!(self.columns)).into_datum(), - ), + (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.analysis)).into_datum()), + (PgBuiltInOids::JSONBOID.oid(), JsonB(json!(self.columns)).into_datum()), (PgBuiltInOids::INT8OID.oid(), self.id.into_datum()), ]), ) @@ -1001,14 +984,19 @@ impl Snapshot { // Categorical encoding types Some(categories) => { let key = match column.pg_type.as_str() { - "bool" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "int2" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "int4" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "int8" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "float4" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "float8" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - "bpchar" | "text" | "varchar" => row[column.position].value::().unwrap().map(|v| v.to_string() ), - _ => error!("Unhandled type for categorical variable: {} {:?}", column.name, column.pg_type) + "bool" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "int2" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "int4" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "int8" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "float4" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "float8" => row[column.position].value::().unwrap().map(|v| v.to_string()), + "bpchar" | "text" | "varchar" => { + row[column.position].value::().unwrap().map(|v| v.to_string()) + } + _ => error!( + "Unhandled type for categorical variable: {} {:?}", + column.name, column.pg_type + ), }; let key = key.unwrap_or_else(|| NULL_CATEGORY_KEY.to_string()); if i < num_train_rows { @@ -1018,16 +1006,18 @@ impl Snapshot { NULL_CATEGORY_KEY => 0_f32, // NULL values are always Category 0 _ => match &column.preprocessor.encode { Encode::target | Encode::native | Encode::one_hot { .. } => len as f32, - Encode::ordinal(values) => match values.iter().position(|v| v == key.as_str()) { - Some(i) => (i + 1) as f32, - None => error!("value is not present in ordinal: {:?}. Valid values: {:?}", key, values), + Encode::ordinal(values) => { + match values.iter().position(|v| v == key.as_str()) { + Some(i) => (i + 1) as f32, + None => error!( + "value is not present in ordinal: {:?}. Valid values: {:?}", + key, values + ), + } } - } + }, }; - Category { - value, - members: 0 - } + Category { value, members: 0 } }); category.members += 1; vector.push(category.value); @@ -1088,9 +1078,13 @@ impl Snapshot { vector.push(j as f32) } } - _ => error!("Unhandled type for quantitative array column: {} {:?}", column.name, column.pg_type) + _ => error!( + "Unhandled type for quantitative array column: {} {:?}", + column.name, column.pg_type + ), } - } else { // scalar + } else { + // scalar let float = match column.pg_type.as_str() { "bool" => row[column.position].value::().unwrap().map(|v| v as u8 as f32), "int2" => row[column.position].value::().unwrap().map(|v| v as f32), @@ -1098,7 +1092,10 @@ impl Snapshot { "int8" => row[column.position].value::().unwrap().map(|v| v as f32), "float4" => row[column.position].value::().unwrap(), "float8" => row[column.position].value::().unwrap().map(|v| v as f32), - _ => error!("Unhandled type for quantitative scalar column: {} {:?}", column.name, column.pg_type) + _ => error!( + "Unhandled type for quantitative scalar column: {} {:?}", + column.name, column.pg_type + ), }; match float { Some(f) => vector.push(f), @@ -1114,7 +1111,7 @@ impl Snapshot { let num_features = self.num_features(); let num_labels = self.num_labels(); - data = Some(Dataset{ + data = Some(Dataset { x_train, y_train, x_test, @@ -1129,7 +1126,8 @@ impl Snapshot { }); Ok::, i64>(Some(())) // this return type is nonsense - }).unwrap(); + }) + .unwrap(); let data = data.unwrap(); diff --git a/pgml-extension/src/vectors.rs b/pgml-extension/src/vectors.rs index ccaafa28a..b2114b7dd 100644 --- a/pgml-extension/src/vectors.rs +++ b/pgml-extension/src/vectors.rs @@ -115,18 +115,12 @@ fn divide_vector_d(vector: Array, dividend: Array) -> Vec { #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] fn norm_l0_s(vector: Array) -> f32 { - vector - .iter_deny_null() - .map(|a| if a == 0.0 { 0.0 } else { 1.0 }) - .sum() + vector.iter_deny_null().map(|a| if a == 0.0 { 0.0 } else { 1.0 }).sum() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l0")] fn norm_l0_d(vector: Array) -> f64 { - vector - .iter_deny_null() - .map(|a| if a == 0.0 { 0.0 } else { 1.0 }) - .sum() + vector.iter_deny_null().map(|a| if a == 0.0 { 0.0 } else { 1.0 }).sum() } #[pg_extern(immutable, parallel_safe, strict, name = "norm_l1")] @@ -334,11 +328,7 @@ impl Aggregate for SumS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state<'a>( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state<'a>(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -356,11 +346,7 @@ impl Aggregate for SumS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -397,11 +383,7 @@ impl Aggregate for SumD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -419,11 +401,7 @@ impl Aggregate for SumD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -460,11 +438,7 @@ impl Aggregate for MaxAbsS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -484,11 +458,7 @@ impl Aggregate for MaxAbsS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -527,11 +497,7 @@ impl Aggregate for MaxAbsD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -551,11 +517,7 @@ impl Aggregate for MaxAbsD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -594,11 +556,7 @@ impl Aggregate for MaxS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -618,11 +576,7 @@ impl Aggregate for MaxS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -661,11 +615,7 @@ impl Aggregate for MaxD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -685,11 +635,7 @@ impl Aggregate for MaxD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -728,11 +674,7 @@ impl Aggregate for MinS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -752,11 +694,7 @@ impl Aggregate for MinS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -795,11 +733,7 @@ impl Aggregate for MinD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -819,11 +753,7 @@ impl Aggregate for MinD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -862,11 +792,7 @@ impl Aggregate for MinAbsS { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -886,11 +812,7 @@ impl Aggregate for MinAbsS { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -929,11 +851,7 @@ impl Aggregate for MinAbsD { type Finalize = Vec; #[pgrx(immutable, parallel_safe)] - fn state( - mut current: Self::State, - arg: Self::Args, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn state(mut current: Self::State, arg: Self::Args, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match arg { None => {} Some(arg) => match current { @@ -953,11 +871,7 @@ impl Aggregate for MinAbsD { } #[pgrx(immutable, parallel_safe)] - fn combine( - mut first: Self::State, - second: Self::State, - _fcinfo: pg_sys::FunctionCallInfo, - ) -> Self::State { + fn combine(mut first: Self::State, second: Self::State, _fcinfo: pg_sys::FunctionCallInfo) -> Self::State { match (&mut first, &second) { (None, None) => None, (Some(_), None) => first, @@ -1043,65 +957,57 @@ mod tests { #[pg_test] fn test_add_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_add_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.add(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([2.0, 4.0, 6.0].to_vec()))); } #[pg_test] fn test_subtract_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_subtract_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.subtract(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([0.0, 0.0, 0.0].to_vec()))); } #[pg_test] fn test_multiply_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.multiply(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([1.0, 4.0, 9.0].to_vec()))); } #[pg_test] fn test_divide_vector_s() { - let result = Spi::get_one::>( - "SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])", - ); + let result = + Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float4[], ARRAY[1.0, 2.0, 3.0]::float4[])"); assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } #[pg_test] fn test_divide_vector_d() { - let result = Spi::get_one::>( - "SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])", - ); + let result = + Spi::get_one::>("SELECT pgml.divide(ARRAY[1,2,3]::float8[], ARRAY[1.0, 2.0, 3.0]::float8[])"); assert_eq!(result, Ok(Some([1.0, 1.0, 1.0].to_vec()))); } @@ -1178,9 +1084,7 @@ mod tests { let result = Spi::get_one::>("SELECT pgml.normalize_l1(ARRAY[1,2,3]::float8[])"); assert_eq!( result, - Ok(Some( - [0.16666666666666666, 0.3333333333333333, 0.5].to_vec() - )) + Ok(Some([0.16666666666666666, 0.3333333333333333, 0.5].to_vec())) ); } @@ -1217,67 +1121,48 @@ mod tests { #[pg_test] fn test_normalize_max_d() { let result = Spi::get_one::>("SELECT pgml.normalize_max(ARRAY[1,2,3]::float8[])"); - assert_eq!( - result, - Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec())) - ); + assert_eq!(result, Ok(Some([0.3333333333333333, 0.6666666666666666, 1.0].to_vec()))); } #[pg_test] fn test_distance_l1_s() { - let result = Spi::get_one::( - "SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l1_d() { - let result = Spi::get_one::( - "SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l1(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_s() { - let result = Spi::get_one::( - "SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_distance_l2_d() { - let result = Spi::get_one::( - "SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.distance_l2(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); assert_eq!(result, Ok(Some(0.0))); } #[pg_test] fn test_dot_product_s() { - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[1,2,3]::float4[])"); assert_eq!(result, Ok(Some(14.0))); - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float4[],ARRAY[2,3,4]::float4[])"); assert_eq!(result, Ok(Some(20.0))); } #[pg_test] fn test_dot_product_d() { - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[1,2,3]::float8[])"); assert_eq!(result, Ok(Some(14.0))); - let result = Spi::get_one::( - "SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])", - ); + let result = Spi::get_one::("SELECT pgml.dot_product(ARRAY[1,2,3]::float8[],ARRAY[2,3,4]::float8[])"); assert_eq!(result, Ok(Some(20.0))); } @@ -1299,7 +1184,10 @@ mod tests { let want = 0.9925833; assert!((got - want).abs() < F32_TOLERANCE); - let got = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float4[], ARRAY[0,0,1,1,0,1,1]::float4[])").unwrap() + let got = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float4[], ARRAY[0,0,1,1,0,1,1]::float4[])", + ) + .unwrap() .unwrap(); let want = 0.4472136; assert!((got - want).abs() < F32_TOLERANCE); @@ -1323,7 +1211,11 @@ mod tests { let want = 0.9925833339709303; assert!((got - want).abs() < F64_TOLERANCE); - let got = Spi::get_one::("SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float8[], ARRAY[0,0,1,1,0,1,1]::float8[])").unwrap().unwrap(); + let got = Spi::get_one::( + "SELECT pgml.cosine_similarity(ARRAY[1,1,1,1,1,0,0]::float8[], ARRAY[0,0,1,1,0,1,1]::float8[])", + ) + .unwrap() + .unwrap(); let want = 0.4472135954999579; assert!((got - want).abs() < F64_TOLERANCE); } pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy