From 0dca661154a167ab194f751bcd8d915605d6541f Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:25:39 -0700 Subject: [PATCH 1/4] SDK - Fixed async iterator bug --- pgml-sdks/pgml/src/transformer_pipeline.rs | 139 +++++---------------- pgml-sdks/pgml/src/types.rs | 5 +- 2 files changed, 36 insertions(+), 108 deletions(-) diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index 43154615b..ce114ea95 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -1,12 +1,6 @@ use anyhow::Context; -use futures::Stream; use rust_bridge::{alias, alias_methods}; -use sqlx::{postgres::PgRow, Row}; -use sqlx::{Postgres, Transaction}; -use std::collections::VecDeque; -use std::future::Future; -use std::pin::Pin; -use std::task::Poll; +use sqlx::Row; use tracing::instrument; /// Provides access to builtin database methods @@ -22,99 +16,6 @@ use crate::{get_or_initialize_pool, types::Json}; #[cfg(feature = "python")] use crate::types::{GeneralJsonAsyncIteratorPython, JsonPython}; -#[allow(clippy::type_complexity)] -struct TransformerStream { - transaction: Option>, - future: Option, sqlx::Error>> + Send + 'static>>>, - commit: Option> + Send + 'static>>>, - done: bool, - query: String, - db_batch_size: i32, - results: VecDeque, -} - -impl std::fmt::Debug for TransformerStream { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("TransformerStream").finish() - } -} - -impl TransformerStream { - fn new(transaction: Transaction<'static, Postgres>, db_batch_size: i32) -> Self { - let query = format!("FETCH {} FROM c", db_batch_size); - Self { - transaction: Some(transaction), - future: None, - commit: None, - done: false, - query, - db_batch_size, - results: VecDeque::new(), - } - } -} - -impl Stream for TransformerStream { - type Item = anyhow::Result; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> Poll> { - if self.done { - if let Some(c) = self.commit.as_mut() { - if c.as_mut().poll(cx).is_ready() { - self.commit = None; - } - } - } else { - if self.future.is_none() { - unsafe { - let s = self.as_mut().get_unchecked_mut(); - let s: *mut Self = s; - let s = Box::leak(Box::from_raw(s)); - s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), - )); - } - } - - if let Poll::Ready(o) = self.as_mut().future.as_mut().unwrap().as_mut().poll(cx) { - let rows = o?; - if rows.len() < self.db_batch_size as usize { - self.done = true; - unsafe { - let s = self.as_mut().get_unchecked_mut(); - let transaction = std::mem::take(&mut s.transaction).unwrap(); - s.commit = Some(Box::pin(transaction.commit())); - } - } else { - unsafe { - let s = self.as_mut().get_unchecked_mut(); - let s: *mut Self = s; - let s = Box::leak(Box::from_raw(s)); - s.future = Some(Box::pin( - sqlx::query(&s.query).fetch_all(&mut **s.transaction.as_mut().unwrap()), - )); - } - } - for r in rows.into_iter() { - self.results.push_back(r) - } - } - } - - if !self.results.is_empty() { - let r = self.results.pop_front().unwrap(); - Poll::Ready(Some(Ok(r.get::(0)))) - } else if self.done { - Poll::Ready(None) - } else { - Poll::Pending - } - } -} - #[alias_methods(new, transform, transform_stream)] impl TransformerPipeline { /// Creates a new [TransformerPipeline] @@ -200,7 +101,7 @@ impl TransformerPipeline { ) -> anyhow::Result { let pool = get_or_initialize_pool(&self.database_url).await?; let args = args.unwrap_or_default(); - let batch_size = batch_size.unwrap_or(10); + let batch_size = batch_size.unwrap_or(1); let mut transaction = pool.begin().await?; // We set the task in the new constructor so we can unwrap here @@ -234,10 +135,36 @@ impl TransformerPipeline { .await?; } - Ok(GeneralJsonAsyncIterator(Box::pin(TransformerStream::new( - transaction, - batch_size, - )))) + let s = futures::stream::try_unfold(transaction, move |mut transaction| async move { + let query = format!("FETCH {} FROM c", batch_size); + let mut res: Vec = sqlx::query_scalar(&query) + .fetch_all(&mut *transaction) + .await?; + if !res.is_empty() { + if batch_size > 1 { + let res: Vec = res + .into_iter() + .map(|v| { + v.0.as_array() + .context("internal SDK error - cannot parse db value as array. Please post a new github issue") + .map(|v| { + v[0].as_str() + .context( + "internal SDK error - cannot parse db value as string. Please post a new github issue", + ) + .map(|v| v.to_owned()) + }) + }) + .collect::>>>()??; + Ok(Some((serde_json::json!(res).into(), transaction))) + } else { + Ok(Some((std::mem::take(&mut res[0]), transaction))) + } + } else { + Ok(None) + } + }); + Ok(GeneralJsonAsyncIterator(Box::pin(s))) } } @@ -305,7 +232,7 @@ mod tests { serde_json::json!("AI is going to").into(), Some( serde_json::json!({ - "max_new_tokens": 10 + "max_new_tokens": 30 }) .into(), ), diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 34d93be5c..386e4bc53 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -1,5 +1,5 @@ use anyhow::Context; -use futures::{Stream, StreamExt}; +use futures::{stream::BoxStream, Stream, StreamExt}; use itertools::Itertools; use rust_bridge::alias_manual; use sea_query::Iden; @@ -126,7 +126,8 @@ impl IntoTableNameAndSchema for String { /// A wrapper around `std::pin::Pin> + Send>>` #[derive(alias_manual)] pub struct GeneralJsonAsyncIterator( - pub std::pin::Pin> + Send>>, + // pub std::pin::Pin> + Send>>, + pub BoxStream<'static, anyhow::Result>, ); impl Stream for GeneralJsonAsyncIterator { From bc5f906806b60175c420707df2d06f2f4dba6ec1 Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:26:47 -0700 Subject: [PATCH 2/4] cleanup --- pgml-sdks/pgml/src/types.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pgml-sdks/pgml/src/types.rs b/pgml-sdks/pgml/src/types.rs index 386e4bc53..b1c14f88a 100644 --- a/pgml-sdks/pgml/src/types.rs +++ b/pgml-sdks/pgml/src/types.rs @@ -123,12 +123,9 @@ impl IntoTableNameAndSchema for String { } } -/// A wrapper around `std::pin::Pin> + Send>>` +/// A wrapper around `BoxStream<'static, anyhow::Result>` #[derive(alias_manual)] -pub struct GeneralJsonAsyncIterator( - // pub std::pin::Pin> + Send>>, - pub BoxStream<'static, anyhow::Result>, -); +pub struct GeneralJsonAsyncIterator(pub BoxStream<'static, anyhow::Result>); impl Stream for GeneralJsonAsyncIterator { type Item = anyhow::Result; From 30d094f957cabee4fa41b9aca4c16e750cfb1e9c Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:35:09 -0700 Subject: [PATCH 3/4] make sure to end the transaction --- pgml-sdks/pgml/src/transformer_pipeline.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/pgml-sdks/pgml/src/transformer_pipeline.rs b/pgml-sdks/pgml/src/transformer_pipeline.rs index ce114ea95..7a6141675 100644 --- a/pgml-sdks/pgml/src/transformer_pipeline.rs +++ b/pgml-sdks/pgml/src/transformer_pipeline.rs @@ -161,6 +161,7 @@ impl TransformerPipeline { Ok(Some((std::mem::take(&mut res[0]), transaction))) } } else { + transaction.commit().await?; Ok(None) } }); From 351ed336de2735d1e57fe9ffb01f1163f521208f Mon Sep 17 00:00:00 2001 From: SilasMarvin <19626586+SilasMarvin@users.noreply.github.com> Date: Mon, 15 Apr 2024 13:37:03 -0700 Subject: [PATCH 4/4] Bump sdk version --- pgml-sdks/pgml/Cargo.toml | 2 +- pgml-sdks/pgml/javascript/package.json | 2 +- pgml-sdks/pgml/pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pgml-sdks/pgml/Cargo.toml b/pgml-sdks/pgml/Cargo.toml index 9424190d9..e78e7413a 100644 --- a/pgml-sdks/pgml/Cargo.toml +++ b/pgml-sdks/pgml/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgml" -version = "1.0.1" +version = "1.0.2" edition = "2021" authors = ["PosgresML "] homepage = "https://postgresml.org/" diff --git a/pgml-sdks/pgml/javascript/package.json b/pgml-sdks/pgml/javascript/package.json index 379b17fd0..1dd507712 100644 --- a/pgml-sdks/pgml/javascript/package.json +++ b/pgml-sdks/pgml/javascript/package.json @@ -1,6 +1,6 @@ { "name": "pgml", - "version": "1.0.1", + "version": "1.0.2", "description": "Open Source Alternative for Building End-to-End Vector Search Applications without OpenAI & Pinecone", "keywords": [ "postgres", diff --git a/pgml-sdks/pgml/pyproject.toml b/pgml-sdks/pgml/pyproject.toml index 6197b02c7..26fe8b4d0 100644 --- a/pgml-sdks/pgml/pyproject.toml +++ b/pgml-sdks/pgml/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "maturin" [project] name = "pgml" requires-python = ">=3.7" -version = "1.0.1" +version = "1.0.2" description = "Python SDK is designed to facilitate the development of scalable vector search applications on PostgreSQL databases." authors = [ {name = "PostgresML", email = "team@postgresml.org"}, pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy