Skip to content

Instantly share code, notes, and snippets.

@virattt
Last active December 15, 2024 19:34
Show Gist options
  • Save virattt/2604e735810422fd0bf4e3d7f424313a to your computer and use it in GitHub Desktop.
Save virattt/2604e735810422fd0bf4e3d7f424313a to your computer and use it in GitHub Desktop.
simple-agent-backtesting.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"collapsed_sections": [
"DsvZo7gjpMGu"
],
"authorship_tag": "ABX9TyMwSFEu+BRStxiUiDQFWALN",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/virattt/2604e735810422fd0bf4e3d7f424313a/simple-agent-backtesting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "jDTXcDmUj0bV"
},
"outputs": [],
"source": [
"!pip install -U --quiet langgraph langchain_openai"
]
},
{
"cell_type": "code",
"source": [
"import getpass\n",
"import os\n",
"\n",
"\n",
"def _set_if_undefined(var: str):\n",
" if not os.environ.get(var):\n",
" os.environ[var] = getpass.getpass(f\"Please provide your {var}\")\n",
"\n",
"\n",
"_set_if_undefined(\"FINANCIAL_DATASETS_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai\n",
"_set_if_undefined(\"OPENAI_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai"
],
"metadata": {
"id": "azDJXqYSl2sF"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import requests\n",
"import os\n",
"import re\n",
"from datetime import datetime, timedelta\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Import your agent's dependencies\n",
"from langchain_openai.chat_models import ChatOpenAI\n",
"from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage\n",
"from typing import TypedDict, Annotated, Sequence\n",
"import operator\n",
"from langgraph.graph import StateGraph, MessagesState\n"
],
"metadata": {
"id": "lP3IylZfm744"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 1. Create the Agent"
],
"metadata": {
"id": "-adeYrdHpHvx"
}
},
{
"cell_type": "code",
"source": [
"# Initialize the OpenAI model\n",
"gpt_4o_model = ChatOpenAI(model=\"gpt-4o\", temperature=0)\n",
"\n",
"# Define the system prompt\n",
"system_prompt = \"\"\"\n",
"You are a financial trading agent.\n",
"Based on the provided historical stock price data, make a trading decision for today.\n",
"Your decision should be one of:\n",
"- 'buy'\n",
"- 'sell'\n",
"- 'hold'\n",
"Only output the decision, without any additional text.\n",
"\"\"\"\n",
"\n",
"# Define the function that calls the model\n",
"def call_agent(state: MessagesState):\n",
" prompt = SystemMessage(content=system_prompt)\n",
" # Get the messages\n",
" messages = state[\"messages\"]\n",
"\n",
" # Check if the first message is the prompt\n",
" if messages and messages[0].content != system_prompt:\n",
" # Add the prompt to the start of the messages\n",
" messages.insert(0, prompt)\n",
"\n",
" # Invoke the model and return the response\n",
" return {\"messages\": [gpt_4o_model.invoke(messages)]}\n",
"\n",
"# Define the agent graph\n",
"workflow = StateGraph(MessagesState)\n",
"workflow.add_node(\"agent\", call_agent)\n",
"workflow.set_entry_point(\"agent\")\n",
"app = workflow.compile()\n",
"\n",
"# Run the agent\n",
"def run_agent(content: str):\n",
" final_state = app.invoke(\n",
" {\"messages\": [HumanMessage(content=content)]},\n",
" config={\"configurable\": {\"thread_id\": 42}}\n",
" )\n",
" return final_state[\"messages\"][-1].content\n",
"\n"
],
"metadata": {
"id": "tuMK7Jrxj3eY"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"run_agent(\"hello\")"
],
"metadata": {
"id": "UnlmOdL8j7QD"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 2. Get Price Data"
],
"metadata": {
"id": "DsvZo7gjpMGu"
}
},
{
"cell_type": "code",
"source": [
"def get_price_data(ticker, start_date, end_date):\n",
" # Add your API key to the headers\n",
" headers = {\n",
" \"X-API-KEY\": os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n",
" }\n",
"\n",
" # Create the URL\n",
" url = (\n",
" f'https://api.financialdatasets.ai/prices/'\n",
" f'?ticker={ticker}'\n",
" f'&interval=day'\n",
" f'&interval_multiplier=1'\n",
" f'&start_date={start_date}'\n",
" f'&end_date={end_date}'\n",
" )\n",
"\n",
" # Make API request\n",
" response = requests.get(url, headers=headers)\n",
"\n",
" # Check for successful response\n",
" if response.status_code != 200:\n",
" raise Exception(f\"Error fetching data: {response.status_code} - {response.text}\")\n",
"\n",
" # Parse prices from the response\n",
" data = response.json()\n",
" prices = data.get('prices')\n",
" if not prices:\n",
" raise ValueError(\"No price data returned\")\n",
"\n",
" # Convert prices to DataFrame\n",
" df = pd.DataFrame(prices)\n",
"\n",
" # Convert 'time' to datetime and set as index\n",
" df['Date'] = pd.to_datetime(df['time'])\n",
" df.set_index('Date', inplace=True)\n",
"\n",
" # Ensure numeric data types\n",
" numeric_cols = ['open', 'close', 'high', 'low', 'volume']\n",
" for col in numeric_cols:\n",
" df[col] = pd.to_numeric(df[col], errors='coerce')\n",
"\n",
" # Sort by date\n",
" df.sort_index(inplace=True)\n",
"\n",
" return df\n"
],
"metadata": {
"id": "tCYkY3EDmCLM"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 3. Create a backtester"
],
"metadata": {
"id": "eib4alsrpQ3h"
}
},
{
"cell_type": "code",
"source": [
"class Backtester:\n",
" def __init__(self, agent, ticker, start_date, end_date, initial_capital):\n",
" self.agent = agent\n",
" self.ticker = ticker\n",
" self.start_date = start_date\n",
" self.end_date = end_date\n",
" self.initial_capital = initial_capital\n",
" self.data = None\n",
" self.portfolio = None\n",
" self.portfolio_values = []\n",
"\n",
" def load_data(self):\n",
" self.data = get_price_data(self.ticker, self.start_date, self.end_date)\n",
" # Ensure data is sorted by date\n",
" self.data.sort_index(inplace=True)\n",
"\n",
" def initialize_portfolio(self):\n",
" self.portfolio = {\n",
" 'cash': self.initial_capital,\n",
" 'stock': 0,\n",
" 'portfolio_value': self.initial_capital\n",
" }\n",
"\n",
" def parse_action(self, agent_output):\n",
" # Use regular expressions to find 'buy', 'sell', or 'hold' in the agent's output\n",
" match = re.search(r'\\b(buy|sell|hold)\\b', agent_output.lower())\n",
" if match:\n",
" return match.group(1)\n",
" else:\n",
" # If no valid action is found, default to 'hold'\n",
" return 'hold'\n",
"\n",
" def execute_trade(self, action, current_price):\n",
" if action == 'buy' and self.portfolio['cash'] >= current_price:\n",
" # Buy as many shares as possible with available cash\n",
" shares_to_buy = int(self.portfolio['cash'] // current_price)\n",
" self.portfolio['stock'] += shares_to_buy\n",
" self.portfolio['cash'] -= shares_to_buy * current_price\n",
" elif action == 'sell' and self.portfolio['stock'] > 0:\n",
" # Sell all shares\n",
" self.portfolio['cash'] += self.portfolio['stock'] * current_price\n",
" self.portfolio['stock'] = 0\n",
" # else 'hold' or not enough cash/stock to trade\n",
"\n",
" def run_backtest(self):\n",
" window_size = 5 # Number of days of historical data to provide to the agent\n",
" data = self.data\n",
"\n",
" # Ensure there are enough data points\n",
" if len(data) < window_size:\n",
" raise ValueError(\"Not enough data to perform backtest.\")\n",
"\n",
" print(\"\\nStarting backtest...\")\n",
" print(f\"{'Date':<12} {'Action':<6} {'Price':>8} {'Shares':>8} {'Cash':>12} {'Total Value':>12}\")\n",
" print(\"-\" * 60)\n",
"\n",
" for idx in range(window_size, len(data)):\n",
" current_date = data.index[idx]\n",
" # Get historical data up to the current date\n",
" historical_data = data.iloc[idx - window_size:idx]\n",
" historical_prices = historical_data['close'].tolist()\n",
" price_history_str = ', '.join([f\"{price:.2f}\" for price in historical_prices])\n",
"\n",
" content = f\"\"\"\n",
" Here is the closing price data for the last {window_size} days: {price_history_str}.\n",
"\n",
" Based on this data, what is your trading decision for today? Please respond with only one word: 'buy', 'sell', or 'hold'.\n",
" \"\"\"\n",
"\n",
" # Agent makes a decision\n",
" agent_output = self.agent(content)\n",
" action = self.parse_action(agent_output)\n",
" current_price = data.loc[current_date, 'close']\n",
"\n",
" # Execute the agent's action\n",
" self.execute_trade(action, current_price)\n",
"\n",
" # Update total portfolio value\n",
" total_value = self.portfolio['cash'] + self.portfolio['stock'] * current_price\n",
" self.portfolio['portfolio_value'] = total_value\n",
"\n",
" # Log the current state\n",
" print(f\"{current_date.strftime('%Y-%m-%d'):<12} {action:<6} {current_price:>8.2f} {self.portfolio['stock']:>8d} {self.portfolio['cash']:>12.2f} {total_value:>12.2f}\")\n",
"\n",
" # Record the portfolio value\n",
" self.portfolio_values.append({\n",
" 'Date': current_date,\n",
" 'Portfolio Value': total_value\n",
" })\n",
"\n",
" def analyze_performance(self):\n",
" # Convert portfolio values to DataFrame\n",
" performance_df = pd.DataFrame(self.portfolio_values).set_index('Date')\n",
"\n",
" # Calculate total return\n",
" total_return = (self.portfolio['portfolio_value'] - self.initial_capital) / self.initial_capital\n",
" print(f\"Total Return: {total_return * 100:.2f}%\")\n",
"\n",
" # Plot the portfolio value over time\n",
" performance_df['Portfolio Value'].plot(title='Portfolio Value Over Time', figsize=(12,6))\n",
" plt.ylabel('Portfolio Value ($)')\n",
" plt.xlabel('Date')\n",
" plt.show()\n",
"\n",
" # Compute daily returns\n",
" performance_df['Daily Return'] = performance_df['Portfolio Value'].pct_change()\n",
"\n",
" # Calculate Sharpe Ratio (assuming 252 trading days in a year)\n",
" mean_daily_return = performance_df['Daily Return'].mean()\n",
" std_daily_return = performance_df['Daily Return'].std()\n",
" sharpe_ratio = (mean_daily_return / std_daily_return) * (252 ** 0.5)\n",
" print(f\"Sharpe Ratio: {sharpe_ratio:.2f}\")\n",
"\n",
" # Calculate Maximum Drawdown\n",
" rolling_max = performance_df['Portfolio Value'].cummax()\n",
" drawdown = performance_df['Portfolio Value'] / rolling_max - 1\n",
" max_drawdown = drawdown.min()\n",
" print(f\"Maximum Drawdown: {max_drawdown * 100:.2f}%\")\n",
"\n",
" return performance_df"
],
"metadata": {
"id": "sycJCgYunBfq"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 4. Run the Backtest"
],
"metadata": {
"id": "8gbjCJT-pTa0"
}
},
{
"cell_type": "code",
"source": [
"# Define parameters\n",
"ticker = 'AAPL' # Example ticker symbol\n",
"start_date = '2024-01-01' # Adjust as needed\n",
"end_date = '2024-01-31' # Adjust as needed\n",
"initial_capital = 100000 # $100,000\n",
"\n",
"# Create an instance of Backtester\n",
"backtester = Backtester(\n",
" agent=run_agent,\n",
" ticker=ticker,\n",
" start_date=start_date,\n",
" end_date=end_date,\n",
" initial_capital=initial_capital\n",
")\n",
"\n",
"# Run the backtesting process\n",
"backtester.load_data()\n",
"backtester.initialize_portfolio()\n",
"backtester.run_backtest()\n",
"performance_df = backtester.analyze_performance()"
],
"metadata": {
"id": "ro6_juA_nHl-"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment