Skip to content

Commit

Permalink
(feat) Add Bedrock Stability.ai Stable Diffusion 3 Image Generation m…
Browse files Browse the repository at this point in the history
…odels (BerriAI#6673)

* add bedrock image gen async support

* added async support for bedrock image gen

* move image gen testing

* add AmazonStability3Config

* add AmazonStability3Config config

* update AmazonStabilityConfig

* update get_optional_params_image_gen

* use 1 helper for _get_request_body

* add transform_response_dict_to_openai_response for stability3

* test sd3-large-v1:0

* unit testing for bedrock image gen

* fix load_vertex_ai_credentials

* fix test_aimage_generation_vertex_ai

* add stability.sd3-large-v1:0 to model cost map

* add stability.stability.sd3-large-v1:0 to docs
  • Loading branch information
ishaan-jaff authored Nov 9, 2024
1 parent 0871c33 commit 979dfe8
Show file tree
Hide file tree
Showing 14 changed files with 528 additions and 111 deletions.
1 change: 1 addition & 0 deletions docs/my-website/docs/providers/bedrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -1082,5 +1082,6 @@ print(f"response: {response}")
| Model Name | Function Call |
|----------------------|---------------------------------------------|
| Stable Diffusion 3 - v0 | `embedding(model="bedrock/stability.stability.sd3-large-v1:0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v0", prompt=prompt)` |
| Stable Diffusion - v0 | `embedding(model="bedrock/stability.stable-diffusion-xl-v1", prompt=prompt)` |
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,6 +988,7 @@ class LlmProviders(str, Enum):
AmazonBedrockGlobalConfig,
)
from .llms.bedrock.image.amazon_stability1_transformation import AmazonStabilityConfig
from .llms.bedrock.image.amazon_stability3_transformation import AmazonStability3Config
from .llms.bedrock.embed.amazon_titan_g1_transformation import AmazonTitanG1Config
from .llms.bedrock.embed.amazon_titan_multimodal_transformation import (
AmazonTitanMultimodalEmbeddingG1Config,
Expand Down
35 changes: 35 additions & 0 deletions litellm/llms/bedrock/image/amazon_stability1_transformation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import types
from typing import List, Optional

from openai.types.image import Image

from litellm.types.utils import ImageResponse


class AmazonStabilityConfig:
"""
Expand Down Expand Up @@ -67,3 +71,34 @@ def get_config(cls):
)
and v is not None
}

@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
return ["size"]

@classmethod
def map_openai_params(
cls,
non_default_params: dict,
optional_params: dict,
):
_size = non_default_params.get("size")
if _size is not None:
width, height = _size.split("x")
optional_params["width"] = int(width)
optional_params["height"] = int(height)

return optional_params

@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
image_list: List[Image] = []
for artifact in response_dict["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)

model_response.data = image_list

return model_response
94 changes: 94 additions & 0 deletions litellm/llms/bedrock/image/amazon_stability3_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import types
from typing import List, Optional

from openai.types.image import Image

from litellm.types.llms.bedrock import (
AmazonStability3TextToImageRequest,
AmazonStability3TextToImageResponse,
)
from litellm.types.utils import ImageResponse


class AmazonStability3Config:
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
"""

@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}

@classmethod
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
"""
No additional OpenAI params are mapped for stability 3
"""
return []

@classmethod
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
"""
Returns True if the model is a Stability 3 model
Stability 3 models follow this pattern:
sd3-large
sd3-large-turbo
sd3-medium
sd3.5-large
sd3.5-large-turbo
"""
if model and ("sd3" in model or "sd3.5" in model):
return True
return False

@classmethod
def transform_request_body(
cls, prompt: str, optional_params: dict
) -> AmazonStability3TextToImageRequest:
"""
Transform the request body for the Stability 3 models
"""
data = AmazonStability3TextToImageRequest(prompt=prompt, **optional_params)
return data

@classmethod
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
"""
Map the OpenAI params to the Bedrock params
No OpenAI params are mapped for Stability 3, so directly return the optional_params
"""
return optional_params

@classmethod
def transform_response_dict_to_openai_response(
cls, model_response: ImageResponse, response_dict: dict
) -> ImageResponse:
"""
Transform the response dict to the OpenAI response
"""

stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
openai_images: List[Image] = []
for _img in stability_3_response.get("images", []):
openai_images.append(Image(b64_json=_img))

model_response.data = openai_images
return model_response
85 changes: 57 additions & 28 deletions litellm/llms/bedrock/image/image_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,28 +183,9 @@ def _prepare_request(
boto3_credentials_info.aws_region_name,
)

