Skip to content

Commit

Permalink
LiteLLM Minor Fixes & Improvements (11/12/2024) (BerriAI#6705)
Browse files Browse the repository at this point in the history
* fix(caching): convert arg to equivalent kwargs in llm caching handler

prevent unexpected errors

* fix(caching_handler.py): don't pass args to caching

* fix(caching): remove all *args from caching.py

* fix(caching): consistent function signatures + abc method

* test(caching_unit_tests.py): add unit tests for llm caching

ensures coverage for common caching scenarios across different implementations

* refactor(litellm_logging.py): move to using cache key from hidden params instead of regenerating one

* fix(router.py): drop redis password requirement

* fix(proxy_server.py): fix faulty slack alerting check

* fix(langfuse.py): avoid copying functions/thread lock objects in metadata

fixes metadata copy error when parent otel span in metadata

* test: update test
  • Loading branch information
krrishdholakia authored Nov 12, 2024
1 parent d39fd60 commit 9160d80
Show file tree
Hide file tree
Showing 23 changed files with 525 additions and 204 deletions.
7 changes: 6 additions & 1 deletion litellm/caching/base_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
- async_get_cache
"""

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional

if TYPE_CHECKING:
Expand All @@ -18,7 +19,7 @@
Span = Any


class BaseCache:
class BaseCache(ABC):
def __init__(self, default_ttl: int = 60):
self.default_ttl = default_ttl

Expand All @@ -37,6 +38,10 @@ def set_cache(self, key, value, **kwargs):
async def async_set_cache(self, key, value, **kwargs):
raise NotImplementedError

@abstractmethod
async def async_set_cache_pipeline(self, cache_list, **kwargs):
pass

def get_cache(self, key, **kwargs):
raise NotImplementedError

Expand Down
75 changes: 35 additions & 40 deletions litellm/caching/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,19 +233,18 @@ def __init__(
if self.namespace is not None and isinstance(self.cache, RedisCache):
self.cache.namespace = self.namespace

def get_cache_key(self, *args, **kwargs) -> str:
def get_cache_key(self, **kwargs) -> str:
"""
Get the cache key for the given arguments.
Args:
*args: args to litellm.completion() or embedding()
**kwargs: kwargs to litellm.completion() or embedding()
Returns:
str: The cache key generated from the arguments, or None if no cache key could be generated.
"""
cache_key = ""
verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)

preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
if preset_cache_key is not None:
Expand Down Expand Up @@ -521,7 +520,7 @@ def _get_cache_logic(
return cached_response
return cached_result

def get_cache(self, *args, **kwargs):
def get_cache(self, **kwargs):
"""
Retrieves the cached result for the given arguments.
Expand All @@ -533,13 +532,13 @@ def get_cache(self, *args, **kwargs):
The cached result if it exists, otherwise None.
"""
try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
messages = kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
Expand All @@ -553,45 +552,44 @@ def get_cache(self, *args, **kwargs):
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None

async def async_get_cache(self, *args, **kwargs):
async def async_get_cache(self, **kwargs):
"""
Async get cache implementation.
Used for embedding calls in async wrapper
"""

try: # never block execution
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return

kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = await self.cache.async_get_cache(
cache_key, *args, **kwargs
)
cached_result = await self.cache.async_get_cache(cache_key, **kwargs)
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None

def _add_cache_logic(self, result, *args, **kwargs):
def _add_cache_logic(self, result, **kwargs):
"""
Common implementation across sync + async add_cache functions
"""
try:
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
cache_key = self.get_cache_key(*args, **kwargs)
cache_key = self.get_cache_key(**kwargs)
if cache_key is not None:
if isinstance(result, BaseModel):
result = result.model_dump_json()
Expand All @@ -613,7 +611,7 @@ def _add_cache_logic(self, result, *args, **kwargs):
except Exception as e:
raise e

def add_cache(self, result, *args, **kwargs):
def add_cache(self, result, **kwargs):
"""
Adds a result to the cache.
Expand All @@ -625,41 +623,42 @@ def add_cache(self, result, *args, **kwargs):
None
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
result=result, **kwargs
)
self.cache.set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")

async def async_add_cache(self, result, *args, **kwargs):
async def async_add_cache(self, result, **kwargs):
"""
Async implementation of add_cache
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return
if self.type == "redis" and self.redis_flush_size is not None:
# high traffic - fill in results in memory and then flush
await self.batch_cache_write(result, *args, **kwargs)
await self.batch_cache_write(result, **kwargs)
else:
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
result=result, **kwargs
)

await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")

async def async_add_cache_pipeline(self, result, *args, **kwargs):
async def async_add_cache_pipeline(self, result, **kwargs):
"""
Async implementation of add_cache for Embedding calls
Does a bulk write, to prevent using too many clients
"""
try:
if self.should_use_cache(*args, **kwargs) is not True:
if self.should_use_cache(**kwargs) is not True:
return

# set default ttl if not set
Expand All @@ -668,29 +667,27 @@ async def async_add_cache_pipeline(self, result, *args, **kwargs):

cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(*args, **{**kwargs, "input": i})
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
*args,
**kwargs,
)
cache_list.append((cache_key, cached_data))
async_set_cache_pipeline = getattr(
self.cache, "async_set_cache_pipeline", None
)
if async_set_cache_pipeline:
await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
else:
tasks = []
for val in cache_list:
tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
await asyncio.gather(*tasks)

await self.cache.async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# if async_set_cache_pipeline:
# await async_set_cache_pipeline(cache_list=cache_list, **kwargs)
# else:
# tasks = []
# for val in cache_list:
# tasks.append(self.cache.async_set_cache(val[0], val[1], **kwargs))
# await asyncio.gather(*tasks)
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")

def should_use_cache(self, *args, **kwargs):
def should_use_cache(self, **kwargs):
"""
Returns true if we should use the cache for LLM API calls
Expand All @@ -708,10 +705,8 @@ def should_use_cache(self, *args, **kwargs):
return True
return False

async def batch_cache_write(self, result, *args, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(
result=result, *args, **kwargs
)
async def batch_cache_write(self, result, **kwargs):
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)

async def ping(self):
Expand Down
Loading

0 comments on commit 9160d80

Please sign in to comment.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy