Created
August 12, 2024 23:55
-
-
Save skrawcz/a95989aa4fd1d9647b9c2633dc97301c to your computer and use it in GitHub Desktop.
Shows how to wrap a burr application for delegation to Ray. This is one possible strategy to make things run on Ray.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from IPython.display import Image, display | |
from IPython.core.display import HTML | |
import openai | |
from burr.core import ApplicationBuilder, State, default, graph, when | |
from burr.core.action import action | |
from burr.tracking import LocalTrackingClient | |
MODES = { | |
"answer_question": "text", | |
"generate_image": "image", | |
"generate_code": "code", | |
"unknown": "text", | |
} | |
@action(reads=[], writes=["chat_history", "prompt"]) | |
def process_prompt(state: State, prompt: str) -> State: | |
result = {"chat_item": {"role": "user", "content": prompt, "type": "text"}} | |
state = state.append(chat_history=result["chat_item"]) | |
state = state.update(prompt=prompt) | |
return state | |
@action(reads=["prompt"], writes=["mode"]) | |
def choose_mode(state: State) -> State: | |
prompt = ( | |
f"You are a chatbot. You've been prompted this: {state['prompt']}. " | |
f"You have the capability of responding in the following modes: {', '.join(MODES)}. " | |
"Please respond with *only* a single word representing the mode that most accurately " | |
"corresponds to the prompt. Fr instance, if the prompt is 'draw a picture of a cat', " | |
"the mode would be 'generate_image'. If the prompt is " | |
"'what is the capital of France', the mode would be 'answer_question'." | |
"If none of these modes apply, please respond with 'unknown'." | |
) | |
llm_result = openai.Client().chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": "You are a helpful assistant"}, | |
{"role": "user", "content": prompt}, | |
], | |
) | |
content = llm_result.choices[0].message.content | |
mode = content.lower() | |
if mode not in MODES: | |
mode = "unknown" | |
result = {"mode": mode} | |
return state.update(**result) | |
@action(reads=["prompt", "chat_history"], writes=["response"]) | |
def prompt_for_more(state: State) -> State: | |
result = { | |
"response": { | |
"content": "None of the response modes I support apply to your question. " | |
"Please clarify?", | |
"type": "text", | |
"role": "assistant", | |
} | |
} | |
return state.update(**result) | |
@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) | |
def chat_response( | |
state: State, prepend_prompt: str, model: str = "gpt-3.5-turbo" | |
) -> State: | |
chat_history = copy.deepcopy(state["chat_history"]) | |
chat_history[-1]["content"] = f"{prepend_prompt}: {chat_history[-1]['content']}" | |
chat_history_api_format = [ | |
{ | |
"role": chat["role"], | |
"content": chat["content"], | |
} | |
for chat in chat_history | |
] | |
client = openai.Client() | |
result = client.chat.completions.create( | |
model=model, | |
messages=chat_history_api_format, | |
) | |
text_response = result.choices[0].message.content | |
result = {"response": {"content": text_response, "type": MODES[state["mode"]], "role": "assistant"}} | |
return state.update(**result) | |
@action(reads=["prompt", "chat_history", "mode"], writes=["response"]) | |
def image_response(state: State, model: str = "dall-e-2") -> State: | |
"""Generates an image response to the prompt. Optional save function to save the image to a URL.""" | |
# raise ValueError("Demo error") | |
client = openai.Client() | |
result = client.images.generate( | |
model=model, prompt=state["prompt"], size="1024x1024", quality="standard", n=1 | |
) | |
image_url = result.data[0].url | |
result = {"response": {"content": image_url, "type": MODES[state["mode"]], "role": "assistant"}} | |
return state.update(**result) | |
@action(reads=["response", "mode"], writes=["chat_history"]) | |
def response(state: State) -> State: | |
# you'd do something specific here based on prior state | |
result = {"chat_item": state["response"]} | |
return state.append(chat_history=result["chat_item"]) | |
# Built the graph. | |
base_graph = ( | |
graph.GraphBuilder() | |
.with_actions( | |
# these are the "nodes" | |
prompt=process_prompt, | |
decide_mode=choose_mode, | |
generate_image=image_response, | |
generate_code=chat_response.bind( | |
prepend_prompt="Please respond with *only* code and no other text (at all) to the following:", | |
), | |
answer_question=chat_response.bind( | |
prepend_prompt="Please answer the following question:", | |
), | |
prompt_for_more=prompt_for_more, | |
response=response, | |
) | |
.with_transitions( | |
# these are the edges between nodes, based on state. | |
("prompt", "decide_mode", default), | |
("decide_mode", "generate_image", when(mode="generate_image")), | |
("decide_mode", "generate_code", when(mode="generate_code")), | |
("decide_mode", "answer_question", when(mode="answer_question")), | |
("decide_mode", "prompt_for_more", default), | |
( | |
["generate_image", "answer_question", "generate_code", "prompt_for_more"], | |
"response", | |
), | |
("response", "prompt", default), | |
) | |
.build() | |
) | |
# base_graph.visualize() | |
import ray | |
@ray.remote | |
def run_agent(user_input: str, app_id: str) -> State: | |
""" | |
Write a simple wrapper around creating an agent and calling it. | |
Use a persister for persistence between calls (postgres, etc support partition keys too) | |
Building the graph was above in a different cell. Here Ray is able to serialize the code the application | |
below references fine... | |
""" | |
tracker = LocalTrackingClient(project="agent-demo-ray") # I'm using it as a persister here. | |
app = ( | |
ApplicationBuilder() | |
.with_graph(base_graph) | |
.initialize_from( | |
tracker, | |
resume_at_next_action=True, | |
default_state={"chat_history": []}, | |
default_entrypoint="prompt", | |
) | |
.with_identifiers(app_id=app_id) | |
.with_tracker(tracker) # tracking + checkpointing/persisting; one line 🪄. | |
.build() | |
) | |
last_action, action_result, app_state = app.run( | |
halt_after=["response"], | |
inputs={"prompt": user_input} | |
) | |
return app_state | |
if __name__ == "__main__": | |
ray.init(ignore_reinit_error=True) | |
object_ref = run_agent.remote("what is the capital of france?", "test1234") | |
print(ray.get(object_ref)) | |
# uses prior history because app_id is the same... | |
object_ref = run_agent.remote("write hello world in java", "test1234") | |
print(ray.get(object_ref)) | |
you could also directly pass in the state to the function, rather than using the persister.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I tested this in a notebook and copy pasted it here -- so it should work...