# transform request
### FORMAT IMAGE GENERATION INPUT ###
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {"text_prompts": [{"text": prompt, "weight": 1}], **inference_params}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
data = self._get_request_body(
model=model, prompt=prompt, optional_params=optional_params
)

# Make POST Request
body = json.dumps(data).encode("utf-8")
Expand Down Expand Up @@ -239,6 +220,51 @@ def _prepare_request(
data=data,
)

def _get_request_body(
self,
model: str,
prompt: str,
optional_params: dict,
) -> dict:
"""
Get the request body for the Bedrock Image Generation API
Checks the model/provider and transforms the request body accordingly
Returns:
dict: The request body to use for the Bedrock Image Generation API
"""
provider = model.split(".")[0]
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"user", None
) # make sure user is not passed in for bedrock call
data = {}
if provider == "stability":
if litellm.AmazonStability3Config._is_stability_3_model(model):
request_body = litellm.AmazonStability3Config.transform_request_body(
prompt=prompt, optional_params=optional_params
)
return dict(request_body)
else:
prompt = prompt.replace(os.linesep, " ")
## LOAD CONFIG
config = litellm.AmazonStabilityConfig.get_config()
for k, v in config.items():
if (
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
data = {
"text_prompts": [{"text": prompt, "weight": 1}],
**inference_params,
}
else:
raise BedrockError(
status_code=422, message=f"Unsupported model={model}, passed in"
)
return data

def _transform_response_dict_to_openai_response(
self,
model_response: ImageResponse,
Expand All @@ -265,11 +291,14 @@ def _transform_response_dict_to_openai_response(
if response_dict is None:
raise ValueError("Error in response object format, got None")

image_list: List[Image] = []
for artifact in response_dict["artifacts"]:
_image = Image(b64_json=artifact["base64"])
image_list.append(_image)

model_response.data = image_list
config_class = (
litellm.AmazonStability3Config
if litellm.AmazonStability3Config._is_stability_3_model(model=model)
else litellm.AmazonStabilityConfig
)
config_class.transform_response_dict_to_openai_response(
model_response=model_response,
response_dict=response_dict,
)

return model_response

This file was deleted.

3 changes: 2 additions & 1 deletion litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4448,6 +4448,7 @@ def image_generation( # noqa: PLR0915
k: v for k, v in kwargs.items() if k not in default_params
} # model-specific params - pass them straight to the model/provider
optional_params = get_optional_params_image_gen(
model=model,
n=n,
quality=quality,
response_format=response_format,
Expand Down Expand Up @@ -4540,7 +4541,7 @@ def image_generation( # noqa: PLR0915
elif custom_llm_provider == "bedrock":
if model is None:
raise Exception("Model needs to be set for bedrock")
model_response = bedrock_image_generation.image_generation(
model_response = bedrock_image_generation.image_generation( # type: ignore
model=model,
prompt=prompt,
timeout=timeout,
Expand Down
7 changes: 7 additions & 0 deletions litellm/model_prices_and_context_window_backup.json
Original file line number Diff line number Diff line change
Expand Up @@ -5611,6 +5611,13 @@
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"stability.stability.sd3-large-v1:0": {
"max_tokens": 77,
"max_input_tokens": 77,
"output_cost_per_image": 0.08,
"litellm_provider": "bedrock",
"mode": "image_generation"
},
"sagemaker/meta-textgeneration-llama-2-7b": {
"max_tokens": 4096,
"max_input_tokens": 4096,
Expand Down
29 changes: 29 additions & 0 deletions litellm/types/llms/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,3 +275,32 @@ class AmazonTitanMultimodalEmbeddingResponse(TypedDict):
AmazonTitanV2EmbeddingRequest,
AmazonTitanG1EmbeddingRequest,
]


class AmazonStability3TextToImageRequest(TypedDict, total=False):
"""
Request for Amazon Stability 3 Text to Image API
Ref here: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html
"""

prompt: str
aspect_ratio: Literal[
"16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"
]
mode: Literal["image-to-image", "text-to-image"]
output_format: Literal["JPEG", "PNG"]
seed: int
negative_prompt: str


class AmazonStability3TextToImageResponse(TypedDict, total=False):
"""
Response for Amazon Stability 3 Text to Image API
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-diffusion-3-text-image.html
"""

images: List[str]
seeds: List[str]
finish_reasons: List[str]
Loading

0 comments on commit 979dfe8

Please sign in to comment.
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy