diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 9424190d9..e78e7413a 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.0.1" +version = "1.0.2" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 379b17fd0..1dd507712 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,6 +1,6 @@ { "name": "pgml", - "version": "1.0.1", + "version": "1.0.2", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", "keywords": [ "postgres", diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index 6197b02c7..26fe8b4d0 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "1.0.1" +version = "1.0.2" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 43154615b..7a6141675 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -1,12 +1,6 @@ use anyhow::Context; -use futures::Stream; use rust_bridge::{alias, alias_methods}; -use sqlx::{postgres::PgRow, Row}; -use sqlx::{Postgres, Transaction}; -use std::collections::VecDeque; -use std::future::Future; -use std::pin::Pin; -use std::task::Poll; +use sqlx::Row; use tracing::instrument; /// Provides access to builtin database methods @@ -22,99 +16,6 @@ use crate::{get_or_initialize_pool, types::Json}; #[cfg(feature = "python")] use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython}; -#[allow(clippy::type_complexity)] -struct TransformerStream { - transaction: Option>, - future: Option, sqlx::Error>> + Send + 'static>>>, - commit: Option> + Send + 'static>>>, - done: bool, - query: String, - db_batch_size: i32, - results: VecDeque, -} - -impl std::fmt::Debug for TransformerStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TransformerStream").finish() - } -} - -impl TransformerStream { - fn new(transaction: Transaction<'static, Postgres>, db_batch_size: i32) -> Self { - let query = format!("FETCH {} FROM c", db_batch_size); - Self { - transaction: Some(transaction), - future: None, - commit: None, - done: false, - query, - db_batch_size, - results: VecDeque::new(), - } - } -} - -impl Stream for TransformerStream { - type Item = anyhow::Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if self.done { - if let Some(c) = self.commit.as_mut() { - if c.as_mut().poll(cx).is_ready() { - self.commit = None; - } - } - } else { - if self.future.is_none() { - unsafe { - let s = self.as_mut().get_unchecked_mut(); - let s: *mut Self = s; - let s = Box::leak(Box::from_raw(s)); - s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), - )); - } - } - - if let Poll::Ready(o) = self.as_mut().future.as_mut().unwrap().as_mut().poll(cx) { - let rows = o?; - if rows.len() < self.db_batch_size as usize { - self.done = true; - unsafe { - let s = self.as_mut().get_unchecked_mut(); - let transaction = std::mem::take(&mut s.transaction).unwrap(); - s.commit = Some(Box::pin(transaction.commit())); - } - } else { - unsafe { - let s = self.as_mut().get_unchecked_mut(); - let s: *mut Self = s; - let s = Box::leak(Box::from_raw(s)); - s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), - )); - } - } - for r in rows.into_iter() { - self.results.push_back(r) - } - } - } - - if !self.results.is_empty() { - let r = self.results.pop_front().unwrap(); - Poll::Ready(Some(Ok(r.get::(0)))) - } else if self.done { - Poll::Ready(None) - } else { - Poll::Pending - } - } -} - #[alias_methods(new, transform, transform_stream)] impl TransformerPipeline { /// Creates a new [TransformerPipeline] @@ -200,7 +101,7 @@ impl TransformerPipeline { ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let args = args.unwrap_or_default(); - let batch_size = batch_size.unwrap_or(10); + let batch_size = batch_size.unwrap_or(1); let mut transaction = pool.begin().await?; // We set the task in the new constructor so we can unwrap here @@ -234,10 +135,37 @@ impl TransformerPipeline { .await?; } - Ok(GeneralJsonAsyncIterator(Box::pin(TransformerStream::new( - transaction, - batch_size, - )))) + let s = futures::stream::try_unfold(transaction, move |mut transaction| async move { + let query = format!("FETCH {} FROM c", batch_size); + let mut res: Vec = sqlx::query_scalar(&query) + .fetch_all(&mut *transaction) + .await?; + if !res.is_empty() { + if batch_size > 1 { + let res: Vec = res + .into_iter() + .map(|v| { + v.0.as_array() + .context("internal SDK error - cannot parse db value as array. Please post a new github issue") + .map(|v| { + v[0].as_str() + .context( + "internal SDK error - cannot parse db value as string. Please post a new github issue", + ) + .map(|v| v.to_owned()) + }) + }) + .collect::>>>()??; + Ok(Some((serde_json::json!(res).into(), transaction))) + } else { + Ok(Some((std::mem::take(&mut res[0]), transaction))) + } + } else { + transaction.commit().await?; + Ok(None) + } + }); + Ok(GeneralJsonAsyncIterator(Box::pin(s))) } } @@ -305,7 +233,7 @@ mod tests { serde_json::json!("AI is going to").into(), Some( serde_json::json!({ - "max_new_tokens": 10 + "max_new_tokens": 30 }) .into(), ), diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 34d93be5c..b1c14f88a 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,5 +1,5 @@ use anyhow::Context; -use futures::{Stream, StreamExt}; +use futures::{stream::BoxStream, Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; @@ -123,11 +123,9 @@ impl IntoTableNameAndSchema for String { } } -/// A wrapper around `std::pin::Pin> + Send>>` +/// A wrapper around `BoxStream<'static, anyhow::Result>` #[derive(alias_manual)] -pub struct GeneralJsonAsyncIterator( - pub std::pin::Pin> + Send>>, -); +pub struct GeneralJsonAsyncIterator(pub BoxStream<'static, anyhow::Result>); impl Stream for GeneralJsonAsyncIterator { type Item = anyhow::Result; pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy