Skip to content

Instantly share code, notes, and snippets.

@STHITAPRAJNAS
Created February 21, 2025 05:37
Show Gist options
  • Save STHITAPRAJNAS/30a7ba9f3d93f4e786c2f46b169d1602 to your computer and use it in GitHub Desktop.
Save STHITAPRAJNAS/30a7ba9f3d93f4e786c2f46b169d1602 to your computer and use it in GitHub Desktop.
A smart langgraph example
import os
from typing import TypedDict, Annotated, Literal, List
from langchain_aws import ChatBedrock
from langchain_community.vectorstores import PGVector
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import OpenAIEmbeddings
import operator
# State schema
class AgentState(TypedDict):
messages: Annotated[List[HumanMessage | AIMessage], operator.add]
intent: Literal["confluence", "databricks", "both", "ambiguous", None]
confluence_context: List[str]
databricks_context: List[str]
generated_sql: str
sql_attempts: int
sql_error: str
needs_clarification: bool
final_answer: str
rewritten_query: str
# Environment setup
os.environ["AWS_REGION"] = "us-east-1"
CONFLUENCE_CONNECTION_STRING = "postgresql+psycopg2://user:password@localhost:5432/confluence_db"
DATABRICKS_CONNECTION_STRING = "postgresql+psycopg2://user:password@localhost:5432/databricks_db"
llm = ChatBedrock(model_id="anthropic.claude-3-5-sonnet-20240620", region_name="us-east-1", model_kwargs={"temperature": 0.7})
embeddings = OpenAIEmbeddings()
confluence_store = PGVector(collection_name="confluence_docs", connection_string=CONFLUENCE_CONNECTION_STRING, embedding_function=embeddings)
databricks_store = PGVector(collection_name="databricks_metadata", connection_string=DATABRICKS_CONNECTION_STRING, embedding_function=embeddings)
# Tools and utilities
@tool
def generate_sql(query: str, metadata: List[str], error: str = None) -> str:
if error:
prompt = f"Refine this SQL query '{query}' that caused error '{error}' using metadata: {metadata}"
else:
prompt = f"Generate a SQL query for '{query}' using metadata: {metadata}"
return llm.invoke(prompt).content
def execute_databricks_sql(sql: str) -> dict:
if "error" in sql.lower():
return {"status": "error", "error": "Simulated SQL error"}
return {"status": "success", "results": ["Simulated results"]}
def rewrite_query(query: str) -> str:
prompt = f"Rewrite this query to make it clearer and more specific: '{query}'"
return llm.invoke(prompt).content
def rerank_results(results: List[str], query: str) -> List[str]:
return results[::-1]
def check_guardrails(answer: str) -> bool:
forbidden_words = ["harmful", "inappropriate"]
return not any(word in answer.lower() for word in forbidden_words)
# Nodes
def parse_question(state: AgentState) -> AgentState:
question = state["messages"][-1].content
prompt = f"""Classify this question into one of: 'confluence', 'databricks', 'both', or 'ambiguous':
Question: {question}
Provide a one-word response: """
response = llm.invoke(prompt).content.strip()
state["intent"] = response if response in ["confluence", "databricks", "both", "ambiguous"] else "ambiguous"
return state
def rewrite_query_node(state: AgentState) -> AgentState:
question = state["messages"][-1].content
state["rewritten_query"] = rewrite_query(question)
return state
def route_context(state: AgentState) -> AgentState:
query = state["rewritten_query"]
if state["intent"] in ["confluence", "both"]:
results = confluence_store.similarity_search(query, k=5)
state["confluence_context"] = rerank_results([doc.page_content for doc in results], query)
if state["intent"] in ["databricks", "both"]:
results = databricks_store.similarity_search(query, k=5)
state["databricks_context"] = rerank_results([doc.page_content for doc in results], query)
if state["intent"] == "ambiguous":
state["needs_clarification"] = True
return state
def clarify_question(state: AgentState) -> AgentState:
if state.get("needs_clarification", False):
prompt = f"""The question '{state["messages"][-1].content}' is ambiguous. Ask a clarifying question."""
clarification = llm.invoke(prompt).content
state["messages"].append(AIMessage(content=clarification))
state["needs_clarification"] = False
return state
def generate_sql_node(state: AgentState) -> AgentState:
if state["intent"] in ["databricks", "both"] and state["databricks_context"]:
query = state["rewritten_query"]
metadata = state["databricks_context"]
error = state.get("sql_error")
state["generated_sql"] = generate_sql.invoke({"query": query, "metadata": metadata, "error": error})
return state
def execute_sql_node(state: AgentState) -> AgentState:
if state.get("generated_sql"):
sql = state["generated_sql"]
result = execute_databricks_sql(sql)
if result["status"] == "success":
state["databricks_context"] = result["results"]
state["sql_attempts"] = 0
state["sql_error"] = None
else:
state["sql_error"] = result["error"]
state["sql_attempts"] = state.get("sql_attempts", 0) + 1
if state["sql_attempts"] < 3:
state = generate_sql_node(state)
state = execute_sql_node(state)
else:
state["databricks_context"] = ["Unable to retrieve data due to persistent errors."]
state["sql_attempts"] = 0
state["sql_error"] = None
return state
def generate_answer(state: AgentState) -> AgentState:
query = state["rewritten_query"]
confluence_ctx = "\n".join(state.get("confluence_context", []))
databricks_ctx = "\n".join(state.get("databricks_context", []))
prompt = f"""Answer this question: '{query}'
Using Confluence context: {confluence_ctx}
And Databricks context: {databricks_ctx}
Provide a concise, accurate response."""
answer = llm.invoke(prompt).content
if check_guardrails(answer):
state["final_answer"] = answer
state["messages"].append(AIMessage(content=answer))
else:
state["final_answer"] = "I'm sorry, but I can't provide that information."
state["messages"].append(AIMessage(content=state["final_answer"]))
return state
# Workflow
workflow = StateGraph(AgentState)
workflow.add_node("parse_question", parse_question)
workflow.add_node("rewrite_query", rewrite_query_node)
workflow.add_node("route_context", route_context)
workflow.add_node("clarify_question", clarify_question)
workflow.add_node("generate_sql", generate_sql_node)
workflow.add_node("execute_sql", execute_sql_node)
workflow.add_node("generate_answer", generate_answer)
workflow.add_edge("parse_question", "rewrite_query")
workflow.add_edge("rewrite_query", "route_context")
workflow.add_edge("route_context", "clarify_question")
workflow.add_conditional_edges(
"clarify_question",
lambda state: "generate_sql" if state["intent"] in ["databricks", "both"] else "generate_answer",
{"generate_sql": "generate_sql", "generate_answer": "generate_answer"}
)
workflow.add_edge("generate_sql", "execute_sql")
workflow.add_edge("execute_sql", "generate_answer")
workflow.add_edge("generate_answer", END)
workflow.set_entry_point("parse_question")
checkpointer = MemorySaver()
graph = workflow.compile(checkpointer=checkpointer)
# Run agent
def run_agent(question: str, thread_id: str = "thread_1"):
initial_state = {"messages": [HumanMessage(content=question)]}
result = graph.invoke(initial_state, config={"configurable": {"thread_id": thread_id}})
return result["messages"][-1].content
# Test cases
print(run_agent("What is the process for onboarding in Confluence?"))
print(run_agent("How many rows are in the sales table in Databricks?"))
print(run_agent("What’s the latest update on project X?"))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment