Skip to content

Commit

Permalink
fix raise correct error 404 when /key/info is called on non-existent …
Browse files Browse the repository at this point in the history
…key (BerriAI#6653)

* fix raise correct error on /key/info

* add not_found_error error

* fix key not found in DB error

* use 1 helper for checking token hash

* fix error code on key info

* fix test key gen prisma

* test_generate_and_call_key_info

* test fix test_call_with_valid_model_using_all_models

* fix key info tests
  • Loading branch information
ishaan-jaff authored Nov 12, 2024
1 parent 25bae4c commit de2f9ae
Show file tree
Hide file tree
Showing 8 changed files with 3,593 additions and 57 deletions.
1 change: 1 addition & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,7 @@ class ProxyErrorTypes(str, enum.Enum):
auth_error = "auth_error"
internal_server_error = "internal_server_error"
bad_request_error = "bad_request_error"
not_found_error = "not_found_error"


class SSOUserDefinedValues(TypedDict):
Expand Down
10 changes: 2 additions & 8 deletions litellm/proxy/auth/route_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,8 @@ def non_proxy_admin_allowed_routes_check(
route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route
if route == "/key/info":
# check if user can access this route
query_params = request.query_params
key = query_params.get("key")
if key is not None and hash_token(token=key) != api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info",
)
# handled by function itself
pass
elif route == "/user/info":
# check if user can access this route
query_params = request.query_params
Expand Down
57 changes: 51 additions & 6 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret

router = APIRouter()
Expand Down Expand Up @@ -734,13 +734,37 @@ async def info_key_fn(
raise Exception(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
if key is None:
key = user_api_key_dict.api_key
key_info = await prisma_client.get_data(token=key)

# default to using Auth token if no key is passed in
key = key or user_api_key_dict.api_key
hashed_key: Optional[str] = key
if key is not None:
hashed_key = _hash_token_if_needed(token=key)
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_key}, # type: ignore
include={"litellm_budget_table": True},
)
if key_info is None:
raise ProxyException(
message="Key not found in database",
type=ProxyErrorTypes.not_found_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)

if (
_can_user_query_key_info(
user_api_key_dict=user_api_key_dict,
key=key,
key_info=key_info,
)
is not True
):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"message": "No keys found"},
status_code=status.HTTP_403_FORBIDDEN,
detail="You are not allowed to access this key's info. Your role={}".format(
user_api_key_dict.user_role
),
)
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
try:
Expand Down Expand Up @@ -1540,6 +1564,27 @@ async def key_health(
)


def _can_user_query_key_info(
user_api_key_dict: UserAPIKeyAuth,
key: Optional[str],
key_info: LiteLLM_VerificationToken,
) -> bool:
"""
Helper to check if the user has access to the key's info
"""
if (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
return True
elif user_api_key_dict.api_key == key:
return True
# user can query their own key info
elif key_info.user_id == user_api_key_dict.user_id:
return True
return False


async def test_key_logging(
user_api_key_dict: UserAPIKeyAuth,
request: Request,
Expand Down
26 changes: 16 additions & 10 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,9 +1424,7 @@ async def get_data( # noqa: PLR0915
# check if plain text or hash
if token is not None:
if isinstance(token, str):
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
hashed_token = _hash_token_if_needed(token=token)
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
Expand Down Expand Up @@ -1493,8 +1491,7 @@ async def get_data( # noqa: PLR0915
if token is not None:
where_filter["token"] = {}
if isinstance(token, str):
if token.startswith("sk-"):
token = self.hash_token(token=token)
token = _hash_token_if_needed(token=token)
where_filter["token"]["in"] = [token]
elif isinstance(token, list):
hashed_tokens = []
Expand Down Expand Up @@ -1630,9 +1627,7 @@ async def get_data( # noqa: PLR0915
# check if plain text or hash
if token is not None:
if isinstance(token, str):
hashed_token = token
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
hashed_token = _hash_token_if_needed(token=token)
verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}"
)
Expand Down Expand Up @@ -1912,8 +1907,7 @@ async def update_data( # noqa: PLR0915
if token is not None:
print_verbose(f"token: {token}")
# check if plain text or hash
if token.startswith("sk-"):
token = self.hash_token(token=token)
token = _hash_token_if_needed(token=token)
db_data["token"] = token
response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore
Expand Down Expand Up @@ -2424,6 +2418,18 @@ def hash_token(token: str):
return hashed_token


def _hash_token_if_needed(token: str) -> str:
"""
Hash the token if it's a string and starts with "sk-"
Else return the token as is
"""
if token.startswith("sk-"):
return hash_token(token=token)
else:
return token


def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration)

Expand Down
Loading

0 comments on commit de2f9ae

Please sign in to comment.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy