Created
February 21, 2025 05:37
-
-
Save STHITAPRAJNAS/30a7ba9f3d93f4e786c2f46b169d1602 to your computer and use it in GitHub Desktop.
A smart langgraph example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import TypedDict, Annotated, Literal, List | |
from langchain_aws import ChatBedrock | |
from langchain_community.vectorstores import PGVector | |
from langchain_core.messages import HumanMessage, AIMessage | |
from langchain_core.tools import tool | |
from langgraph.graph import StateGraph, END | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_openai import OpenAIEmbeddings | |
import operator | |
# State schema | |
class AgentState(TypedDict): | |
messages: Annotated[List[HumanMessage | AIMessage], operator.add] | |
intent: Literal["confluence", "databricks", "both", "ambiguous", None] | |
confluence_context: List[str] | |
databricks_context: List[str] | |
generated_sql: str | |
sql_attempts: int | |
sql_error: str | |
needs_clarification: bool | |
final_answer: str | |
rewritten_query: str | |
# Environment setup | |
os.environ["AWS_REGION"] = "us-east-1" | |
CONFLUENCE_CONNECTION_STRING = "postgresql+psycopg2://user:password@localhost:5432/confluence_db" | |
DATABRICKS_CONNECTION_STRING = "postgresql+psycopg2://user:password@localhost:5432/databricks_db" | |
llm = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620", region_name="us-east-1", model_kwargs={"temperature": 0.7}) | |
embeddings = OpenAIEmbeddings() | |
confluence_store = PGVector(collection_name="confluence_docs", connection_string=CONFLUENCE_CONNECTION_STRING, embedding_function=embeddings) | |
databricks_store = PGVector(collection_name="databricks_metadata", connection_string=DATABRICKS_CONNECTION_STRING, embedding_function=embeddings) | |
# Tools and utilities | |
@tool | |
def generate_sql(query: str, metadata: List[str], error: str = None) -> str: | |
if error: | |
prompt = f"Refine this SQL query '{query}' that caused error '{error}' using metadata: {metadata}" | |
else: | |
prompt = f"Generate a SQL query for '{query}' using metadata: {metadata}" | |
return llm.invoke(prompt).content | |
def execute_databricks_sql(sql: str) -> dict: | |
if "error" in sql.lower(): | |
return {"status": "error", "error": "Simulated SQL error"} | |
return {"status": "success", "results": ["Simulated results"]} | |
def rewrite_query(query: str) -> str: | |
prompt = f"Rewrite this query to make it clearer and more specific: '{query}'" | |
return llm.invoke(prompt).content | |
def rerank_results(results: List[str], query: str) -> List[str]: | |
return results[::-1] | |
def check_guardrails(answer: str) -> bool: | |
forbidden_words = ["harmful", "inappropriate"] | |
return not any(word in answer.lower() for word in forbidden_words) | |
# Nodes | |
def parse_question(state: AgentState) -> AgentState: | |
question = state["messages"][-1].content | |
prompt = f"""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous': | |
Question: {question} | |
Provide a one-word response: """ | |
response = llm.invoke(prompt).content.strip() | |
state["intent"] = response if response in ["confluence", "databricks", "both", "ambiguous"] else "ambiguous" | |
return state | |
def rewrite_query_node(state: AgentState) -> AgentState: | |
question = state["messages"][-1].content | |
state["rewritten_query"] = rewrite_query(question) | |
return state | |
def route_context(state: AgentState) -> AgentState: | |
query = state["rewritten_query"] | |
if state["intent"] in ["confluence", "both"]: | |
results = confluence_store.similarity_search(query, k=5) | |
state["confluence_context"] = rerank_results([doc.page_content for doc in results], query) | |
if state["intent"] in ["databricks", "both"]: | |
results = databricks_store.similarity_search(query, k=5) | |
state["databricks_context"] = rerank_results([doc.page_content for doc in results], query) | |
if state["intent"] == "ambiguous": | |
state["needs_clarification"] = True | |
return state | |
def clarify_question(state: AgentState) -> AgentState: | |
if state.get("needs_clarification", False): | |
prompt = f"""The question '{state["messages"][-1].content}' is ambiguous. Ask a clarifying question.""" | |
clarification = llm.invoke(prompt).content | |
state["messages"].append(AIMessage(content=clarification)) | |
state["needs_clarification"] = False | |
return state | |
def generate_sql_node(state: AgentState) -> AgentState: | |
if state["intent"] in ["databricks", "both"] and state["databricks_context"]: | |
query = state["rewritten_query"] | |
metadata = state["databricks_context"] | |
error = state.get("sql_error") | |
state["generated_sql"] = generate_sql.invoke({"query": query, "metadata": metadata, "error": error}) | |
return state | |
def execute_sql_node(state: AgentState) -> AgentState: | |
if state.get("generated_sql"): | |
sql = state["generated_sql"] | |
result = execute_databricks_sql(sql) | |
if result["status"] == "success": | |
state["databricks_context"] = result["results"] | |
state["sql_attempts"] = 0 | |
state["sql_error"] = None | |
else: | |
state["sql_error"] = result["error"] | |
state["sql_attempts"] = state.get("sql_attempts", 0) + 1 | |
if state["sql_attempts"] < 3: | |
state = generate_sql_node(state) | |
state = execute_sql_node(state) | |
else: | |
state["databricks_context"] = ["Unable to retrieve data due to persistent errors."] | |
state["sql_attempts"] = 0 | |
state["sql_error"] = None | |
return state | |
def generate_answer(state: AgentState) -> AgentState: | |
query = state["rewritten_query"] | |
confluence_ctx = "\n".join(state.get("confluence_context", [])) | |
databricks_ctx = "\n".join(state.get("databricks_context", [])) | |
prompt = f"""Answer this question: '{query}' | |
Using Confluence context: {confluence_ctx} | |
And Databricks context: {databricks_ctx} | |
Provide a concise, accurate response.""" | |
answer = llm.invoke(prompt).content | |
if check_guardrails(answer): | |
state["final_answer"] = answer | |
state["messages"].append(AIMessage(content=answer)) | |
else: | |
state["final_answer"] = "I'm sorry, but I can't provide that information." | |
state["messages"].append(AIMessage(content=state["final_answer"])) | |
return state | |
# Workflow | |
workflow = StateGraph(AgentState) | |
workflow.add_node("parse_question", parse_question) | |
workflow.add_node("rewrite_query", rewrite_query_node) | |
workflow.add_node("route_context", route_context) | |
workflow.add_node("clarify_question", clarify_question) | |
workflow.add_node("generate_sql", generate_sql_node) | |
workflow.add_node("execute_sql", execute_sql_node) | |
workflow.add_node("generate_answer", generate_answer) | |
workflow.add_edge("parse_question", "rewrite_query") | |
workflow.add_edge("rewrite_query", "route_context") | |
workflow.add_edge("route_context", "clarify_question") | |
workflow.add_conditional_edges( | |
"clarify_question", | |
lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer", | |
{"generate_sql": "generate_sql", "generate_answer": "generate_answer"} | |
) | |
workflow.add_edge("generate_sql", "execute_sql") | |
workflow.add_edge("execute_sql", "generate_answer") | |
workflow.add_edge("generate_answer", END) | |
workflow.set_entry_point("parse_question") | |
checkpointer = MemorySaver() | |
graph = workflow.compile(checkpointer=checkpointer) | |
# Run agent | |
def run_agent(question: str, thread_id: str = "thread_1"): | |
initial_state = {"messages": [HumanMessage(content=question)]} | |
result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}}) | |
return result["messages"][-1].content | |
# Test cases | |
print(run_agent("What is the process for onboarding in Confluence?")) | |
print(run_agent("How many rows are in the sales table in Databricks?")) | |
print(run_agent("What’s the latest update on project X?")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment