Skip to content

Instantly share code, notes, and snippets.

@hex-plex
Created May 26, 2024 07:54
Show Gist options
  • Save hex-plex/4dc449eb685c747bed37ad050d82731a to your computer and use it in GitHub Desktop.
Save hex-plex/4dc449eb685c747bed37ad050d82731a to your computer and use it in GitHub Desktop.
Load Balancer to spin multiple models with single API interface
from flask import Flask, request, jsonify
import requests
import random
app = Flask(__name__)
endpoints = {
'model_provider/model_1': 'http://0.0.0.0:8001',
'model_provider/model_2': 'http://0.0.0.0:8002',
'model_provider/model_3': 'http://0.0.0.0:8003',
'model_provider/model_4': 'http://0.0.0.0:8004'
}
# CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --model model_provider/model_1 --tensor-parallel-size 1 --chat-template ./chat.jinja --port 8001
# CUDA_VISIBLE_DEVICES=1 python -m vllm.entrypoints.openai.api_server --model model_provider/model_2 --tensor-parallel-size 1 --chat-template ./chat.jinja --port 8002
# CUDA_VISIBLE_DEVICES=2 python -m vllm.entrypoints.openai.api_server --model model_provider/model_3 --tensor-parallel-size 1 --chat-template ./chat.jinja --port 8003
# CUDA_VISIBLE_DEVICES=3 python -m vllm.entrypoints.openai.api_server --model model_provider/model_4 --tensor-parallel-size 1 --chat-template ./chat.jinja --port 8004
def check_model(req, avail_endpoints, suffix_used, headers):
data = req.get_json()
model_name = data.get("model")
if model_name:
full_url = f"{avail_endpoints[model_name]}/v1/{suffix_used}"
response = requests.post(full_url, json=data, headers=headers)
return response.json(), response.status_code
else:
raise Exception("Model name not provided")
def merge_across_models(req, avail_endpoints, suffix_used, headers):
final_response = None
final_status_code = None
for model_name, selected_endpoint in avail_endpoints.items():
full_url = f"{selected_endpoint}/v1/{suffix_used}"
response = requests.get(full_url, params=req.args, headers=headers)
if final_response is None:
final_response = response.json()
else:
final_response["data"].extend(response.json()["data"])
final_status_code = final_status_code or response.status_code
return final_response, final_status_code
@app.route('/v1/<path:suffix>', methods=['POST', 'GET'])
def api_load_balancer(suffix):
headers = dict(request.headers)
excluded_headers = ['Host', 'Content-Length', 'Content-Type']
headers = {key: value for key, value in headers.items() if key not in excluded_headers}
if request.method == 'POST':
response_json, response_status_code = check_model(request, endpoints, suffix, headers)
elif request.method == 'GET':
response_json, response_status_code = merge_across_models(request, endpoints, suffix, headers)
return jsonify(response_json), response_status_code
if __name__ == '__main__':
app.run(debug=True, port=8000)
@wakusoftware
Copy link

Shouldn't this be done with async functions? Maybe best using FastAPI?

@hex-plex
Copy link
Author

I agree that async can be added while streaming support can be added. It was sufficient for my use case. I'm really not experienced with FastAPI, but I will give it a try and make an update. Thanks for the feedback

@linchen111
Copy link

can I use in different machines? maybe model1 in three machines

@hex-plex
Copy link
Author

Yes it will work just change 0.0.0.0 to the ip address of your machine and make sure the port is open.

@matthew-at-qamcom
Copy link

That was useful to me. Thanks!

@yijicheng
Copy link

I agree that async can be added while streaming support can be added. It was sufficient for my use case. I'm really not experienced with FastAPI, but I will give it a try and make an update. Thanks for the feedback

hi, have you implemented async verison?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment