diff --git a/src/openlayer/lib/integrations/langchain_callback.py b/src/openlayer/lib/integrations/langchain_callback.py index d476dfb5..8f5dfd3f 100644 --- a/src/openlayer/lib/integrations/langchain_callback.py +++ b/src/openlayer/lib/integrations/langchain_callback.py @@ -163,13 +163,15 @@ def _process_and_upload_trace(self, root_step: steps.Step) -> None: if tracer._publish: try: - tracer._client.inference_pipelines.data.stream( - inference_pipeline_id=utils.get_env_variable( - "OPENLAYER_INFERENCE_PIPELINE_ID" - ), - rows=[trace_data], - config=config, - ) + client = tracer._get_client() + if client: + client.inference_pipelines.data.stream( + inference_pipeline_id=utils.get_env_variable( + "OPENLAYER_INFERENCE_PIPELINE_ID" + ), + rows=[trace_data], + config=config, + ) except Exception as err: # pylint: disable=broad-except tracer.logger.error("Could not stream data to Openlayer %s", err) diff --git a/src/openlayer/lib/tracing/tracer.py b/src/openlayer/lib/tracing/tracer.py index bc02ad88..d27771ad 100644 --- a/src/openlayer/lib/tracing/tracer.py +++ b/src/openlayer/lib/tracing/tracer.py @@ -17,6 +17,8 @@ logger = logging.getLogger(__name__) +# ----------------------------- Module setup and globals ----------------------------- # + TRUE_LIST = ["true", "on", "1"] _publish = utils.get_env_variable("OPENLAYER_DISABLE_PUBLISH") not in TRUE_LIST @@ -24,20 +26,33 @@ utils.get_env_variable("OPENLAYER_VERIFY_SSL") or "true" ).lower() in TRUE_LIST _client = None -if _publish: - if _verify_ssl: - _client = Openlayer() - else: - _client = Openlayer( - http_client=DefaultHttpxClient( - verify=False, - ), - ) + + +def _get_client() -> Optional[Openlayer]: + """Get or create the Openlayer client with lazy initialization.""" + global _client + if not _publish: + return None + + if _client is None: + # Lazy initialization - create client when first needed + if _verify_ssl: + _client = Openlayer() + else: + _client = Openlayer( + http_client=DefaultHttpxClient( + verify=False, + ), + ) + return _client + _current_step = contextvars.ContextVar("current_step") _current_trace = contextvars.ContextVar("current_trace") _rag_context = contextvars.ContextVar("rag_context") +# ----------------------------- Public API functions ----------------------------- # + def get_current_trace() -> Optional[traces.Trace]: """Returns the current trace.""" @@ -64,26 +79,13 @@ def create_step( inference_pipeline_id: Optional[str] = None, ) -> Generator[steps.Step, None, None]: """Starts a trace and yields a Step object.""" - new_step: steps.Step = steps.step_factory( - step_type=step_type, name=name, inputs=inputs, output=output, metadata=metadata + new_step, is_root_step, token = _create_and_initialize_step( + step_name=name, + step_type=step_type, + inputs=inputs, + output=output, + metadata=metadata, ) - new_step.start_time = time.time() - - parent_step: Optional[steps.Step] = get_current_step() - is_root_step: bool = parent_step is None - - if parent_step is None: - logger.debug("Starting a new trace...") - current_trace = traces.Trace() - _current_trace.set(current_trace) # Set the current trace in context - _rag_context.set(None) # Reset the context - current_trace.add_step(new_step) - else: - logger.debug("Adding step %s to parent step %s", name, parent_step.name) - current_trace = get_current_trace() - parent_step.add_nested_step(new_step) - - token = _current_step.set(new_step) try: yield new_step finally: @@ -94,44 +96,11 @@ def create_step( new_step.latency = latency _current_step.reset(token) - if is_root_step: - logger.debug("Ending the trace...") - trace_data, input_variable_names = post_process_trace(current_trace) - - config = dict( - ConfigLlmData( - output_column_name="output", - input_variable_names=input_variable_names, - latency_column_name="latency", - cost_column_name="cost", - timestamp_column_name="inferenceTimestamp", - inference_id_column_name="inferenceId", - num_of_token_column_name="tokens", - ) - ) - if "groundTruth" in trace_data: - config.update({"ground_truth_column_name": "groundTruth"}) - if "context" in trace_data: - config.update({"context_column_name": "context"}) - - if isinstance(new_step, steps.ChatCompletionStep): - config.update( - { - "prompt": new_step.inputs.get("prompt"), - } - ) - if _publish: - try: - _client.inference_pipelines.data.stream( - inference_pipeline_id=inference_pipeline_id - or utils.get_env_variable("OPENLAYER_INFERENCE_PIPELINE_ID"), - rows=[trace_data], - config=config, - ) - except Exception as err: # pylint: disable=broad-except - logger.error("Could not stream data to Openlayer %s", err) - else: - logger.debug("Ending step %s", name) + _handle_trace_completion( + is_root_step=is_root_step, + step_name=name, + inference_pipeline_id=inference_pipeline_id, + ) def add_chat_completion_step_to_trace(**kwargs) -> None: @@ -143,7 +112,6 @@ def add_chat_completion_step_to_trace(**kwargs) -> None: step.log(**kwargs) -# ----------------------------- Tracing decorator ---------------------------- # def trace( *step_args, inference_pipeline_id: Optional[str] = None, @@ -193,40 +161,25 @@ def decorator(func): def wrapper(*func_args, **func_kwargs): if step_kwargs.get("name") is None: step_kwargs["name"] = func.__name__ + with create_step( *step_args, inference_pipeline_id=inference_pipeline_id, **step_kwargs ) as step: output = exception = None try: output = func(*func_args, **func_kwargs) - # pylint: disable=broad-except except Exception as exc: - step.log(metadata={"Exceptions": str(exc)}) + _log_step_exception(step, exc) exception = exc - end_time = time.time() - latency = (end_time - step.start_time) * 1000 # in ms - - bound = func_signature.bind(*func_args, **func_kwargs) - bound.apply_defaults() - inputs = dict(bound.arguments) - inputs.pop("self", None) - inputs.pop("cls", None) - - if context_kwarg: - if context_kwarg in inputs: - log_context(inputs.get(context_kwarg)) - else: - logger.warning( - "Context kwarg `%s` not found in inputs of the " - "current function.", - context_kwarg, - ) - step.log( - inputs=inputs, + # Extract inputs and finalize logging using optimized helper + _process_wrapper_inputs_and_outputs( + step=step, + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, output=output, - end_time=end_time, - latency=latency, ) if exception is not None: @@ -244,101 +197,188 @@ def trace_async( context_kwarg: Optional[str] = None, **step_kwargs, ): - """Decorator to trace a function. + """Decorator to trace async functions and async generators. + + This decorator automatically detects whether the function is a regular async function + or an async generator and handles both cases appropriately. Examples -------- - To trace a function, simply decorate it with the ``@trace()`` decorator. By doing - so, the functions inputs, outputs, and metadata will be automatically logged to your - Openlayer project. + To trace a regular async function: - >>> import os - >>> from openlayer.tracing import tracer - >>> - >>> # Set the environment variables - >>> os.environ["OPENLAYER_API_KEY"] = "YOUR_OPENLAYER_API_KEY_HERE" - >>> os.environ["OPENLAYER_PROJECT_NAME"] = "YOUR_OPENLAYER_PROJECT_NAME_HERE" - >>> - >>> # Decorate all the functions you want to trace >>> @tracer.trace_async() >>> async def main(user_query: str) -> str: >>> context = retrieve_context(user_query) >>> answer = generate_answer(user_query, context) >>> return answer - >>> - >>> @tracer.trace_async() - >>> def retrieve_context(user_query: str) -> str: - >>> return "Some context" - >>> + + To trace an async generator function: + >>> @tracer.trace_async() - >>> def generate_answer(user_query: str, context: str) -> str: - >>> return "Some answer" - >>> - >>> # Every time the main function is called, the data is automatically - >>> # streamed to your Openlayer project. E.g.: - >>> tracer.run_async_func(main("What is the meaning of life?")) + >>> async def stream_response(query: str): + >>> async for chunk in openai_client.chat.completions.create(...): + >>> yield chunk.choices[0].delta.content """ def decorator(func): func_signature = inspect.signature(func) - @wraps(func) - async def wrapper(*func_args, **func_kwargs): - if step_kwargs.get("name") is None: - step_kwargs["name"] = func.__name__ - with create_step( - *step_args, inference_pipeline_id=inference_pipeline_id, **step_kwargs - ) as step: - output = exception = None - try: - output = await func(*func_args, **func_kwargs) - # pylint: disable=broad-except - except Exception as exc: - step.log(metadata={"Exceptions": str(exc)}) - exception = exc - end_time = time.time() - latency = (end_time - step.start_time) * 1000 # in ms - - bound = func_signature.bind(*func_args, **func_kwargs) - bound.apply_defaults() - inputs = dict(bound.arguments) - inputs.pop("self", None) - inputs.pop("cls", None) - - if context_kwarg: - if context_kwarg in inputs: - log_context(inputs.get(context_kwarg)) - else: - logger.warning( - "Context kwarg `%s` not found in inputs of the " - "current function.", - context_kwarg, + if step_kwargs.get("name") is None: + step_kwargs["name"] = func.__name__ + step_name = step_kwargs["name"] + + if asyncio.iscoroutinefunction(func) or inspect.isasyncgenfunction(func): + # Check if it's specifically an async generator function + if inspect.isasyncgenfunction(func): + # For async generators, use class-based approach to delay trace creation + # until actual iteration begins (not when generator object is created) + @wraps(func) + def async_generator_wrapper(*func_args, **func_kwargs): + class TracedAsyncGenerator: + def __init__(self): + self._original_gen = None + self._step = None + self._is_root_step = False + self._token = None + self._output_chunks = [] + self._trace_initialized = False + + def __aiter__(self): + return self + + async def __anext__(self): + # Initialize tracing on first iteration only + if not self._trace_initialized: + self._original_gen = func(*func_args, **func_kwargs) + self._step, self._is_root_step, self._token = _create_and_initialize_step( + step_name=step_name, + step_type=enums.StepType.USER_CALL, + inputs=None, + output=None, + metadata=None, + ) + self._inputs = _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ) + self._trace_initialized = True + + try: + chunk = await self._original_gen.__anext__() + self._output_chunks.append(chunk) + return chunk + except StopAsyncIteration: + # Finalize trace when generator is exhausted + output = _join_output_chunks(self._output_chunks) + _finalize_async_generator_step( + step=self._step, + token=self._token, + is_root_step=self._is_root_step, + step_name=step_name, + inputs=self._inputs, + output=output, + inference_pipeline_id=inference_pipeline_id, + ) + raise + except Exception as exc: + # Handle exceptions + if self._step: + _log_step_exception(self._step, exc) + output = _join_output_chunks(self._output_chunks) + _finalize_async_generator_step( + step=self._step, + token=self._token, + is_root_step=self._is_root_step, + step_name=step_name, + inputs=self._inputs, + output=output, + inference_pipeline_id=inference_pipeline_id, + ) + raise + + return TracedAsyncGenerator() + + return async_generator_wrapper + else: + # Create wrapper for regular async functions + @wraps(func) + async def async_function_wrapper(*func_args, **func_kwargs): + with create_step( + *step_args, + inference_pipeline_id=inference_pipeline_id, + **step_kwargs, + ) as step: + output = exception = None + + try: + output = await func(*func_args, **func_kwargs) + except Exception as exc: + _log_step_exception(step, exc) + exception = exc + raise + + # Extract inputs and finalize logging + _process_wrapper_inputs_and_outputs( + step=step, + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + output=output, ) - step.log( - inputs=inputs, - output=output, - end_time=end_time, - latency=latency, - ) + return output - if exception is not None: - raise exception - return output + return async_function_wrapper + else: + # For sync functions, use the existing logic with optimizations + @wraps(func) + def sync_wrapper(*func_args, **func_kwargs): + with create_step( + *step_args, + inference_pipeline_id=inference_pipeline_id, + **step_kwargs, + ) as step: + output = exception = None + try: + output = func(*func_args, **func_kwargs) + except Exception as exc: + _log_step_exception(step, exc) + exception = exc + + # Extract inputs and finalize logging + _process_wrapper_inputs_and_outputs( + step=step, + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + output=output, + ) - return wrapper + if exception is not None: + raise exception + return output + + return sync_wrapper return decorator -async def _invoke_with_context( - coroutine: Awaitable[Any], -) -> Tuple[contextvars.Context, Any]: - """Runs a coroutine and preserves the context variables set within it.""" - result = await coroutine - context = contextvars.copy_context() - return context, result +def log_context(context: List[str]) -> None: + """Logs context information to the current step of the trace. + + The `context` parameter should be a list of strings representing the + context chunks retrieved by the context retriever.""" + current_step = get_current_step() + if current_step: + _rag_context.set(context) + current_step.log(metadata={"context": context}) + else: + logger.warning("No current step found to log context.") def run_async_func(coroutine: Awaitable[Any]) -> Any: @@ -351,20 +391,211 @@ def run_async_func(coroutine: Awaitable[Any]) -> Any: return result -def log_context(context: List[str]) -> None: - """Logs context information to the current step of the trace. +# ----------------------------- Helper functions for create_step ----------------------------- # - The `context` parameter should be a list of strings representing the - context chunks retrieved by the context retriever.""" - current_step = get_current_step() - if current_step: - _rag_context.set(context) - current_step.log(metadata={"context": context}) + +def _create_and_initialize_step( + step_name: str, + step_type: enums.StepType = enums.StepType.USER_CALL, + inputs: Optional[Any] = None, + output: Optional[Any] = None, + metadata: Optional[Dict[str, Any]] = None, +) -> Tuple[steps.Step, bool, Any]: + """Create a new step and initialize trace/parent relationships. + + Returns: + Tuple of (step, is_root_step, token) + """ + new_step = steps.step_factory( + step_type=step_type, + name=step_name, + inputs=inputs, + output=output, + metadata=metadata, + ) + new_step.start_time = time.time() + + parent_step = get_current_step() + is_root_step = parent_step is None + + if parent_step is None: + logger.debug("Starting a new trace...") + current_trace = traces.Trace() + _current_trace.set(current_trace) + _rag_context.set(None) + current_trace.add_step(new_step) else: - logger.warning("No current step found to log context.") + logger.debug("Adding step %s to parent step %s", step_name, parent_step.name) + current_trace = get_current_trace() + parent_step.add_nested_step(new_step) + + token = _current_step.set(new_step) + return new_step, is_root_step, token + + +def _handle_trace_completion( + is_root_step: bool, step_name: str, inference_pipeline_id: Optional[str] = None +) -> None: + """Handle trace completion and data streaming.""" + if is_root_step: + logger.debug("Ending the trace...") + current_trace = get_current_trace() + trace_data, input_variable_names = post_process_trace(current_trace) + + config = dict( + ConfigLlmData( + output_column_name="output", + input_variable_names=input_variable_names, + latency_column_name="latency", + cost_column_name="cost", + timestamp_column_name="inferenceTimestamp", + inference_id_column_name="inferenceId", + num_of_token_column_name="tokens", + ) + ) + if "groundTruth" in trace_data: + config.update({"ground_truth_column_name": "groundTruth"}) + if "context" in trace_data: + config.update({"context_column_name": "context"}) + + if isinstance(get_current_step(), steps.ChatCompletionStep): + config.update( + { + "prompt": get_current_step().inputs.get("prompt"), + } + ) + if _publish: + try: + client = _get_client() + if client: + client.inference_pipelines.data.stream( + inference_pipeline_id=inference_pipeline_id + or utils.get_env_variable("OPENLAYER_INFERENCE_PIPELINE_ID"), + rows=[trace_data], + config=config, + ) + except Exception as err: # pylint: disable=broad-except + logger.error("Could not stream data to Openlayer %s", err) + else: + logger.debug("Ending step %s", step_name) + + +# ----------------------------- Helper functions for trace decorators ----------------------------- # + + +def _log_step_exception(step: steps.Step, exception: Exception) -> None: + """Log exception metadata to a step.""" + step.log(metadata={"Exceptions": str(exception)}) + + +def _process_wrapper_inputs_and_outputs( + step: steps.Step, + func_signature: inspect.Signature, + func_args: tuple, + func_kwargs: dict, + context_kwarg: Optional[str], + output: Any, +) -> None: + """Extract function inputs and finalize step logging - common pattern across wrappers.""" + inputs = _extract_function_inputs( + func_signature=func_signature, + func_args=func_args, + func_kwargs=func_kwargs, + context_kwarg=context_kwarg, + ) + _finalize_step_logging( + step=step, inputs=inputs, output=output, start_time=step.start_time + ) + + +def _extract_function_inputs( + func_signature: inspect.Signature, + func_args: tuple, + func_kwargs: dict, + context_kwarg: Optional[str] = None, +) -> dict: + """Extract and clean function inputs for logging.""" + bound = func_signature.bind(*func_args, **func_kwargs) + bound.apply_defaults() + inputs = dict(bound.arguments) + inputs.pop("self", None) + inputs.pop("cls", None) + + # Handle context kwarg if specified + if context_kwarg: + if context_kwarg in inputs: + log_context(inputs.get(context_kwarg)) + else: + logger.warning( + "Context kwarg `%s` not found in inputs of the current function.", + context_kwarg, + ) + + return inputs + + +def _finalize_step_logging( + step: steps.Step, + inputs: dict, + output: Any, + start_time: float, +) -> None: + """Finalize step timing and logging.""" + if step.end_time is None: + step.end_time = time.time() + if step.latency is None: + step.latency = (step.end_time - start_time) * 1000 # in ms + + step.log( + inputs=inputs, + output=output, + end_time=step.end_time, + latency=step.latency, + ) + + +# ----------------------------- Async generator specific functions ----------------------------- # + + + +def _finalize_async_generator_step( + step: steps.Step, + token: Any, + is_root_step: bool, + step_name: str, + inputs: dict, + output: Any, + inference_pipeline_id: Optional[str] = None, +) -> None: + """Finalize async generator step - called when generator is consumed.""" + _current_step.reset(token) + _finalize_step_logging( + step=step, inputs=inputs, output=output, start_time=step.start_time + ) + _handle_trace_completion( + is_root_step=is_root_step, + step_name=step_name, + inference_pipeline_id=inference_pipeline_id, + ) + + +def _join_output_chunks(output_chunks: List[Any]) -> str: + """Join output chunks into a single string, filtering out None values.""" + return "".join(str(chunk) for chunk in output_chunks if chunk is not None) + + +# ----------------------------- Utility functions ----------------------------- # + + +async def _invoke_with_context( + coroutine: Awaitable[Any], +) -> Tuple[contextvars.Context, Any]: + """Runs a coroutine and preserves the context variables set within it.""" + result = await coroutine + context = contextvars.copy_context() + return context, result -# --------------------- Helper post-processing functions --------------------- # def post_process_trace( trace_obj: traces.Trace, ) -> Tuple[Dict[str, Any], List[str]]: pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy