Skip to content

Instantly share code, notes, and snippets.

@the-crypt-keeper
Created December 29, 2024 16:56
Show Gist options
  • Save the-crypt-keeper/932d51708f3343d5cce76ffa7a4d989f to your computer and use it in GitHub Desktop.
Save the-crypt-keeper/932d51708f3343d5cce76ffa7a4d989f to your computer and use it in GitHub Desktop.
LiteLLM adapters for generating images with local Stable Diffusion APIs
model_list:
- model_name: "sd_xl_turbo_1.0"
litellm_params:
model: "sdserver/sd_xl_turbo_1.0"
api_base: "http://falcon:51524/"
model_info:
mode: image_generation
- model_name: "Deliberate_v2"
litellm_params:
model: "sdserver/Deliberate_v2"
api_base: "http://falcon:59275/"
model_info:
mode: image_generation
- model_name: "fluxunchained-schnell-dev-merge-q8-0"
litellm_params:
model: "kobold/fluxunchained-schnell-dev-merge-q8-0"
api_base: "http://falcon:52247/"
model_info:
mode: image_generation
- model_name: "dall-e-3"
litellm_params:
model: "openai/dall-e-3"
litellm_settings:
custom_provider_map:
- {"provider": "kobold", "custom_handler": kobold_handler.kobold_cpp}
- {"provider": "sdserver", "custom_handler": kobold_handler.sd_server}
- {"provider": "llamabox", "custom_handler": kobold_handler.llama_box}
general_settings:
pass_through_endpoints:
- path: "/"
target: "http://localhost:3333"
router_settings:
routing_strategy: "least-busy"
num_retries: 1
timeout: 180
import litellm
import time
import httpx
from typing import Any, Optional, Union, Protocol
from litellm import CustomLLM
from litellm.types.utils import ImageResponse, ImageObject
class KoboldCpp(CustomLLM):
async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client = None, **kwargs) -> ImageResponse:
if client is None:
client = httpx.AsyncClient(base_url=kwargs['api_base'], timeout=timeout)
# Health check. We don't actually need to generate anything, just check the server is alive.
if prompt == "test from litellm":
response = await client.get("/api/v1/info/version")
response.raise_for_status()
return ImageResponse()
if optional_params.get('n', 1) != 1:
raise ValueError('n parameter is not supported by the proxy')
if 'size' in optional_params:
optional_params['width'], optional_params['height'] = optional_params.pop('size').split('x')
if 'steps' not in optional_params:
optional_params['steps'] = 8
optional_params['prompt'] = prompt
print("koboldcpp aimage_generation() called:", optional_params, "timeout=", timeout)
try:
response = await client.post("/sdapi/v1/txt2img", json=optional_params)
response.raise_for_status()
except Exception as e:
raise ValueError("HTTP error", response.status_code)
try:
result = response.json()
except Exception as e:
raise ValueError("JSON parse error", response.text)
return ImageResponse(
created=int(time.time()),
data=[ImageObject(b64_json=img_data) for img_data in result["images"]]
)
class SDServer(CustomLLM):
async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client = None, **kwargs) -> ImageResponse:
if client is None:
client = httpx.AsyncClient(base_url=kwargs['api_base'], timeout=timeout)
# Health check. We don't actually need to generate anything, just check the server is alive.
if prompt == "test from litellm":
response = await client.get("/")
assert response.status_code == 404
return ImageResponse()
if 'n' in optional_params: optional_params['batch_count'] = optional_params.pop('n')
if 'size' in optional_params:
optional_params['width'] = int(optional_params['size'].split('x')[0])
optional_params['height'] = int(optional_params['size'].split('x')[1])
del optional_params['size']
if 'steps' not in optional_params: optional_params['steps'] = 8
optional_params['sample_steps'] = optional_params.pop('steps')
# if 'negative_prompt' not in optional_params: optional_params['negative_prompt'] = "Bad quality, ugly"
if 'seed' not in optional_params: optional_params['seed'] = -1
optional_params['prompt'] = prompt
print("sd-server aimage_generation() called:", optional_params, "timeout=", timeout)
try:
response = await client.post("/txt2img", json=optional_params)
response.raise_for_status()
except Exception as e:
raise ValueError("HTTP error", response.status_code)
try:
result = response.json()
except Exception as e:
raise ValueError("JSON parse error", response.text)
return ImageResponse(
created=int(time.time()),
data=[ImageObject(b64_json=img_data["data"]) for img_data in result]
)
class LlamaBox(CustomLLM):
async def aimage_generation(self, model: str, prompt: str, model_response: ImageResponse, optional_params: dict, logging_obj: Any, timeout: Optional[Union[float, httpx.Timeout]] = None, client = None, **kwargs) -> ImageResponse:
if client is None:
client = httpx.AsyncClient(base_url=kwargs['api_base'], timeout=timeout)
# Health check. We don't actually need to generate anything, just check the server is alive.
if prompt == "test from litellm":
response = await client.get("/")
assert response.status_code == 404
return ImageResponse()
params = {}
params['n'] = optional_params.get('n', 1)
if 'size' in optional_params:
params['width'] = int(optional_params['size'].split('x')[0])
params['height'] = int(optional_params['size'].split('x')[1])
if 'steps' in optional_params:
params['sampling_steps'] = optional_params.get('steps')
if 'negative_prompt' in optional_params: optional_params['negative_prompt'] = "Bad quality, ugly"
if 'seed' in optional_params: optional_params['seed'] = -1
params['prompt'] = prompt
params['sample_method'] = 'Euler'
if 'sampler' in optional_params:
params['sample_method'] = optional_params['sampler']
print("llamabox aimage_generation() called:", params, "timeout=", timeout)
response = None
try:
response = await client.post("/v1/images/generations", json=params)
response.raise_for_status()
except Exception as e:
raise ValueError("HTTP error", response.status_code if response is not None else "Call failed")
try:
result = response.json()
except Exception as e:
raise ValueError("JSON parse error", response.text if response is not None else "Call failed")
return ImageResponse(
created=int(time.time()),
data=[ImageObject(b64_json=img_data["b64_json"]) for img_data in result['data']]
)
kobold_cpp = KoboldCpp()
sd_server = SDServer()
llama_box = LlamaBox()
# without this, optional_params doesnt work see https://github.com/BerriAI/litellm/blob/main/litellm/utils.py#L2171
litellm.openai_compatible_providers.append('kobold')
litellm.openai_compatible_providers.append('sdserver')
litellm.openai_compatible_providers.append('llamabox')
# we want to use pass_through_endpoints on / so we need to yank any default handlers
app = litellm.proxy.proxy_server.app
app.router.routes = [x for x in app.router.routes if x.path != '/']

This is a raw dump of litellm image generation adapters for kobold (which is automatic1111 compatible), sd-server and llamabox.

Only the /v1/images/generations endpoint is supported by the adapters.

To use:

  1. Register the custom handlers in custom_provider_map
  2. Define endpoints for kobold/automatic1111, sd-server and llamabox as needed

Things I hate about this:

  1. The openai_compatible_providers hack, without which optional_params dont get passed to the adapter and just get eaten (why?)
  2. This entire approach is bad, trying to pack the SD-API inside the DALLE-API was misguided as the SD-API is the superset.

I ended up writing my own SD-API proxy: https://github.com/the-crypt-keeper/modelzoo/blob/main/proxy.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment