Created
June 19, 2024 22:56
-
-
Save skrawcz/05d85faebb905cba36cbe1f37a5c155d to your computer and use it in GitHub Desktop.
Gist for the Hamilton, Burr, FalkorDB blog post
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@action( | |
reads=["chat_history"], | |
writes=["chat_history"], | |
) | |
def AI_generate_response(state: State, client: openai.Client) -> tuple[dict, State]: | |
"""AI step to generate the response.""" | |
messages = state["chat_history"] | |
response = client.chat.completions.create( | |
model="gpt-4-turbo-preview", | |
messages=messages, | |
) # get a new response from the model where it can see the function response | |
response_message = response.choices[0].message | |
new_state = state.append(chat_history=response_message.to_dict()) | |
return {"ai_response": response_message.content, | |
"usage": response.usage.to_dict()}, new_state |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@action( | |
reads=["question", "chat_history"], | |
writes=["chat_history", "tool_calls"], | |
) | |
def AI_create_cypher_query(state: State, | |
client: openai.Client) -> tuple[dict, State]: | |
"""AI step to create the cypher query.""" | |
messages = state["chat_history"] | |
# Call the function | |
response = client.chat.completions.create( | |
model="gpt-4-turbo-preview", | |
messages=messages, | |
tools=[run_cypher_query_tool_description], | |
tool_choice="auto", | |
) | |
response_message = response.choices[0].message | |
new_state = state.append(chat_history=response_message.to_dict()) | |
tool_calls = response_message.tool_calls | |
if tool_calls: | |
new_state = new_state.update(tool_calls=tool_calls) | |
return {"ai_response": response_message.content, | |
"usage": response.usage.to_dict()}, new_state |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@action(reads=["count"], writes=["count"]) | |
def counter(state: State) -> State: | |
return state.update(counter=state.get("count", 0) +1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from burr.core import ApplicationBuilder, default, expr | |
app = ( | |
ApplicationBuilder() | |
.with_actions( | |
count=count, | |
done=done # implementation left out above | |
).with_transitions( | |
("counter", "counter", expr("count < 10")), # Keep counting if the counter is less than 10 | |
("counter", "done", default) # Otherwise, we're done | |
).with_state(count=0) | |
.with_entrypoint("counter") # we have to start somewhere | |
.build() | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
burr_application = ( | |
ApplicationBuilder() | |
.with_actions( # define the actions | |
AI_create_cypher_query.bind(client=openai_client), | |
tool_call.bind(graph=graph), | |
AI_generate_response.bind(client=openai_client), | |
human_converse | |
) | |
.with_transitions( # define the edges between the actions based on state conditions | |
("human_converse", "AI_create_cypher_query", default), | |
("AI_create_cypher_query", "tool_call", expr("len(tool_calls)>0")), | |
("AI_create_cypher_query", "human_converse", default), | |
("tool_call", "AI_generate_response", default), | |
("AI_generate_response", "human_converse", default) | |
) | |
.with_identifiers(app_id=application_run_id) | |
.with_state( # initial state | |
**{"chat_history": base_messages, "tool_calls": []}, | |
) | |
.with_entrypoint("human_converse") | |
.with_tracker(tracker) | |
.build() | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_cypher_query(graph, query): | |
try: | |
results = graph.ro_query(query).result_set | |
except: | |
results = {"error": "Query failed please try a different variation of this query"} | |
if len(results) == 0: | |
results = { | |
"error": "The query did not return any data, please make sure you're using the right edge " | |
"directions and you're following the correct graph schema"} | |
return str(results) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
run_cypher_query_tool_description = { | |
"type": "function", | |
"function": { | |
"name": "run_cypher_query", | |
"description": "Runs a Cypher query against the knowledge graph", | |
"parameters": { | |
"type": "object", | |
"properties": { | |
"query": { | |
"type": "string", | |
"description": "Query to execute", | |
}, | |
}, | |
"required": ["query"], | |
}, | |
}, | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from hamilton import driver | |
import definitions # contains node definitions, e.g. A, B, C from above | |
dr = driver.Builder().with_modules(definitions).build() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# request node named "C"; returns a dictionary of results | |
results = dr.execute(["C"], inputs={"external_input": 7}) | |
# request node named "B"; returns a dictionary of results | |
results = dr.execute(["B"], inputs={"external_input": 7}) | |
# request node named "B"; returns a dictionary of results | |
results = dr.execute(["A", "B", "C"], inputs={"external_input": 7}) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
while True: | |
question = input("What can I help you with?\n") | |
if question == "exit": | |
break | |
print(f"Human: {question}") | |
action, _, state = burr_application.run( | |
halt_before=["human_converse"], | |
inputs={"user_question": question}, | |
) | |
print(f"AI: {state['chat_history'][-1]['content']}\n") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def set_inital_chat_history(schema_prompt: str) -> list[dict]: | |
SYSTEM_MESSAGE = "You are a Cypher expert with access to a directed knowledge graph\n" | |
SYSTEM_MESSAGE += schema_prompt | |
SYSTEM_MESSAGE += ("Query the knowledge graph to extract relevant information to help you answer the users " | |
"questions, base your answer only on the context retrieved from the knowledge graph, " | |
"do not use preexisting knowledge.") | |
SYSTEM_MESSAGE += ("For example to find out if two fighters had fought each other e.g. did Conor McGregor " | |
"every compete against Jose Aldo issue the following query: " | |
"MATCH (a:Fighter)-[]->(f:Fight)<-[]-(b:Fighter) WHERE a.Name = 'Conor McGregor' AND " | |
"b.Name = 'Jose Aldo' RETURN a, b\n") | |
messages = [{"role": "system", "content": SYSTEM_MESSAGE}] | |
return messages |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@action( | |
reads=["tool_calls", "chat_history"], | |
writes=["tool_calls", "chat_history"], | |
) | |
def tool_call(state: State, graph: falkordb.Graph) -> Tuple[dict, State]: | |
"""Tool call step -- execute the tool call.""" | |
tool_calls = state.get("tool_calls", []) | |
new_state = state | |
result = {"tool_calls": []} | |
for tool_call in tool_calls: | |
function_name = tool_call.function.name | |
assert (function_name == "run_cypher_query") | |
function_args = json.loads(tool_call.function.arguments) | |
function_response = run_cypher_query(graph, function_args.get("query")) | |
new_state = new_state.append(chat_history= | |
{ | |
"tool_call_id": tool_call.id, | |
"role": "tool", | |
"name": function_name, | |
"content": function_response, | |
} | |
) | |
result["tool_calls"].append( | |
{"tool_call_id": tool_call.id, "response": function_response}) | |
new_state = new_state.update(tool_calls=[]) | |
return result, new_state |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def write_to_graph(record: Collect[dict], graph: falkordb.Graph) -> int: | |
"""Take all records and then push to the DB""" | |
records = list(record) | |
# Load all fighters in one go. | |
q = "UNWIND $fighters as fighter CREATE (f:Fighter) SET f = fighter" | |
graph.query(q, {'fighters': records}) | |
return len(records) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
q = "MERGE (:Referee {Name: $name})" | |
_graph.query(q, | |
{'name': _row.Referee | |
if isinstance(_row.Referee, str) else ""}) | |
q = "MERGE (c:Card {Date: $date, Location: $location})" | |
_graph.query(q, {'date': _row.date, 'location': _row.location}) | |
q = """MATCH (c:Card {Date: $date, Location: $location}) | |
MATCH (ref:Referee {Name: $referee}) | |
MATCH (r:Fighter {Name:$R_fighter}) | |
MATCH (b:Fighter {Name:$B_fighter}) | |
CREATE (f:Fight)-[:PART_OF]->(c) | |
SET f = $fight | |
CREATE (f)-[:RED]->(r) | |
CREATE (f)-[:BLUE]->(b) | |
CREATE (ref)-[:REFEREED]->(f) | |
RETURN ID(f) | |
""" | |
f_id = _graph.query(q, | |
{'date': _row.date, | |
'location': _row.location, | |
'referee': _row.Referee | |
if isinstance(_row.Referee, str) else "", | |
'R_fighter': _row.R_fighter, | |
'B_fighter': _row.B_fighter, | |
'fight': {'Last_round': _row.last_round, | |
'Last_round_time': _row.last_round_time, | |
'Format': _row.Format, | |
'Fight_type': _row.Fight_type} | |
} | |
).result_set[0][0] | |
q = """MATCH (f:Fight) WHERE ID(f) = $fight_id | |
MATCH (l:Fighter {Name:$loser}) | |
MATCH (w:Fighter {Name:$winner}) | |
CREATE (w)-[:WON]->(f), (l)-[:LOST]->(f) | |
""" | |
_graph.query(q, | |
{'fight_id': f_id, | |
'loser': _row.Loser, | |
'winner': _row.Winner | |
if isinstance(_row.Winner, str) else ""} | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment