Last active
August 7, 2024 00:41
-
-
Save jishnurajendran/1f96aecc0e4bfdb966b3aa8cfc1c94ad to your computer and use it in GitHub Desktop.
Setup code for using tools with an Ollama model in a Langchain chat system.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from langchain_community.llms import Ollama | |
import requests | |
from langchain_core.tools import tool | |
from langchain_openai import ChatOpenAI | |
from langchain_core.output_parsers import JsonOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.tools import render_text_description | |
from langchain_core.runnables import RunnablePassthrough | |
from typing import Any, Dict, Optional, TypedDict | |
import arxiv | |
# ChatOpenAI setup with ollama | |
chat_model = ChatOpenAI( | |
api_key="ollama", | |
model="llama3.1", | |
base_url="http://localhost:11434/v1", | |
) | |
# Tool definitions | |
@tool | |
def multiply(x: float, y: float) -> float: | |
"""Multiply two numbers together.""" | |
return x * y | |
@tool | |
def add(x: float, y: float) -> float: | |
"Add two numbers." | |
return x + y | |
@tool | |
def get_weather(location: str) -> Dict[str, Any]: | |
"""Get detailed weather information for a specified location as a JSON object.""" | |
url = f"https://wttr.in/{location}?format=j1" | |
response = requests.get(url) | |
if response.status_code == 200: | |
weather_data = response.json() # The response is already in JSON format | |
# Extract relevant information | |
current_condition = weather_data['current_condition'][0] | |
weather_info = { | |
"location": location, | |
"temperature": { | |
"celsius": current_condition['temp_C'], | |
"fahrenheit": current_condition['temp_F'] | |
}, | |
"condition": current_condition['weatherDesc'][0]['value'], | |
"humidity": current_condition['humidity'], | |
"wind": { | |
"speed_kmh": current_condition['windspeedKmph'], | |
"direction": current_condition['winddir16Point'] | |
}, | |
"feels_like": { | |
"celsius": current_condition['FeelsLikeC'], | |
"fahrenheit": current_condition['FeelsLikeF'] | |
}, | |
"visibility": current_condition['visibility'], | |
"pressure": current_condition['pressure'], | |
"precipitation": current_condition['precipMM'], | |
"cloud_cover": current_condition['cloudcover'] | |
} | |
return weather_info | |
else: | |
return {"error": f"Unable to retrieve weather information for {location}"} | |
@tool | |
def query_arxiv(query: str, include_abstract: bool = False) -> Dict[str, Any]: | |
""" | |
Query arXiv for papers based on author or article name. | |
Optionally include the abstract in the results. | |
""" | |
client = arxiv.Client() | |
search = arxiv.Search( | |
query=query, | |
max_results=5, | |
sort_by=arxiv.SortCriterion.Relevance | |
) | |
results = [] | |
for paper in client.results(search): | |
paper_info = { | |
"title": paper.title, | |
"authors": [author.name for author in paper.authors], | |
"published": paper.published.strftime("%Y-%m-%d"), | |
"url": paper.pdf_url | |
} | |
if include_abstract: | |
paper_info["abstract"] = paper.summary | |
results.append(paper_info) | |
return { | |
"query": query, | |
"results": results, | |
"total_results": len(results) | |
} | |
tools = [multiply, add, get_weather, query_arxiv] | |
# Tool inspection | |
for t in tools: | |
print("____________________________________________") | |
print(t.name) | |
print(t.description) | |
print(t.args) | |
# Render tool descriptions | |
rendered_tools = render_text_description(tools) | |
# System prompt | |
system_prompt = f"""\ | |
You are an assistant that has access to the following set of tools. | |
Here are the names and descriptions for each tool: | |
{rendered_tools} | |
Given the user input, return the name and input of the tool to use. | |
Return your response as a JSON blob with 'name' and 'arguments' keys. | |
The `arguments` should be a dictionary, with keys corresponding | |
to the argument names and the values corresponding to the requested values. | |
For the query_arxiv tool, if the user asks specifically for the abstract, | |
include 'include_abstract': true in the arguments. | |
if not asked for the abstract specifically, include 'include_abstract': false in the arguments. | |
""" | |
prompt = ChatPromptTemplate.from_messages( | |
[("system", system_prompt), ("user", "{input}")] | |
) | |
# Tool invocation setup | |
class ToolCallRequest(TypedDict): | |
name: str | |
arguments: Dict[str, Any] | |
def invoke_tool( | |
tool_call_request: ToolCallRequest, config: Optional[Dict[str, Any]] = None | |
): | |
tool_name_to_tool = {tool.name: tool for tool in tools} | |
name = tool_call_request["name"] | |
requested_tool = tool_name_to_tool[name] | |
return requested_tool.invoke(tool_call_request["arguments"], config=config) | |
# Chain setup | |
chain = ( | |
prompt | |
| chat_model | |
| JsonOutputParser() | |
| RunnablePassthrough.assign(output=invoke_tool) | |
) | |
# Example usage | |
result = chain.invoke({"input": "What's the weather like in Thodupuzha?"}) | |
print(result) | |
result = chain.invoke({"input": "Find recent papers by Richard Feynman, don't include their abstracts"}) | |
print(result) | |
result = chain.invoke({"input": "Find papers about 'Optimal Contorl and Ultrastrong coupling' and don't include their abstracts"}) | |
print(result) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment