Last active
December 15, 2024 19:34
-
-
Save virattt/2604e735810422fd0bf4e3d7f424313a to your computer and use it in GitHub Desktop.
simple-agent-backtesting.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"collapsed_sections": [ | |
"DsvZo7gjpMGu" | |
], | |
"authorship_tag": "ABX9TyMwSFEu+BRStxiUiDQFWALN", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/virattt/2604e735810422fd0bf4e3d7f424313a/simple-agent-backtesting.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "jDTXcDmUj0bV" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install -U --quiet langgraph langchain_openai" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import getpass\n", | |
"import os\n", | |
"\n", | |
"\n", | |
"def _set_if_undefined(var: str):\n", | |
" if not os.environ.get(var):\n", | |
" os.environ[var] = getpass.getpass(f\"Please provide your {var}\")\n", | |
"\n", | |
"\n", | |
"_set_if_undefined(\"FINANCIAL_DATASETS_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai\n", | |
"_set_if_undefined(\"OPENAI_API_KEY\") # For getting financial data. Get from https://financialdatasets.ai" | |
], | |
"metadata": { | |
"id": "azDJXqYSl2sF" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import pandas as pd\n", | |
"import requests\n", | |
"import os\n", | |
"import re\n", | |
"from datetime import datetime, timedelta\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"# Import your agent's dependencies\n", | |
"from langchain_openai.chat_models import ChatOpenAI\n", | |
"from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage\n", | |
"from typing import TypedDict, Annotated, Sequence\n", | |
"import operator\n", | |
"from langgraph.graph import StateGraph, MessagesState\n" | |
], | |
"metadata": { | |
"id": "lP3IylZfm744" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 1. Create the Agent" | |
], | |
"metadata": { | |
"id": "-adeYrdHpHvx" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Initialize the OpenAI model\n", | |
"gpt_4o_model = ChatOpenAI(model=\"gpt-4o\", temperature=0)\n", | |
"\n", | |
"# Define the system prompt\n", | |
"system_prompt = \"\"\"\n", | |
"You are a financial trading agent.\n", | |
"Based on the provided historical stock price data, make a trading decision for today.\n", | |
"Your decision should be one of:\n", | |
"- 'buy'\n", | |
"- 'sell'\n", | |
"- 'hold'\n", | |
"Only output the decision, without any additional text.\n", | |
"\"\"\"\n", | |
"\n", | |
"# Define the function that calls the model\n", | |
"def call_agent(state: MessagesState):\n", | |
" prompt = SystemMessage(content=system_prompt)\n", | |
" # Get the messages\n", | |
" messages = state[\"messages\"]\n", | |
"\n", | |
" # Check if the first message is the prompt\n", | |
" if messages and messages[0].content != system_prompt:\n", | |
" # Add the prompt to the start of the messages\n", | |
" messages.insert(0, prompt)\n", | |
"\n", | |
" # Invoke the model and return the response\n", | |
" return {\"messages\": [gpt_4o_model.invoke(messages)]}\n", | |
"\n", | |
"# Define the agent graph\n", | |
"workflow = StateGraph(MessagesState)\n", | |
"workflow.add_node(\"agent\", call_agent)\n", | |
"workflow.set_entry_point(\"agent\")\n", | |
"app = workflow.compile()\n", | |
"\n", | |
"# Run the agent\n", | |
"def run_agent(content: str):\n", | |
" final_state = app.invoke(\n", | |
" {\"messages\": [HumanMessage(content=content)]},\n", | |
" config={\"configurable\": {\"thread_id\": 42}}\n", | |
" )\n", | |
" return final_state[\"messages\"][-1].content\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "tuMK7Jrxj3eY" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"run_agent(\"hello\")" | |
], | |
"metadata": { | |
"id": "UnlmOdL8j7QD" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 2. Get Price Data" | |
], | |
"metadata": { | |
"id": "DsvZo7gjpMGu" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def get_price_data(ticker, start_date, end_date):\n", | |
" # Add your API key to the headers\n", | |
" headers = {\n", | |
" \"X-API-KEY\": os.environ.get(\"FINANCIAL_DATASETS_API_KEY\")\n", | |
" }\n", | |
"\n", | |
" # Create the URL\n", | |
" url = (\n", | |
" f'https://api.financialdatasets.ai/prices/'\n", | |
" f'?ticker={ticker}'\n", | |
" f'&interval=day'\n", | |
" f'&interval_multiplier=1'\n", | |
" f'&start_date={start_date}'\n", | |
" f'&end_date={end_date}'\n", | |
" )\n", | |
"\n", | |
" # Make API request\n", | |
" response = requests.get(url, headers=headers)\n", | |
"\n", | |
" # Check for successful response\n", | |
" if response.status_code != 200:\n", | |
" raise Exception(f\"Error fetching data: {response.status_code} - {response.text}\")\n", | |
"\n", | |
" # Parse prices from the response\n", | |
" data = response.json()\n", | |
" prices = data.get('prices')\n", | |
" if not prices:\n", | |
" raise ValueError(\"No price data returned\")\n", | |
"\n", | |
" # Convert prices to DataFrame\n", | |
" df = pd.DataFrame(prices)\n", | |
"\n", | |
" # Convert 'time' to datetime and set as index\n", | |
" df['Date'] = pd.to_datetime(df['time'])\n", | |
" df.set_index('Date', inplace=True)\n", | |
"\n", | |
" # Ensure numeric data types\n", | |
" numeric_cols = ['open', 'close', 'high', 'low', 'volume']\n", | |
" for col in numeric_cols:\n", | |
" df[col] = pd.to_numeric(df[col], errors='coerce')\n", | |
"\n", | |
" # Sort by date\n", | |
" df.sort_index(inplace=True)\n", | |
"\n", | |
" return df\n" | |
], | |
"metadata": { | |
"id": "tCYkY3EDmCLM" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 3. Create a backtester" | |
], | |
"metadata": { | |
"id": "eib4alsrpQ3h" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"class Backtester:\n", | |
" def __init__(self, agent, ticker, start_date, end_date, initial_capital):\n", | |
" self.agent = agent\n", | |
" self.ticker = ticker\n", | |
" self.start_date = start_date\n", | |
" self.end_date = end_date\n", | |
" self.initial_capital = initial_capital\n", | |
" self.data = None\n", | |
" self.portfolio = None\n", | |
" self.portfolio_values = []\n", | |
"\n", | |
" def load_data(self):\n", | |
" self.data = get_price_data(self.ticker, self.start_date, self.end_date)\n", | |
" # Ensure data is sorted by date\n", | |
" self.data.sort_index(inplace=True)\n", | |
"\n", | |
" def initialize_portfolio(self):\n", | |
" self.portfolio = {\n", | |
" 'cash': self.initial_capital,\n", | |
" 'stock': 0,\n", | |
" 'portfolio_value': self.initial_capital\n", | |
" }\n", | |
"\n", | |
" def parse_action(self, agent_output):\n", | |
" # Use regular expressions to find 'buy', 'sell', or 'hold' in the agent's output\n", | |
" match = re.search(r'\\b(buy|sell|hold)\\b', agent_output.lower())\n", | |
" if match:\n", | |
" return match.group(1)\n", | |
" else:\n", | |
" # If no valid action is found, default to 'hold'\n", | |
" return 'hold'\n", | |
"\n", | |
" def execute_trade(self, action, current_price):\n", | |
" if action == 'buy' and self.portfolio['cash'] >= current_price:\n", | |
" # Buy as many shares as possible with available cash\n", | |
" shares_to_buy = int(self.portfolio['cash'] // current_price)\n", | |
" self.portfolio['stock'] += shares_to_buy\n", | |
" self.portfolio['cash'] -= shares_to_buy * current_price\n", | |
" elif action == 'sell' and self.portfolio['stock'] > 0:\n", | |
" # Sell all shares\n", | |
" self.portfolio['cash'] += self.portfolio['stock'] * current_price\n", | |
" self.portfolio['stock'] = 0\n", | |
" # else 'hold' or not enough cash/stock to trade\n", | |
"\n", | |
" def run_backtest(self):\n", | |
" window_size = 5 # Number of days of historical data to provide to the agent\n", | |
" data = self.data\n", | |
"\n", | |
" # Ensure there are enough data points\n", | |
" if len(data) < window_size:\n", | |
" raise ValueError(\"Not enough data to perform backtest.\")\n", | |
"\n", | |
" print(\"\\nStarting backtest...\")\n", | |
" print(f\"{'Date':<12} {'Action':<6} {'Price':>8} {'Shares':>8} {'Cash':>12} {'Total Value':>12}\")\n", | |
" print(\"-\" * 60)\n", | |
"\n", | |
" for idx in range(window_size, len(data)):\n", | |
" current_date = data.index[idx]\n", | |
" # Get historical data up to the current date\n", | |
" historical_data = data.iloc[idx - window_size:idx]\n", | |
" historical_prices = historical_data['close'].tolist()\n", | |
" price_history_str = ', '.join([f\"{price:.2f}\" for price in historical_prices])\n", | |
"\n", | |
" content = f\"\"\"\n", | |
" Here is the closing price data for the last {window_size} days: {price_history_str}.\n", | |
"\n", | |
" Based on this data, what is your trading decision for today? Please respond with only one word: 'buy', 'sell', or 'hold'.\n", | |
" \"\"\"\n", | |
"\n", | |
" # Agent makes a decision\n", | |
" agent_output = self.agent(content)\n", | |
" action = self.parse_action(agent_output)\n", | |
" current_price = data.loc[current_date, 'close']\n", | |
"\n", | |
" # Execute the agent's action\n", | |
" self.execute_trade(action, current_price)\n", | |
"\n", | |
" # Update total portfolio value\n", | |
" total_value = self.portfolio['cash'] + self.portfolio['stock'] * current_price\n", | |
" self.portfolio['portfolio_value'] = total_value\n", | |
"\n", | |
" # Log the current state\n", | |
" print(f\"{current_date.strftime('%Y-%m-%d'):<12} {action:<6} {current_price:>8.2f} {self.portfolio['stock']:>8d} {self.portfolio['cash']:>12.2f} {total_value:>12.2f}\")\n", | |
"\n", | |
" # Record the portfolio value\n", | |
" self.portfolio_values.append({\n", | |
" 'Date': current_date,\n", | |
" 'Portfolio Value': total_value\n", | |
" })\n", | |
"\n", | |
" def analyze_performance(self):\n", | |
" # Convert portfolio values to DataFrame\n", | |
" performance_df = pd.DataFrame(self.portfolio_values).set_index('Date')\n", | |
"\n", | |
" # Calculate total return\n", | |
" total_return = (self.portfolio['portfolio_value'] - self.initial_capital) / self.initial_capital\n", | |
" print(f\"Total Return: {total_return * 100:.2f}%\")\n", | |
"\n", | |
" # Plot the portfolio value over time\n", | |
" performance_df['Portfolio Value'].plot(title='Portfolio Value Over Time', figsize=(12,6))\n", | |
" plt.ylabel('Portfolio Value ($)')\n", | |
" plt.xlabel('Date')\n", | |
" plt.show()\n", | |
"\n", | |
" # Compute daily returns\n", | |
" performance_df['Daily Return'] = performance_df['Portfolio Value'].pct_change()\n", | |
"\n", | |
" # Calculate Sharpe Ratio (assuming 252 trading days in a year)\n", | |
" mean_daily_return = performance_df['Daily Return'].mean()\n", | |
" std_daily_return = performance_df['Daily Return'].std()\n", | |
" sharpe_ratio = (mean_daily_return / std_daily_return) * (252 ** 0.5)\n", | |
" print(f\"Sharpe Ratio: {sharpe_ratio:.2f}\")\n", | |
"\n", | |
" # Calculate Maximum Drawdown\n", | |
" rolling_max = performance_df['Portfolio Value'].cummax()\n", | |
" drawdown = performance_df['Portfolio Value'] / rolling_max - 1\n", | |
" max_drawdown = drawdown.min()\n", | |
" print(f\"Maximum Drawdown: {max_drawdown * 100:.2f}%\")\n", | |
"\n", | |
" return performance_df" | |
], | |
"metadata": { | |
"id": "sycJCgYunBfq" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 4. Run the Backtest" | |
], | |
"metadata": { | |
"id": "8gbjCJT-pTa0" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Define parameters\n", | |
"ticker = 'AAPL' # Example ticker symbol\n", | |
"start_date = '2024-01-01' # Adjust as needed\n", | |
"end_date = '2024-01-31' # Adjust as needed\n", | |
"initial_capital = 100000 # $100,000\n", | |
"\n", | |
"# Create an instance of Backtester\n", | |
"backtester = Backtester(\n", | |
" agent=run_agent,\n", | |
" ticker=ticker,\n", | |
" start_date=start_date,\n", | |
" end_date=end_date,\n", | |
" initial_capital=initial_capital\n", | |
")\n", | |
"\n", | |
"# Run the backtesting process\n", | |
"backtester.load_data()\n", | |
"backtester.initialize_portfolio()\n", | |
"backtester.run_backtest()\n", | |
"performance_df = backtester.analyze_performance()" | |
], | |
"metadata": { | |
"id": "ro6_juA_nHl-" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment