diff --git a/pgml-extension/src/api.rs b/pgml-extension/src/api.rs index ad952e485..5b8ddc4e7 100644 --- a/pgml-extension/src/api.rs +++ b/pgml-extension/src/api.rs @@ -4,6 +4,8 @@ use std::str::FromStr; use ndarray::Zip; use pgrx::iter::{SetOfIterator, TableIterator}; use pgrx::*; +use pyo3::prelude::*; +use pyo3::types::{IntoPyDict, PyDict}; #[cfg(feature = "python")] use serde_json::json; @@ -632,6 +634,75 @@ pub fn transform_string( } } +struct TransformStreamIterator { + locals: Py, +} + +impl TransformStreamIterator { + fn new(python_iter: Py) -> Self { + let locals = Python::with_gil(|py| -> Result, PyErr> { + Ok([("python_iter", python_iter)].into_py_dict(py).into()) + }) + .map_err(|e| error!("{e}")) + .unwrap(); + Self { locals } + } +} + +impl Iterator for TransformStreamIterator { + type Item = String; + fn next(&mut self) -> Option { + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + Python::with_gil(|py| -> Result, PyErr> { + let code = "next(python_iter)"; + let res: &PyAny = py.eval(code, Some(self.locals.as_ref(py)), None)?; + if res.is_none() { + Ok(None) + } else { + let res: String = res.extract()?; + Ok(Some(res)) + } + }) + .map_err(|e| error!("{e}")) + .unwrap() + } +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_json( + task: JsonB, + args: default!(JsonB, "'{}'"), + input: default!(&str, "''"), + cache: default!(bool, false), +) -> SetOfIterator<'static, String> { + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = crate::bindings::transformers::transform_stream(&task.0, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); + let res = TransformStreamIterator::new(python_iter); + SetOfIterator::new(res) +} + +#[cfg(all(feature = "python", not(feature = "use_as_lib")))] +#[pg_extern(immutable, parallel_safe, name = "transform_stream")] +#[allow(unused_variables)] // cache is maintained for api compatibility +pub fn transform_stream_string( + task: String, + args: default!(JsonB, "'{}'"), + input: default!(&str, "''"), + cache: default!(bool, false), +) -> SetOfIterator<'static, String> { + let task_json = json!({ "task": task }); + // We can unwrap this becuase if there is an error the current transaction is aborted in the map_err call + let python_iter = crate::bindings::transformers::transform_stream(&task_json, &args.0, input) + .map_err(|e| error!("{e}")) + .unwrap(); + let res = TransformStreamIterator::new(python_iter); + SetOfIterator::new(res) +} + #[cfg(feature = "python")] #[pg_extern(immutable, parallel_safe, name = "generate")] fn generate(project_name: &str, inputs: &str, config: default!(JsonB, "'{}'")) -> String { diff --git a/pgml-extension/src/bindings/transformers/mod.rs b/pgml-extension/src/bindings/transformers/mod.rs index c4e262761..8871c8458 100644 --- a/pgml-extension/src/bindings/transformers/mod.rs +++ b/pgml-extension/src/bindings/transformers/mod.rs @@ -16,41 +16,10 @@ use super::TracebackError; pub mod whitelist; -create_pymodule!("/src/bindings/transformers/transformers.py"); - -pub fn transform( - task: &serde_json::Value, - args: &serde_json::Value, - inputs: Vec<&str>, -) -> Result { - crate::bindings::python::activate()?; - - whitelist::verify_task(task)?; - - let task = serde_json::to_string(task)?; - let args = serde_json::to_string(args)?; - let inputs = serde_json::to_string(&inputs)?; +mod transformers; +pub use transformers::*; - let results = Python::with_gil(|py| -> Result { - let transform: Py = get_module!(PY_MODULE) - .getattr(py, "transform") - .format_traceback(py)?; - - let output = transform - .call1( - py, - PyTuple::new( - py, - &[task.into_py(py), args.into_py(py), inputs.into_py(py)], - ), - ) - .format_traceback(py)?; - - output.extract(py).format_traceback(py) - })?; - - Ok(serde_json::from_str(&results)?) -} +create_pymodule!("/src/bindings/transformers/transformers.py"); pub fn get_model_from(task: &Value) -> Result { Python::with_gil(|py| -> Result { diff --git a/pgml-extension/src/bindings/transformers/transformers.py b/pgml-extension/src/bindings/transformers/transformers.py index 8b1d1a43d..2117cb9f6 100644 --- a/pgml-extension/src/bindings/transformers/transformers.py +++ b/pgml-extension/src/bindings/transformers/transformers.py @@ -2,6 +2,8 @@ import os import shutil import time +import queue +import sys import datasets from InstructorEmbedding import INSTRUCTOR @@ -38,7 +40,10 @@ PegasusTokenizer, TrainingArguments, Trainer, + TextStreamer, ) +from threading import Thread +from typing import Optional __cache_transformer_by_model_id = {} __cache_sentence_transformer_by_name = {} @@ -59,14 +64,17 @@ "bool": torch.bool, } + class PgMLException(Exception): pass + def orjson_default(obj): if isinstance(obj, numpy.float32): return float(obj) raise TypeError + def convert_dtype(kwargs): if "torch_dtype" in kwargs: kwargs["torch_dtype"] = DTYPE_MAP[kwargs["torch_dtype"]] @@ -86,30 +94,96 @@ def ensure_device(kwargs): else: kwargs["device"] = "cpu" +# A copy of HuggingFace's with small changes in the __next__ to not raise an exception +class TextIteratorStreamer(TextStreamer): + def __init__( + self, tokenizer, skip_prompt = False, timeout = None, **decode_kwargs + ): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.text_queue = queue.Queue() + self.stop_signal = None + self.timeout = timeout + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" + self.text_queue.put(text, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get(timeout=self.timeout) + if value != self.stop_signal: + return value + class GPTQPipeline(object): def __init__(self, model_name, **task): from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig from huggingface_hub import snapshot_download + model_path = snapshot_download(model_name) quantized_config = BaseQuantizeConfig.from_pretrained(model_path) - self.model = AutoGPTQForCausalLM.from_quantized(model_path, quantized_config=quantized_config, **task) + self.model = AutoGPTQForCausalLM.from_quantized( + model_path, quantized_config=quantized_config, **task + ) if "use_fast_tokenizer" in task: - self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=task.pop("use_fast_tokenizer")) + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=task.pop("use_fast_tokenizer") + ) else: self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.task = "text-generation" + def stream(self, inputs, **kwargs): + streamer = TextIteratorStreamer(self.tokenizer) + inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: - tokens = self.tokenizer(input, return_tensors="pt").to(self.model.device).input_ids + tokens = ( + self.tokenizer(input, return_tensors="pt") + .to(self.model.device) + .input_ids + ) token_ids = self.model.generate(input_ids=tokens, **kwargs)[0] outputs.append(self.tokenizer.decode(token_ids)) return outputs +class ThreadedGeneratorIterator: + def __init__(self, output, starting_input): + self.output = output + self.done = False + self.q = queue.Queue() + self.q.put(starting_input) + + def do_work(): + for x in self.output: + self.q.put(x) + self.done = True + + thread = Thread(target=do_work) + thread.start() + + def __iter__(self): + return self + + def __next__(self): + if not self.done or not self.q.empty(): + v = self.q.get() + self.q.task_done() + return v + + class GGMLPipeline(object): def __init__(self, model_name, **task): import ctransformers @@ -117,10 +191,16 @@ def __init__(self, model_name, **task): task.pop("model") task.pop("task") task.pop("device") - self.model = ctransformers.AutoModelForCausalLM.from_pretrained(model_name, **task) + self.model = ctransformers.AutoModelForCausalLM.from_pretrained( + model_name, **task + ) self.tokenizer = None self.task = "text-generation" + def stream(self, inputs, **kwargs): + output = self.model(inputs[0], stream=True, **kwargs) + return ThreadedGeneratorIterator(output, inputs[0]) + def __call__(self, inputs, **kwargs): outputs = [] for input in inputs: @@ -134,31 +214,42 @@ def __init__(self, model_name, **kwargs): # to the model constructor, so we construct the model/tokenizer manually if possible, # but that is only possible when the task is passed in, since if you pass the model # to the pipeline constructor, the task will no longer be inferred from the default... - if "task" in kwargs and model_name is not None and kwargs["task"] in [ - "text-classification", - "question-answering", - "summarization", - "translation", - "text-generation" - ]: + if ( + "task" in kwargs + and model_name is not None + and kwargs["task"] + in [ + "text-classification", + "question-answering", + "summarization", + "translation", + "text-generation", + ] + ): self.task = kwargs.pop("task") kwargs.pop("model", None) if self.task == "text-classification": - self.model = AutoModelForSequenceClassification.from_pretrained(model_name, **kwargs) + self.model = AutoModelForSequenceClassification.from_pretrained( + model_name, **kwargs + ) elif self.task == "question-answering": - self.model = AutoModelForQuestionAnswering.from_pretrained(model_name, **kwargs) + self.model = AutoModelForQuestionAnswering.from_pretrained( + model_name, **kwargs + ) elif self.task == "summarization" or self.task == "translation": self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, **kwargs) elif self.task == "text-generation": self.model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs) else: raise PgMLException(f"Unhandled task: {self.task}") - + if "use_auth_token" in kwargs: - self.tokenizer = AutoTokenizer.from_pretrained(model_name,use_auth_token=kwargs["use_auth_token"]) + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, use_auth_token=kwargs["use_auth_token"] + ) else: self.tokenizer = AutoTokenizer.from_pretrained(model_name) - + self.pipe = transformers.pipeline( self.task, model=self.model, @@ -169,9 +260,19 @@ def __init__(self, model_name, **kwargs): self.task = self.pipe.task self.model = self.pipe.model if self.pipe.tokenizer is None: - self.pipe.tokenizer = AutoTokenizer.from_pretrained(self.model.name_or_path) + self.pipe.tokenizer = AutoTokenizer.from_pretrained( + self.model.name_or_path + ) self.tokenizer = self.pipe.tokenizer + def stream(self, inputs, **kwargs): + streamer = TextIteratorStreamer(self.tokenizer) + inputs = self.tokenizer(inputs, return_tensors="pt").to(self.model.device) + generation_kwargs = dict(inputs, streamer=streamer, **kwargs) + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + return streamer + def __call__(self, inputs, **kwargs): return self.pipe(inputs, **kwargs) @@ -180,7 +281,7 @@ def get_model_from(task): task = orjson.loads(task) if "model" in task: return task["model"] - + if "task" in task: model = transformers.pipelines.SUPPORTED_TASKS[task["task"]]["default"]["model"] ty = "tf" if "tf" in model else "pt" @@ -222,7 +323,7 @@ def transform_using(pipeline, args, inputs): return orjson.dumps(pipeline(inputs, **args), default=orjson_default).decode() -def transform(task, args, inputs): +def transform(task, args, inputs, stream=False): task = orjson.loads(task) args = orjson.loads(args) inputs = orjson.loads(inputs) @@ -238,12 +339,14 @@ def transform(task, args, inputs): inputs = [orjson.loads(input) for input in inputs] convert_eos_token(pipe.tokenizer, args) + if stream: + return pipe.stream(inputs, **args) return orjson.dumps(pipe(inputs, **args), default=orjson_default).decode() def create_embedding(transformer): instructor = transformer.startswith("hkunlp/instructor") - klass = INSTRUCTOR if instructor else SentenceTransformer + klass = INSTRUCTOR if instructor else SentenceTransformer return klass(transformer) @@ -257,7 +360,7 @@ def embed_using(model, transformer, inputs, kwargs): instruction = kwargs.pop("instruction") for text in inputs: texts_with_instructions.append([instruction, text]) - + inputs = texts_with_instructions return model.encode(inputs, **kwargs) @@ -269,7 +372,9 @@ def embed(transformer, inputs, kwargs): ensure_device(kwargs) if transformer not in __cache_sentence_transformer_by_name: - __cache_sentence_transformer_by_name[transformer] = create_embedding(transformer) + __cache_sentence_transformer_by_name[transformer] = create_embedding( + transformer + ) model = __cache_sentence_transformer_by_name[transformer] return embed_using(model, transformer, inputs, kwargs) @@ -734,5 +839,3 @@ def generate(model_id, data, config): ) all_preds.extend(decoded_preds) return all_preds - - diff --git a/pgml-extension/src/bindings/transformers/transformers.rs b/pgml-extension/src/bindings/transformers/transformers.rs new file mode 100644 index 000000000..55d59b070 --- /dev/null +++ b/pgml-extension/src/bindings/transformers/transformers.rs @@ -0,0 +1,77 @@ +use super::whitelist; +use super::TracebackError; +use anyhow::Result; +use pyo3::prelude::*; +use pyo3::types::PyTuple; +create_pymodule!("/src/bindings/transformers/transformers.py"); + +pub fn transform( + task: &serde_json::Value, + args: &serde_json::Value, + inputs: Vec<&str>, +) -> Result { + crate::bindings::python::activate()?; + + whitelist::verify_task(task)?; + + let task = serde_json::to_string(task)?; + let args = serde_json::to_string(args)?; + let inputs = serde_json::to_string(&inputs)?; + + let results = Python::with_gil(|py| -> Result { + let transform: Py = get_module!(PY_MODULE) + .getattr(py, "transform") + .format_traceback(py)?; + + let output = transform + .call1( + py, + PyTuple::new( + py, + &[task.into_py(py), args.into_py(py), inputs.into_py(py)], + ), + ) + .format_traceback(py)?; + + output.extract(py).format_traceback(py) + })?; + + Ok(serde_json::from_str(&results)?) +} + +pub fn transform_stream( + task: &serde_json::Value, + args: &serde_json::Value, + input: &str, +) -> Result> { + crate::bindings::python::activate()?; + + whitelist::verify_task(task)?; + + let task = serde_json::to_string(task)?; + let args = serde_json::to_string(args)?; + let inputs = serde_json::to_string(&vec![input])?; + + Python::with_gil(|py| -> Result> { + let transform: Py = get_module!(PY_MODULE) + .getattr(py, "transform") + .format_traceback(py)?; + + let output = transform + .call1( + py, + PyTuple::new( + py, + &[ + task.into_py(py), + args.into_py(py), + inputs.into_py(py), + true.into_py(py), + ], + ), + ) + .format_traceback(py)?; + + Ok(output) + }) +} pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy