Last active
September 28, 2024 04:06
-
-
Save anpigon/d263b35c88f97a97557710bf12dc8d97 to your computer and use it in GitHub Desktop.
랭그래프에 도구(Tool) 추가하기
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 랭그래프에 도구(Tool) 추가하기" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 65, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 65, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from dotenv import load_dotenv\n", | |
"\n", | |
"load_dotenv()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 66, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from langchain_core.messages import AIMessage\n", | |
"from langchain_core.tools import tool\n", | |
"\n", | |
"from langgraph.prebuilt import ToolNode" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 도구 추가하기" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 67, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"@tool(\"get_weather\") \n", | |
"def get_weather(location: str):\n", | |
" \"\"\"Call to get the weather in a given location.\"\"\"\n", | |
" if location in [\"서울\", \"Seoul\"]:\n", | |
" return \"기온은 24도이고, 날씨가 좋아요!\"\n", | |
" else:\n", | |
" return \"기온은 18도이고, 날씨가 흐려요!\"\n", | |
" \n", | |
"\n", | |
"@tool(\"coolest_cities\")\n", | |
"def coolest_cities():\n", | |
" \"\"\"Get a list of coolest cities.\"\"\"\n", | |
" return (\"서울\", \"부산\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 68, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"name='get_weather' description='Call to get the weather in a given location.' args_schema=<class 'langchain_core.utils.pydantic.get_weather'> func=<function get_weather at 0x11da67ec0>\n", | |
"name='coolest_cities' description='Get a list of coolest cities.' args_schema=<class 'langchain_core.utils.pydantic.coolest_cities'> func=<function coolest_cities at 0x11da66b60>\n" | |
] | |
} | |
], | |
"source": [ | |
"print(get_weather)\n", | |
"print(coolest_cities)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 69, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"tools = [get_weather, coolest_cities]\n", | |
"tool_node = ToolNode(tools)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 70, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tools(tags=None, recurse=True, func_accepts_config=True, func_accepts={'writer': False}, tools_by_name={'get_weather': StructuredTool(name='get_weather', description='Call to get the weather in a given location.', args_schema=<class 'langchain_core.utils.pydantic.get_weather'>, func=<function get_weather at 0x11da67ec0>), 'coolest_cities': StructuredTool(name='coolest_cities', description='Get a list of coolest cities.', args_schema=<class 'langchain_core.utils.pydantic.coolest_cities'>, func=<function coolest_cities at 0x11da66b60>)}, handle_tool_errors=True)" | |
] | |
}, | |
"execution_count": 70, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tool_node" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 71, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Key 'title' is not supported in schema, ignoring\n", | |
"Key 'title' is not supported in schema, ignoring\n", | |
"Key 'title' is not supported in schema, ignoring\n" | |
] | |
} | |
], | |
"source": [ | |
"from langchain_google_genai import ChatGoogleGenerativeAI\n", | |
"\n", | |
"llm_with_tools = ChatGoogleGenerativeAI(\n", | |
" model=\"gemini-1.5-flash\",\n", | |
" temperature=0,\n", | |
").bind_tools(tools)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 72, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[{'name': 'get_weather',\n", | |
" 'args': {'location': '서울'},\n", | |
" 'id': '95002805-0006-4584-822e-75f0a7589374',\n", | |
" 'type': 'tool_call'}]" | |
] | |
}, | |
"execution_count": 72, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"llm_with_tools.invoke(\"서울 날씨는?\").tool_calls" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 73, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[{'name': 'coolest_cities',\n", | |
" 'args': {},\n", | |
" 'id': '01446505-dba7-476a-a7de-e9905783f909',\n", | |
" 'type': 'tool_call'}]" | |
] | |
}, | |
"execution_count": 73, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"llm_with_tools.invoke(\"한국에서 가장 추운 도시?\").tool_calls" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 74, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"{'messages': [ToolMessage(content='기온은 24도이고, 날씨가 좋아요!', name='get_weather', tool_call_id='216899ba-a3c5-4ac1-b3cf-cbc1b410a5f0')]}" | |
] | |
}, | |
"execution_count": 74, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"tool_node.invoke({\"messages\": [llm_with_tools.invoke(\"서울 날씨는?\")]})" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 75, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from typing import Annotated, Union, Literal, TypedDict\n", | |
"from langchain_core.messages import HumanMessage\n", | |
"from langchain_core.tools import tool\n", | |
"from langgraph.graph import START, END, StateGraph, MessagesState" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 76, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def should_continue(state: MessagesState) -> Literal[\"tools\", END]:\n", | |
" last_message = state[\"messages\"][-1]\n", | |
" if last_message.tool_calls:\n", | |
" return \"tools\"\n", | |
" else:\n", | |
" return END" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 77, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def call_llm(state: MessagesState):\n", | |
" messages = state[\"messages\"]\n", | |
" response = llm_with_tools.invoke(messages)\n", | |
" return {\"messages\": [response]}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 78, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"workflow = StateGraph(MessagesState)\n", | |
"\n", | |
"workflow.add_node(\"agent\", call_llm)\n", | |
"workflow.add_node(\"tools\", tool_node)\n", | |
"\n", | |
"workflow.set_entry_point(\"agent\")\n", | |
"\n", | |
"workflow.add_conditional_edges(\n", | |
" \"agent\",\n", | |
" should_continue,\n", | |
" {\"tools\": \"tools\", END: END},\n", | |
")\n", | |
"\n", | |
"workflow.add_edge(\"tools\", \"agent\")\n", | |
"\n", | |
"graph = workflow.compile()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"image/jpeg": "", | |
"text/plain": [ | |
"<IPython.core.display.Image object>" | |
] | |
}, | |
"execution_count": 79, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from IPython import display\n", | |
"\n", | |
"display.Image(graph.get_graph().draw_mermaid_png())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 103, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"'서울은 지금 날씨가 좋아요! 기온은 24도입니다. \\n'" | |
] | |
}, | |
"execution_count": 103, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"final_state = graph.invoke({\"messages\": [HumanMessage(content=\"서울은 지금 추워? 더워?\")]})\n", | |
"final_state['messages'][-1:][0].content" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 93, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'agent': {'messages': [AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_weather', 'arguments': '{\"location\": \"\\\\uc11c\\\\uc6b8\"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-e6bfc12e-d3e4-4c4b-8de5-47929c246219-0', tool_calls=[{'name': 'get_weather', 'args': {'location': '서울'}, 'id': '234cb4d0-a4e6-48e6-8279-b7e9846edfdd', 'type': 'tool_call'}], usage_metadata={'input_tokens': 76, 'output_tokens': 15, 'total_tokens': 91})]}}\n", | |
"{'tools': {'messages': [ToolMessage(content='기온은 24도이고, 날씨가 좋아요!', name='get_weather', id='9be8c10d-d949-4cda-ae23-b4e6a0805247', tool_call_id='234cb4d0-a4e6-48e6-8279-b7e9846edfdd')]}}\n", | |
"{'agent': {'messages': [AIMessage(content='기온은 24도이고, 날씨가 좋아요! \\n', additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-2ad6d3c0-45de-4f44-bfc3-0140ad1ad8d1-0', usage_metadata={'input_tokens': 122, 'output_tokens': 15, 'total_tokens': 137})]}}\n" | |
] | |
} | |
], | |
"source": [ | |
"for chunk in graph.stream({\"messages\": [HumanMessage(content=\"서울 날씨는?\")]}):\n", | |
" print(chunk)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 92, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"{'messages': [HumanMessage(content='서울 날씨는?', additional_kwargs={}, response_metadata={}, id='a51e588f-b2ae-420e-ac51-26c19df43cae')]}\n", | |
"{'messages': [HumanMessage(content='서울 날씨는?', additional_kwargs={}, response_metadata={}, id='a51e588f-b2ae-420e-ac51-26c19df43cae'), AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_weather', 'arguments': '{\"location\": \"\\\\uc11c\\\\uc6b8\"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-f1b0a2b4-58ef-4a51-a515-955d230f97a7-0', tool_calls=[{'name': 'get_weather', 'args': {'location': '서울'}, 'id': '6a5d3978-7f18-4dd3-a673-b0c521090092', 'type': 'tool_call'}], usage_metadata={'input_tokens': 76, 'output_tokens': 15, 'total_tokens': 91})]}\n", | |
"{'messages': [HumanMessage(content='서울 날씨는?', additional_kwargs={}, response_metadata={}, id='a51e588f-b2ae-420e-ac51-26c19df43cae'), AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_weather', 'arguments': '{\"location\": \"\\\\uc11c\\\\uc6b8\"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-f1b0a2b4-58ef-4a51-a515-955d230f97a7-0', tool_calls=[{'name': 'get_weather', 'args': {'location': '서울'}, 'id': '6a5d3978-7f18-4dd3-a673-b0c521090092', 'type': 'tool_call'}], usage_metadata={'input_tokens': 76, 'output_tokens': 15, 'total_tokens': 91}), ToolMessage(content='기온은 24도이고, 날씨가 좋아요!', name='get_weather', id='9c440713-d8cd-4dda-9f9e-8ff85c6a9ced', tool_call_id='6a5d3978-7f18-4dd3-a673-b0c521090092')]}\n", | |
"{'messages': [HumanMessage(content='서울 날씨는?', additional_kwargs={}, response_metadata={}, id='a51e588f-b2ae-420e-ac51-26c19df43cae'), AIMessage(content='', additional_kwargs={'function_call': {'name': 'get_weather', 'arguments': '{\"location\": \"\\\\uc11c\\\\uc6b8\"}'}}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-f1b0a2b4-58ef-4a51-a515-955d230f97a7-0', tool_calls=[{'name': 'get_weather', 'args': {'location': '서울'}, 'id': '6a5d3978-7f18-4dd3-a673-b0c521090092', 'type': 'tool_call'}], usage_metadata={'input_tokens': 76, 'output_tokens': 15, 'total_tokens': 91}), ToolMessage(content='기온은 24도이고, 날씨가 좋아요!', name='get_weather', id='9c440713-d8cd-4dda-9f9e-8ff85c6a9ced', tool_call_id='6a5d3978-7f18-4dd3-a673-b0c521090092'), AIMessage(content='기온은 24도이고, 날씨가 좋아요! \\n', additional_kwargs={}, response_metadata={'prompt_feedback': {'block_reason': 0, 'safety_ratings': []}, 'finish_reason': 'STOP', 'safety_ratings': [{'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability': 'NEGLIGIBLE', 'blocked': False}]}, id='run-61a3a23a-109f-4331-b96a-997688a7a62e-0', usage_metadata={'input_tokens': 122, 'output_tokens': 15, 'total_tokens': 137})]}\n" | |
] | |
} | |
], | |
"source": [ | |
"for chunk in graph.stream({\"messages\": [HumanMessage(content=\"서울 날씨는?\")]}, stream_mode=\"values\"):\n", | |
" print(chunk)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 96, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"================================\u001b[1m Human Message \u001b[0m=================================\n", | |
"\n", | |
"서울 날씨는?\n", | |
"==================================\u001b[1m Ai Message \u001b[0m==================================\n", | |
"Tool Calls:\n", | |
" get_weather (78e19e92-67e8-4b6a-99db-5db3affa4620)\n", | |
" Call ID: 78e19e92-67e8-4b6a-99db-5db3affa4620\n", | |
" Args:\n", | |
" location: 서울\n", | |
"=================================\u001b[1m Tool Message \u001b[0m=================================\n", | |
"Name: get_weather\n", | |
"\n", | |
"기온은 24도이고, 날씨가 좋아요!\n", | |
"==================================\u001b[1m Ai Message \u001b[0m==================================\n", | |
"\n", | |
"기온은 24도이고, 날씨가 좋아요!\n" | |
] | |
} | |
], | |
"source": [ | |
"for chunk in graph.stream({\"messages\": [HumanMessage(content=\"서울 날씨는?\")]}, stream_mode=\"values\"):\n", | |
" chunk[\"messages\"][-1].pretty_print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 109, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"================================\u001b[1m Human Message \u001b[0m=================================\n", | |
"\n", | |
"가장 추운 도시의 날씨는 어때?\n", | |
"==================================\u001b[1m Ai Message \u001b[0m==================================\n", | |
"Tool Calls:\n", | |
" coolest_cities (dc36e609-ddd3-4dac-893c-58d6f300496c)\n", | |
" Call ID: dc36e609-ddd3-4dac-893c-58d6f300496c\n", | |
" Args:\n", | |
"=================================\u001b[1m Tool Message \u001b[0m=================================\n", | |
"Name: coolest_cities\n", | |
"\n", | |
"[\"서울\", \"부산\"]\n", | |
"==================================\u001b[1m Ai Message \u001b[0m==================================\n", | |
"Tool Calls:\n", | |
" get_weather (059461fb-88c0-4565-aacc-7b3426316184)\n", | |
" Call ID: 059461fb-88c0-4565-aacc-7b3426316184\n", | |
" Args:\n", | |
" location: 서울\n", | |
"=================================\u001b[1m Tool Message \u001b[0m=================================\n", | |
"Name: get_weather\n", | |
"\n", | |
"기온은 24도이고, 날씨가 좋아요!\n", | |
"==================================\u001b[1m Ai Message \u001b[0m==================================\n", | |
"\n", | |
"서울의 기온은 24도이고, 날씨가 좋아요!\n" | |
] | |
} | |
], | |
"source": [ | |
"for chunk in graph.stream({\"messages\": [HumanMessage(content=\"가장 추운 도시의 날씨는 어때?\")]}, stream_mode=\"values\"):\n", | |
" chunk[\"messages\"][-1].pretty_print()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.9" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment