Skip to content

Instantly share code, notes, and snippets.

@eromoe
Forked from Burntt/demo_purgedkfoldcv.ipynb
Created March 10, 2023 06:16
Show Gist options
  • Save eromoe/ef30bc4670f7f5b66f4ad6fb282a64eb to your computer and use it in GitHub Desktop.
Save eromoe/ef30bc4670f7f5b66f4ad6fb282a64eb to your computer and use it in GitHub Desktop.
Demo_PurgedKFoldCV.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Demo_PurgedKFoldCV.ipynb",
"provenance": [],
"collapsed_sections": [],
"authorship_tag": "ABX9TyMOK1b4162zSa6YVa3alK99",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/Burntt/f26e5414205542207949aeb9e9cc1ddb/demo_purgedkfoldcv.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"# The Combinatorial Purged Cross-Validation method: indexing example on crypto\n",
"\n",
"*By Berend Gort*\n",
"\n",
"www.medium.com/@CoderBurnt\n",
"\n",
"www.linkedin.com/in/berendgort/\n",
"\n",
"www.twitter.com/CoderBurnt\n",
"\n",
"\n"
],
"metadata": {
"id": "KFmER7NGbpgZ"
}
},
{
"cell_type": "markdown",
"source": [
"### Packages"
],
"metadata": {
"id": "RXozZT8vckj-"
}
},
{
"cell_type": "code",
"source": [
"# Install required packages\n",
"\n",
"%cd /\n",
"!git clone https://github.com/AI4Finance-Foundation/FinRL-Meta\n",
"%cd /FinRL-Meta/\n",
"!pip install git+https://github.com/AI4Finance-LLC/ElegantRL.git\n",
"!pip install git+https://github.com/AI4Finance-LLC/FinRL-Library.git\n",
"!pip install gputil\n",
"!pip install trading_calendars\n",
"!pip install fracdiff\n",
"!pip install timeseriescv\n",
"\n",
"#install TA-lib (technical analysis)\n",
"!wget http://prdownloads.sourceforge.net/ta-lib/ta-lib-0.4.0-src.tar.gz \n",
"!tar xvzf ta-lib-0.4.0-src.tar.gz\n",
"import os\n",
"os.chdir('ta-lib') \n",
"!./configure --prefix=/usr\n",
"!make \n",
"!make install\n",
"os.chdir('../')\n",
"!pip install TA-Lib\n",
"!pip install python-binance"
],
"metadata": {
"id": "RQGyziIBbmpk"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Imports"
],
"metadata": {
"id": "OA9PXZ2Ccpnd"
}
},
{
"cell_type": "code",
"source": [
"# Other imports\n",
"\n",
"import scipy as sp\n",
"import math\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import matplotlib.dates as mdates\n",
"import numpy as np\n",
"import pickle\n",
"import shutil\n",
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import itertools as itt\n",
"import numbers\n",
"import datetime\n",
"\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import seaborn as sns\n",
"\n",
"from datetime import datetime, timedelta\n",
"from talib.abstract import *\n",
"from binance.client import Client\n",
"from pandas.testing import assert_frame_equal\n",
"from sklearn import metrics\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"from sklearn.preprocessing import MinMaxScaler \n",
"from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler\n",
"from IPython.display import display, HTML\n",
"\n",
"from itertools import combinations\n",
"from abc import abstractmethod\n",
"from typing import Iterable, Tuple, List\n",
"\n",
"#from google.colab import files"
],
"metadata": {
"id": "FRaRl37NbOpo"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Plot settings\n",
"\n",
"SCALE_FACTOR = 1\n",
"\n",
"plt.style.use('seaborn')\n",
"plt.rcParams['figure.figsize'] = [5 * SCALE_FACTOR, 2 * SCALE_FACTOR]\n",
"plt.rcParams['figure.dpi'] = 300 * SCALE_FACTOR\n",
"plt.rcParams['font.size'] = 5 * SCALE_FACTOR\n",
"plt.rcParams['axes.labelsize'] = 5 * SCALE_FACTOR\n",
"plt.rcParams['axes.titlesize'] = 6 * SCALE_FACTOR\n",
"plt.rcParams['xtick.labelsize'] = 4 * SCALE_FACTOR\n",
"plt.rcParams['ytick.labelsize'] = 4 * SCALE_FACTOR\n",
"plt.rcParams['font.family'] = 'serif'"
],
"metadata": {
"id": "vywKmrJ0bQFJ"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## This requires Binance API keys\n",
"\n",
"Video of how to get them easily:\n",
"\n",
"https://www.youtube.com/watch?v=qg-oboAY8rM"
],
"metadata": {
"id": "dzCD3zmCcrOL"
}
},
{
"cell_type": "code",
"source": [
"# Set your Binance data API keys!\n",
"\n",
"if not 'API_KEY_Binance' in locals():\n",
" print('Please enter your main API key:')\n",
" API_KEY_Binance = input()\n",
"\n",
" print('Please enter your secret API key:')\n",
" API_SECRET_Binance = input()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Aml-mqY5bRbL",
"outputId": "155047f4-15aa-4d59-a044-8dd4aa0daf81"
},
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Please enter your main API key:\n",
"qJHiV64YMnIAxA1nQFJOqJf8I9ZHaSfex44EMwLARiWHDGarV9vnvGRJ6na3K6Dp\n",
"Please enter your secret API key:\n",
"HS3c4TjhLEmMA4U7vGPS6poADyOX32V57jfkqwKOL6cwSz3Ikld6YHxQXQDO51t8\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def get_features_for_each_coin(tic_df):\n",
" tic_df['rsi'] = RSI(tic_df['close'], timeperiod=14)\n",
" tic_df['macd'], tic_df['macd_signal'], tic_df['macd_hist'] = MACD(tic_df['close'], fastperiod=12,\n",
" slowperiod=26, signalperiod=9)\n",
" tic_df['cci'] = CCI(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)\n",
" tic_df['dx'] = DX(tic_df['high'], tic_df['low'], tic_df['close'], timeperiod=14)\n",
" return tic_df"
],
"metadata": {
"id": "hkcnv_DcbUmb"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Get data "
],
"metadata": {
"id": "dJP3-VGIdEwv"
}
},
{
"cell_type": "code",
"source": [
"class BinanceProcessor():\n",
" def __init__(self, api_key_binance, api_secret_binance):\n",
" self.binance_api_key = api_key_binance # Enter your own API-key here\n",
" self.binance_api_secret = api_secret_binance # Enter your own API-secret here\n",
" self.binance_client = Client(api_key=api_key_binance, api_secret=api_secret_binance)\n",
"\n",
" def run(self, ticker_list, start_date, end_date, time_interval, technical_indicator_list, if_vix):\n",
" data = self.download_data(ticker_list, start_date, end_date, time_interval)\n",
" data = self.clean_data(data)\n",
" data = self.add_technical_indicator(data, technical_indicator_list)\n",
" data.index = data['time']\n",
"\n",
" if if_vix:\n",
" data = self.add_vix(data)\n",
"\n",
" price_array, tech_array, turbulence_array, time_array = self.df_to_array(data, if_vix)\n",
"\n",
" tech_nan_positions = np.isnan(tech_array)\n",
" tech_array[tech_nan_positions] = 0\n",
"\n",
" return data\n",
"\n",
" # main functions\n",
" def download_data(self, ticker_list, start_date, end_date,\n",
" time_interval):\n",
"\n",
" self.start_time = start_date\n",
" self.end_time = end_date\n",
" self.interval = time_interval\n",
" self.ticker_list = ticker_list\n",
"\n",
" final_df = pd.DataFrame()\n",
" for i in ticker_list:\n",
" hist_data = self.get_binance_bars(self.start_time, self.end_time, self.interval, symbol=i)\n",
" df = hist_data.iloc[:-1]\n",
" df = df.dropna()\n",
" df['tic'] = i\n",
" final_df = final_df.append(df)\n",
"\n",
" return final_df\n",
"\n",
" def clean_data(self, df):\n",
" df = df.dropna()\n",
"\n",
" return df\n",
"\n",
" def add_technical_indicator(self, df, tech_indicator_list):\n",
" # print('Adding self-defined technical indicators is NOT supported yet.')\n",
" # print('Use default: MACD, RSI, CCI, DX.')\n",
" self.tech_indicator_list = ['open', 'high', 'low', 'close', 'volume',\n",
" 'macd', 'macd_signal', 'macd_hist',\n",
" 'rsi', 'cci', 'dx']\n",
"\n",
" final_df = pd.DataFrame()\n",
" for i in df.tic.unique():\n",
"\n",
" # use massive function in previous cell\n",
" coin_df = df[df.tic == i].copy()\n",
" coin_df = get_features_for_each_coin(coin_df)\n",
"\n",
" # Append constructed tic_df\n",
" final_df = final_df.append(coin_df)\n",
" return final_df\n",
"\n",
" def add_turbulence(self, df):\n",
" print('Turbulence not supported yet. Return original DataFrame.')\n",
"\n",
" return df\n",
"\n",
" def add_vix(self, df):\n",
" print('VIX is not applicable for cryptocurrencies. Return original DataFrame')\n",
"\n",
" return df\n",
"\n",
" def df_to_array(self, df, if_vix):\n",
" unique_ticker = df.tic.unique()\n",
" if_first_time = True\n",
" for tic in unique_ticker:\n",
" if if_first_time:\n",
" price_array = df[df.tic == tic][['close']].values\n",
" tech_array = df[df.tic == tic][self.tech_indicator_list].values\n",
" if_first_time = False\n",
" else:\n",
" price_array = np.hstack([price_array, df[df.tic == tic][['close']].values])\n",
" tech_array = np.hstack([tech_array, df[df.tic == tic][self.tech_indicator_list].values])\n",
"\n",
" time_array = df[df.tic == self.ticker_list[0]]['time'].values\n",
"\n",
" assert price_array.shape[0] == tech_array.shape[0]\n",
"\n",
" return price_array, tech_array, np.array([]), time_array# \n",
"\n",
" # helper functions\n",
" def stringify_dates(self, date: datetime):\n",
" return str(int(date.timestamp() * 1000))\n",
"\n",
" def get_binance_bars(self, start_date, end_date, kline_size, symbol):\n",
" data_df = pd.DataFrame()\n",
" klines = self.binance_client.get_historical_klines(symbol, kline_size, start_date, end_date)\n",
" data = pd.DataFrame(klines,\n",
" columns=['timestamp', 'open', 'high', 'low', 'close', 'volume', 'close_time', 'quote_av',\n",
" 'trades', 'tb_base_av', 'tb_quote_av', 'ignore'])\n",
" data = data.drop(labels=['close_time', 'quote_av', 'trades', 'tb_base_av', 'tb_quote_av', 'ignore'], axis=1)\n",
" if len(data_df) > 0:\n",
" temp_df = pd.DataFrame(data)\n",
" data_df = data_df.append(temp_df)\n",
" else:\n",
" data_df = data\n",
"\n",
" data_df = data_df.apply(pd.to_numeric, errors='coerce')\n",
" data_df['time'] = [datetime.fromtimestamp(x / 1000.0) for x in data_df.timestamp]\n",
" data.drop(labels=[\"timestamp\"], axis=1)\n",
" data_df.index = [x for x in range(len(data_df))]\n",
"\n",
" return data_df\n"
],
"metadata": {
"id": "br8spZ1wbXYy"
},
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Set constants"
],
"metadata": {
"id": "kZyTFb60blRh"
}
},
{
"cell_type": "code",
"source": [
"# Set constants:\n",
"\n",
"ticker_list = ['BTCUSDT'\n",
" ]\n",
"\n",
"\n",
"time_interval = '1d'\n",
"\n",
"# Care format\n",
"start_date = '2015-01-01 00:00:00'\n",
"end_date = '2020-01-01 00:00:00'\n",
"\n",
"\n",
"technical_indicator_list = ['open',\n",
" 'high',\n",
" 'low',\n",
" 'close',\n",
" 'volume',\n",
" 'macd',\n",
" 'macd_signal',\n",
" 'macd_hist',\n",
" 'rsi',\n",
" 'cci',\n",
" 'dx'\n",
" ]\n",
"\n",
"if_vix = False"
],
"metadata": {
"id": "wE4-QHXYbYQW"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Process data using unified data processor\n",
"\n",
"DP = BinanceProcessor(API_KEY_Binance, API_SECRET_Binance)\n",
"data_ohlcv = DP.run(ticker_list,\n",
" start_date,\n",
" end_date,\n",
" time_interval,\n",
" technical_indicator_list,\n",
" if_vix)"
],
"metadata": {
"id": "ot2C0ItubZQD"
},
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Drop unecessary columns\n",
"\n",
"if 'timestamp' in data_ohlcv:\n",
" data_ohlcv.drop('timestamp', inplace=True, axis=1)\n",
"\n",
"if 'time' in data_ohlcv:\n",
" data_ohlcv.drop('time', inplace=True, axis=1)\n",
"\n",
"data_ohlcv.head(3)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 174
},
"id": "yl1W84f1baW_",
"outputId": "3e43695d-d13d-47e5-9e75-f0f725d33930"
},
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" open high low close volume tic rsi \\\n",
"time \n",
"2017-08-17 4261.48 4485.39 4200.74 4285.08 795.150377 BTCUSDT NaN \n",
"2017-08-18 4285.08 4371.52 3938.77 4108.37 1199.888264 BTCUSDT NaN \n",
"2017-08-19 4108.37 4184.69 3850.00 4139.98 381.309763 BTCUSDT NaN \n",
"\n",
" macd macd_signal macd_hist cci dx \n",
"time \n",
"2017-08-17 NaN NaN NaN NaN NaN \n",
"2017-08-18 NaN NaN NaN NaN NaN \n",
"2017-08-19 NaN NaN NaN NaN NaN "
],
"text/html": [
"\n",
" <div id=\"df-8667c1ad-9cb5-4761-8c1f-49d8fd700028\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>open</th>\n",
" <th>high</th>\n",
" <th>low</th>\n",
" <th>close</th>\n",
" <th>volume</th>\n",
" <th>tic</th>\n",
" <th>rsi</th>\n",
" <th>macd</th>\n",
" <th>macd_signal</th>\n",
" <th>macd_hist</th>\n",
" <th>cci</th>\n",
" <th>dx</th>\n",
" </tr>\n",
" <tr>\n",
" <th>time</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2017-08-17</th>\n",
" <td>4261.48</td>\n",
" <td>4485.39</td>\n",
" <td>4200.74</td>\n",
" <td>4285.08</td>\n",
" <td>795.150377</td>\n",
" <td>BTCUSDT</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-08-18</th>\n",
" <td>4285.08</td>\n",
" <td>4371.52</td>\n",
" <td>3938.77</td>\n",
" <td>4108.37</td>\n",
" <td>1199.888264</td>\n",
" <td>BTCUSDT</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-08-19</th>\n",
" <td>4108.37</td>\n",
" <td>4184.69</td>\n",
" <td>3850.00</td>\n",
" <td>4139.98</td>\n",
" <td>381.309763</td>\n",
" <td>BTCUSDT</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-8667c1ad-9cb5-4761-8c1f-49d8fd700028')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-8667c1ad-9cb5-4761-8c1f-49d8fd700028 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-8667c1ad-9cb5-4761-8c1f-49d8fd700028');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"source": [
"# Triple barrier method\n",
"\n",
"I made a previous medium article about this, if you want to understand it please refer to my previous medium article.\n",
"\n",
"However, it is not required to understand PurgedKFoldCV, you can skip on to the next section: PurgedKFoldCV =)\n",
"\n",
"https://medium.com/coinmonks/crypto-feature-importance-for-deep-reinforcement-learning-38416616c2a36-8416616c2a36\n"
],
"metadata": {
"id": "Jr7mFKxRdbUH"
}
},
{
"cell_type": "code",
"source": [
"# IMPORTANT: Make sure that pd.Timedelta() is according to the time_interval to get the volatility for that time interval\n",
"\n",
"if time_interval == '5m':\n",
" Delta = pd.Timedelta(minutes=5)\n",
"elif time_interval == '1h':\n",
" Delta = pd.Timedelta(hours=1)\n",
"elif time_interval == '1d':\n",
" Delta = pd.Timedelta(days=1)\n",
"else:\n",
" raise ValueError('Timeframe not supported yet, please manually add!')"
],
"metadata": {
"id": "TPMJCMx3dwjw"
},
"execution_count": 10,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_vol(prices, span=100, delta=Delta):\n",
"\n",
" # 1. compute returns of the form p[t]/p[t-1] - 1\n",
" # 1.1 find the timestamps of p[t-1] values\n",
" df0 = prices.index.searchsorted(prices.index - delta)\n",
" df0 = df0[df0 > 0]\n",
"\n",
" # 1.2 align timestamps of p[t-1] to timestamps of p[t]\n",
" df0 = pd.Series(prices.index[df0-1], \n",
" index=prices.index[prices.shape[0]-df0.shape[0] : ])\n",
" \n",
" # 1.3 get values by timestamps, then compute returns\n",
" df0 = prices.loc[df0.index] / prices.loc[df0.values].values - 1\n",
"\n",
" # 2. estimate rolling standard deviation\n",
" df0 = df0.ewm(span=span).std()\n",
" \n",
" return df0"
],
"metadata": {
"id": "hqqNxQFIdxBX"
},
"execution_count": 11,
"outputs": []
},
{
"cell_type": "code",
"source": [
"data_ohlcv = data_ohlcv.assign(volatility=get_vol(data_ohlcv.close)).dropna()"
],
"metadata": {
"id": "XhpZGD7VdzRL"
},
"execution_count": 12,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def get_barriers():\n",
"\n",
" #create a container\n",
" barriers = pd.DataFrame(columns=['datapoints_passed', \n",
" 'price', 'vert_barrier', \\\n",
" 'top_barrier', 'bottom_barrier'], \\\n",
" index = daily_volatility.index)\n",
" \n",
" for datapoint, vol in daily_volatility.iteritems():\n",
"\n",
" datapoints_passed = len(daily_volatility.loc \\\n",
" [daily_volatility.index[0] : datapoint])\n",
" \n",
" #set the vertical barrier \n",
" if (datapoints_passed + t_final < len(daily_volatility.index) \\\n",
" and t_final != 0):\n",
" vert_barrier = daily_volatility.index[\n",
" datapoints_passed + t_final]\n",
" else:\n",
" vert_barrier = np.nan\n",
" \n",
" #set the top barrier\n",
" if upper_lower_multipliers[0] > 0:\n",
" top_barrier = prices.loc[datapoint] + prices.loc[datapoint] * \\\n",
" upper_lower_multipliers[0] * vol\n",
" else:\n",
" #set it to NaNs\n",
" top_barrier = pd.Series(index=prices.index)\n",
"\n",
" #set the bottom barrier\n",
" if upper_lower_multipliers[1] > 0:\n",
" bottom_barrier = prices.loc[datapoint] - prices.loc[datapoint] * \\\n",
" upper_lower_multipliers[1] * vol\n",
" else: \n",
" #set it to NaNs\n",
" bottom_barrier = pd.Series(index=prices.index)\n",
" \n",
" barriers.loc[datapoint, ['datapoints_passed', 'price', 'vert_barrier','top_barrier', 'bottom_barrier']] = \\\n",
" datapoints_passed, prices.loc[datapoint], vert_barrier, top_barrier, bottom_barrier\n",
"\n",
" return barriers"
],
"metadata": {
"id": "ZjaI2GDrd0Nr"
},
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Set barrier parameters\n",
"\n",
"daily_volatility = data_ohlcv['volatility']\n",
"t_final = 10\n",
"upper_lower_multipliers = [2, 2]\n",
"price = data_ohlcv['close']\n",
"prices = price[daily_volatility.index]"
],
"metadata": {
"id": "SKWhNCGBd0o4"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"barriers = get_barriers()\n",
"barriers.head(5)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 237
},
"id": "DgqsdrjPd2BQ",
"outputId": "c78b2685-a763-43ee-b527-e7f92b00acce"
},
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" datapoints_passed price vert_barrier top_barrier \\\n",
"time \n",
"2017-09-19 1 3910.04 2017-09-30 00:00:00 4530.750516 \n",
"2017-09-20 2 3900.0 2017-10-01 00:00:00 4508.118671 \n",
"2017-09-21 3 3609.99 2017-10-02 00:00:00 4171.469472 \n",
"2017-09-22 4 3595.87 2017-10-03 00:00:00 4153.372755 \n",
"2017-09-23 5 3780.0 2017-10-04 00:00:00 4360.233001 \n",
"\n",
" bottom_barrier \n",
"time \n",
"2017-09-19 3289.329484 \n",
"2017-09-20 3291.881329 \n",
"2017-09-21 3048.510528 \n",
"2017-09-22 3038.367245 \n",
"2017-09-23 3199.766999 "
],
"text/html": [
"\n",
" <div id=\"df-b1a930d1-a5ec-418e-8712-b33cc24f0307\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>datapoints_passed</th>\n",
" <th>price</th>\n",
" <th>vert_barrier</th>\n",
" <th>top_barrier</th>\n",
" <th>bottom_barrier</th>\n",
" </tr>\n",
" <tr>\n",
" <th>time</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2017-09-19</th>\n",
" <td>1</td>\n",
" <td>3910.04</td>\n",
" <td>2017-09-30 00:00:00</td>\n",
" <td>4530.750516</td>\n",
" <td>3289.329484</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-20</th>\n",
" <td>2</td>\n",
" <td>3900.0</td>\n",
" <td>2017-10-01 00:00:00</td>\n",
" <td>4508.118671</td>\n",
" <td>3291.881329</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-21</th>\n",
" <td>3</td>\n",
" <td>3609.99</td>\n",
" <td>2017-10-02 00:00:00</td>\n",
" <td>4171.469472</td>\n",
" <td>3048.510528</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-22</th>\n",
" <td>4</td>\n",
" <td>3595.87</td>\n",
" <td>2017-10-03 00:00:00</td>\n",
" <td>4153.372755</td>\n",
" <td>3038.367245</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-23</th>\n",
" <td>5</td>\n",
" <td>3780.0</td>\n",
" <td>2017-10-04 00:00:00</td>\n",
" <td>4360.233001</td>\n",
" <td>3199.766999</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b1a930d1-a5ec-418e-8712-b33cc24f0307')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-b1a930d1-a5ec-418e-8712-b33cc24f0307 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-b1a930d1-a5ec-418e-8712-b33cc24f0307');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "code",
"source": [
"def get_labels():\n",
" barriers[\"label_barrier\"] = None\n",
" for i in range(len(barriers.index)):\n",
" start = barriers.index[i]\n",
" end = barriers.vert_barrier[i]\n",
" if pd.notna(end):\n",
"\n",
" # assign the initial and final price\n",
" price_initial = barriers.price[start]\n",
" price_final = barriers.price[end]\n",
"\n",
" # assign the top and bottom barriers\n",
" top_barrier = barriers.top_barrier[i]\n",
" bottom_barrier = barriers.bottom_barrier[i]\n",
"\n",
" #set the profit taking and stop loss conditons\n",
" condition_pt = (barriers.price[start: end] >= \\\n",
" top_barrier).any()\n",
" condition_sl = (barriers.price[start: end] <= \\\n",
" bottom_barrier).any()\n",
"\n",
" #assign the labels\n",
" if condition_pt: \n",
" barriers['label_barrier'][i] = 2\n",
" elif condition_sl: \n",
" barriers['label_barrier'][i] = 0 \n",
" else: \n",
" barriers['label_barrier'][i] = 1\n",
" return"
],
"metadata": {
"id": "3o1h3zZ8d3Ob"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Use function to produce barriers\n",
"\n",
"get_labels()\n",
"barriers\n",
"\n",
"# Merge the barriers with the main dataset and drop the last t_final + 1 barriers (as they are too close to the end)\n",
"\n",
"data_ohlcv = data_ohlcv.merge(barriers[['vert_barrier', 'top_barrier', 'bottom_barrier', 'label_barrier']], left_on='time', right_on='time')\n",
"data_ohlcv.drop(data_ohlcv.tail(t_final + 1).index,inplace = True)\n",
"data_ohlcv = data_ohlcv.drop(['vert_barrier', 'top_barrier', 'bottom_barrier','tic'], axis = 1)\n",
"data_ohlcv.head(5)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 237
},
"id": "GE-s2pGHd4YV",
"outputId": "f256fe81-5f6a-481b-97fc-e77ea5ae169b"
},
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" open high low close volume rsi \\\n",
"time \n",
"2017-09-19 4060.00 4089.97 3830.91 3910.04 902.332129 46.049881 \n",
"2017-09-20 3910.04 4046.08 3820.00 3900.00 720.935076 45.861376 \n",
"2017-09-21 3889.99 3910.00 3567.00 3609.99 1001.654084 40.681139 \n",
"2017-09-22 3592.84 3750.00 3505.55 3595.87 838.966425 40.441621 \n",
"2017-09-23 3595.88 3817.19 3542.91 3780.00 752.792791 44.990039 \n",
"\n",
" macd macd_signal macd_hist cci dx \\\n",
"time \n",
"2017-09-19 -117.398092 -51.421239 -65.976853 -19.131593 38.032614 \n",
"2017-09-20 -114.436313 -64.024254 -50.412060 -15.986404 38.307447 \n",
"2017-09-21 -133.946415 -78.008686 -55.937729 -63.756903 44.459188 \n",
"2017-09-22 -148.832034 -92.173355 -56.658678 -71.873305 45.871014 \n",
"2017-09-23 -144.110031 -102.560690 -41.549340 -35.179194 41.631687 \n",
"\n",
" volatility label_barrier \n",
"time \n",
"2017-09-19 0.079374 1 \n",
"2017-09-20 0.077964 1 \n",
"2017-09-21 0.077767 2 \n",
"2017-09-22 0.077520 2 \n",
"2017-09-23 0.076750 2 "
],
"text/html": [
"\n",
" <div id=\"df-44360e0f-930e-454d-8e78-b85753806e08\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>open</th>\n",
" <th>high</th>\n",
" <th>low</th>\n",
" <th>close</th>\n",
" <th>volume</th>\n",
" <th>rsi</th>\n",
" <th>macd</th>\n",
" <th>macd_signal</th>\n",
" <th>macd_hist</th>\n",
" <th>cci</th>\n",
" <th>dx</th>\n",
" <th>volatility</th>\n",
" <th>label_barrier</th>\n",
" </tr>\n",
" <tr>\n",
" <th>time</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2017-09-19</th>\n",
" <td>4060.00</td>\n",
" <td>4089.97</td>\n",
" <td>3830.91</td>\n",
" <td>3910.04</td>\n",
" <td>902.332129</td>\n",
" <td>46.049881</td>\n",
" <td>-117.398092</td>\n",
" <td>-51.421239</td>\n",
" <td>-65.976853</td>\n",
" <td>-19.131593</td>\n",
" <td>38.032614</td>\n",
" <td>0.079374</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-20</th>\n",
" <td>3910.04</td>\n",
" <td>4046.08</td>\n",
" <td>3820.00</td>\n",
" <td>3900.00</td>\n",
" <td>720.935076</td>\n",
" <td>45.861376</td>\n",
" <td>-114.436313</td>\n",
" <td>-64.024254</td>\n",
" <td>-50.412060</td>\n",
" <td>-15.986404</td>\n",
" <td>38.307447</td>\n",
" <td>0.077964</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-21</th>\n",
" <td>3889.99</td>\n",
" <td>3910.00</td>\n",
" <td>3567.00</td>\n",
" <td>3609.99</td>\n",
" <td>1001.654084</td>\n",
" <td>40.681139</td>\n",
" <td>-133.946415</td>\n",
" <td>-78.008686</td>\n",
" <td>-55.937729</td>\n",
" <td>-63.756903</td>\n",
" <td>44.459188</td>\n",
" <td>0.077767</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-22</th>\n",
" <td>3592.84</td>\n",
" <td>3750.00</td>\n",
" <td>3505.55</td>\n",
" <td>3595.87</td>\n",
" <td>838.966425</td>\n",
" <td>40.441621</td>\n",
" <td>-148.832034</td>\n",
" <td>-92.173355</td>\n",
" <td>-56.658678</td>\n",
" <td>-71.873305</td>\n",
" <td>45.871014</td>\n",
" <td>0.077520</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2017-09-23</th>\n",
" <td>3595.88</td>\n",
" <td>3817.19</td>\n",
" <td>3542.91</td>\n",
" <td>3780.00</td>\n",
" <td>752.792791</td>\n",
" <td>44.990039</td>\n",
" <td>-144.110031</td>\n",
" <td>-102.560690</td>\n",
" <td>-41.549340</td>\n",
" <td>-35.179194</td>\n",
" <td>41.631687</td>\n",
" <td>0.076750</td>\n",
" <td>2</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-44360e0f-930e-454d-8e78-b85753806e08')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
" \n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
" \n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-44360e0f-930e-454d-8e78-b85753806e08 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-44360e0f-930e-454d-8e78-b85753806e08');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n",
" "
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "markdown",
"source": [
"# Combinatorial PurgedKFoldCV"
],
"metadata": {
"id": "PNNUXedJd4s5"
}
},
{
"cell_type": "code",
"source": [
"class BaseTimeSeriesCrossValidator:\n",
" \"\"\"\n",
" Abstract class for time series cross-validation.\n",
" Time series cross-validation requires each sample has a prediction time pred_time, at which the features are used to\n",
" predict the response, and an evaluation time eval_time, at which the response is known and the error can be\n",
" computed. Importantly, it means that unlike in standard sklearn cross-validation, the samples X, response y,\n",
" pred_times and eval_times must all be pandas dataframe/series having the same index. It is also assumed that the\n",
" samples are time-ordered with respect to the prediction time (i.e. pred_times is non-decreasing).\n",
" Parameters\n",
" ----------\n",
" n_splits : int, default=10\n",
" Number of folds. Must be at least 2.\n",
" \"\"\"\n",
" def __init__(self, n_splits=10):\n",
" if not isinstance(n_splits, numbers.Integral):\n",
" raise ValueError(f\"The number of folds must be of Integral type. {n_splits} of type {type(n_splits)}\"\n",
" f\" was passed.\")\n",
" n_splits = int(n_splits)\n",
" if n_splits <= 1:\n",
" raise ValueError(f\"K-fold cross-validation requires at least one train/test split by setting n_splits = 2 \"\n",
" f\"or more, got n_splits = {n_splits}.\")\n",
" self.n_splits = n_splits\n",
" self.pred_times = None\n",
" self.eval_times = None\n",
" self.indices = None\n",
"\n",
" @abstractmethod\n",
" def split(self, X: pd.DataFrame, y: pd.Series = None,\n",
" pred_times: pd.Series = None, eval_times: pd.Series = None):\n",
" if not isinstance(X, pd.DataFrame) and not isinstance(X, pd.Series):\n",
" raise ValueError('X should be a pandas DataFrame/Series.')\n",
" if not isinstance(y, pd.Series) and y is not None:\n",
" raise ValueError('y should be a pandas Series.')\n",
" if not isinstance(pred_times, pd.Series):\n",
" raise ValueError('pred_times should be a pandas Series.')\n",
" if not isinstance(eval_times, pd.Series):\n",
" raise ValueError('eval_times should be a pandas Series.')\n",
" if y is not None and (X.index == y.index).sum() != len(y):\n",
" raise ValueError('X and y must have the same index')\n",
" if (X.index == pred_times.index).sum() != len(pred_times):\n",
" raise ValueError('X and pred_times must have the same index')\n",
" if (X.index == eval_times.index).sum() != len(eval_times):\n",
" raise ValueError('X and eval_times must have the same index')\n",
"\n",
" if not pred_times.equals(pred_times.sort_values()):\n",
" raise ValueError('pred_times should be sorted')\n",
" if not eval_times.equals(eval_times.sort_values()):\n",
" raise ValueError('eval_times should be sorted')\n",
"\n",
" self.pred_times = pred_times\n",
" self.eval_times = eval_times\n",
" self.indices = np.arange(X.shape[0])\n",
"\n",
"class CombPurgedKFoldCVLocal(BaseTimeSeriesCrossValidator):\n",
" \"\"\"\n",
" Purged and embargoed combinatorial cross-validation\n",
" As described in Advances in financial machine learning, Marcos Lopez de Prado, 2018.\n",
" The samples are decomposed into n_splits folds containing equal numbers of samples, without shuffling. In each cross\n",
" validation round, n_test_splits folds are used as the test set, while the other folds are used as the train set.\n",
" There are as many rounds as n_test_splits folds among the n_splits folds.\n",
" Each sample should be tagged with a prediction time pred_time and an evaluation time eval_time. The split is such\n",
" that the intervals [pred_times, eval_times] associated to samples in the train and test set do not overlap. (The\n",
" overlapping samples are dropped.) In addition, an \"embargo\" period is defined, giving the minimal time between an\n",
" evaluation time in the test set and a prediction time in the training set. This is to avoid, in the presence of\n",
" temporal correlation, a contamination of the test set by the train set.\n",
" Parameters\n",
" ----------\n",
" n_splits : int, default=10\n",
" Number of folds. Must be at least 2.\n",
" n_test_splits : int, default=2\n",
" Number of folds used in the test set. Must be at least 1.\n",
" embargo_td : pd.Timedelta, default=0\n",
" Embargo period (see explanations above).\n",
" \"\"\"\n",
" def __init__(self, n_splits=10, n_test_splits=2, embargo_td=pd.Timedelta(minutes=0)):\n",
" super().__init__(n_splits)\n",
" if not isinstance(n_test_splits, numbers.Integral):\n",
" raise ValueError(f\"The number of test folds must be of Integral type. {n_test_splits} of type \"\n",
" f\"{type(n_test_splits)} was passed.\")\n",
" n_test_splits = int(n_test_splits)\n",
" if n_test_splits <= 0 or n_test_splits > self.n_splits - 1:\n",
" raise ValueError(f\"K-fold cross-validation requires at least one train/test split by setting \"\n",
" f\"n_test_splits between 1 and n_splits - 1, got n_test_splits = {n_test_splits}.\")\n",
" self.n_test_splits = n_test_splits\n",
" if not isinstance(embargo_td, pd.Timedelta):\n",
" raise ValueError(f\"The embargo time should be of type Pandas Timedelta. {embargo_td} of type \"\n",
" f\"{type(embargo_td)} was passed.\")\n",
" if embargo_td < pd.Timedelta(minutes=0):\n",
" raise ValueError(f\"The embargo time should be positive, got embargo = {embargo_td}.\")\n",
" self.embargo_td = embargo_td\n",
"\n",
" def split(self, X: pd.DataFrame, y: pd.Series = None,\n",
" pred_times: pd.Series = None, eval_times: pd.Series = None) -> Iterable[Tuple[np.ndarray, np.ndarray]]:\n",
" \"\"\"\n",
" Yield the indices of the train and test sets.\n",
" Although the samples are passed in the form of a pandas dataframe, the indices returned are position indices,\n",
" not labels.\n",
" Parameters\n",
" ----------\n",
" X : pd.DataFrame, shape (n_samples, n_features), required\n",
" Samples. Only used to extract n_samples.\n",
" y : pd.Series, not used, inherited from _BaseKFold\n",
" pred_times : pd.Series, shape (n_samples,), required\n",
" Times at which predictions are made. pred_times.index has to coincide with X.index.\n",
" eval_times : pd.Series, shape (n_samples,), required\n",
" Times at which the response becomes available and the error can be computed. eval_times.index has to\n",
" coincide with X.index.\n",
" Returnst\n",
" -------\n",
" train_indices: np.ndarray\n",
" A numpy array containing all the indices in the train set.\n",
" test_indices : np.ndarray\n",
" A numpy array containing all the indices in the test set.\n",
" \"\"\"\n",
" super().split(X, y, pred_times, eval_times)\n",
"\n",
" # Fold boundaries\n",
" fold_bounds = [(fold[0], fold[-1] + 1) for fold in np.array_split(self.indices, self.n_splits)]\n",
" # List of all combinations of n_test_splits folds selected to become test sets\n",
" selected_fold_bounds = list(itt.combinations(fold_bounds, self.n_test_splits))\n",
" # In order for the first round to have its whole test set at the end of the dataset\n",
" selected_fold_bounds.reverse()\n",
"\n",
" for fold_bound_list in selected_fold_bounds:\n",
" # Computes the bounds of the test set, and the corresponding indices\n",
" test_fold_bounds, test_indices = self.compute_test_set(fold_bound_list)\n",
" # Computes the train set indices\n",
" train_indices = self.compute_train_set(test_fold_bounds, test_indices)\n",
"\n",
" yield train_indices, test_indices\n",
"\n",
" def compute_train_set(self, test_fold_bounds: List[Tuple[int, int]], test_indices: np.ndarray) -> np.ndarray:\n",
" \"\"\"\n",
" Compute the position indices of samples in the train set.\n",
" Parameters\n",
" ----------\n",
" test_fold_bounds : List of tuples of position indices\n",
" Each tuple records the bounds of a block of indices in the test set.\n",
" test_indices : np.ndarray\n",
" A numpy array containing all the indices in the test set.\n",
" Returns\n",
" -------\n",
" train_indices: np.ndarray\n",
" A numpy array containing all the indices in the train set.\n",
" \"\"\"\n",
" # As a first approximation, the train set is the complement of the test set\n",
" train_indices = np.setdiff1d(self.indices, test_indices)\n",
" # But we now have to purge and embargo\n",
" for test_fold_start, test_fold_end in test_fold_bounds:\n",
" # Purge\n",
" train_indices = purge(self, train_indices, test_fold_start, test_fold_end)\n",
" # Embargo\n",
" train_indices = embargo(self, train_indices, test_indices, test_fold_end)\n",
" return train_indices\n",
"\n",
" def compute_test_set(self, fold_bound_list: List[Tuple[int, int]]) -> Tuple[List[Tuple[int, int]], np.ndarray]:\n",
" \"\"\"\n",
" Compute the indices of the samples in the test set.\n",
" Parameterst\n",
" ----------\n",
" fold_bound_list: List of tuples of position indices\n",
" Each tuple records the bounds of the folds belonging to the test set.\n",
" Returns\n",
" -------\n",
" test_fold_bounds: List of tuples of position indices\n",
" Like fold_bound_list, but witest_fold_boundsth the neighboring folds in the test set merged.\n",
" test_indices: np.ndarray\n",
" A numpy array containing the test indices.\n",
" \"\"\"\n",
" test_indices = np.empty(0)\n",
" test_fold_bounds = []\n",
" for fold_start, fold_end in fold_bound_list:\n",
" # Records the boundaries of the current test split\n",
" if not test_fold_bounds or fold_start != test_fold_bounds[-1][-1]:\n",
" test_fold_bounds.append((fold_start, fold_end))\n",
" # If the current test split is contiguous to the previous one, simply updates the endpoint\n",
" elif fold_start == test_fold_bounds[-1][-1]:\n",
" test_fold_bounds[-1] = (test_fold_bounds[-1][0], fold_end)\n",
" test_indices = np.union1d(test_indices, self.indices[fold_start:fold_end]).astype(int)\n",
" return test_fold_bounds, test_indices\n",
"\n",
"\n",
"def compute_fold_bounds(cv: BaseTimeSeriesCrossValidator, split_by_time: bool) -> List[int]:\n",
" \"\"\"\n",
" Compute a list containing the fold (left) boundaries.\n",
" Parameters\n",
" ----------\n",
" cv: BaseTimeSeriesCrossValidator\n",
" Cross-validation object for which the bounds need to be computed.\n",
" split_by_time: bool\n",
" If False, the folds contain an (approximately) equal number of samples. If True, the folds span identical\n",
" time intervals.\n",
" \"\"\"\n",
" if split_by_time:\n",
" full_time_span = cv.pred_times.max() - cv.pred_times.min()\n",
" fold_time_span = full_time_span / cv.n_splits\n",
" fold_bounds_times = [cv.pred_times.iloc[0] + fold_time_span * n for n in range(cv.n_splits)]\n",
" return cv.pred_times.searchsorted(fold_bounds_times)\n",
" else:\n",
" return [fold[0] for fold in np.array_split(cv.indices, cv.n_splits)]\n",
"\n",
"\n",
"def embargo(cv: BaseTimeSeriesCrossValidator, train_indices: np.ndarray,\n",
" test_indices: np.ndarray, test_fold_end: int) -> np.ndarray:\n",
" \"\"\"\n",
" Apply the embargo procedure to part of the train set.\n",
" This amounts to dropping the train set samples whose prediction time occurs within self.embargo_dt of the test\n",
" set sample evaluation times. This method applies the embargo only to the part of the training set immediately\n",
" following the end of the test set determined by test_fold_end.\n",
" Parameters\n",
" -------mestamps of p[t-1] values\n",
" df0 = prices.inde---\n",
" cv: Cross-validation class\n",
" Needs to have the attributes cv.pred_times, cv.eval_times, cv.embargo_dt and cv.indices.\n",
" train_indices: np.ndarray\n",
" A numpy array containing all the indices of the samples currently included in the train set.\n",
" test_indices : np.ndarray\n",
" A numpy array containing all the indices of the samples in the test set.\n",
" test_fold_end : int\n",
" Index corresponding to the end of a test set block.\n",
" Returns\n",
" -------\n",
" train_indices: np.ndarray\n",
" The same array, with the indices subject to embargo removed.\n",
" \"\"\"\n",
" if not hasattr(cv, 'embargo_td'):\n",
" raise ValueError(\"The passed cross-validation object should have a member cv.embargo_td defining the embargo\"\n",
" \"time.\")\n",
" last_test_eval_time = cv.eval_times.iloc[test_indices[test_indices <= test_fold_end]].max()\n",
" min_train_index = len(cv.pred_times[cv.pred_times <= last_test_eval_time + cv.embargo_td])\n",
" if min_train_index < cv.indices.shape[0]:\n",
" allowed_indices = np.concatenate((cv.indices[:test_fold_end], cv.indices[min_train_index:]))\n",
" train_indices = np.intersect1d(train_indices, allowed_indices)\n",
" return train_indices\n",
"\n",
"\n",
"def purge(cv: BaseTimeSeriesCrossValidator, train_indices: np.ndarray,\n",
" test_fold_start: int, test_fold_end: int) -> np.ndarray:\n",
" \"\"\"data_ohlcv\n",
" Purge part of the train set.\n",
" Given a left boundary index test_fold_start of the test set, this method removes from the train set all the\n",
" samples whose evaluation time is posterior to the prediction time of the first test sample after the boundary.\n",
" Parameters\n",
" ----------combinatorial purged k fold\n",
" cv: Cross-validation class\n",
" Needs to have the attributes cv.pred_times, cv.eval_times and cv.indices.\n",
" train_indices: np.ndarray\n",
" A numpy array containing all the indices of the samples currently included in the train set.\n",
" test_fold_start : int\n",
" Index corresponding to the start of a test set block.\n",
" test_fold_end : int\n",
" Index corresponding to the end of the same test set block.\n",
" Returns\n",
" -------\n",
" train_indices: np.ndarray\n",
" A numpy array containing the train indices purged at test_fold_start.\n",
" \"\"\"\n",
" time_test_fold_start = cv.pred_times.iloc[test_fold_start]\n",
" # The train indices before the start of the test fold, purged.\n",
" train_indices_1 = np.intersect1d(train_indices, cv.indices[cv.eval_times < time_test_fold_start])\n",
" # The train indices after the end of the test fold.\n",
" train_indices_2 = np.intersect1d(train_indices, cv.indices[test_fold_end:])\n",
" return np.concatenate((train_indices_1, train_indices_2))"
],
"metadata": {
"id": "6--viSo7i0ql"
},
"execution_count": 31,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### The generator function for the unique paths"
],
"metadata": {
"id": "VI8JXiu1eotV"
}
},
{
"cell_type": "code",
"source": [
"def back_test_paths_generator(t_span, n, k, prediction_times, evaluation_times, verbose=True):\n",
" # split data into N groups, with N << T\n",
" # this will assign each index position to a group position\n",
" group_num = np.arange(t_span) // (t_span // n)\n",
" group_num[group_num == n] = n-1\n",
" \n",
" # generate the combinations \n",
" test_groups = np.array(list(itt.combinations(np.arange(n), k))).reshape(-1, k)\n",
" C_nk = len(test_groups)\n",
" n_paths = C_nk * k // n \n",
" \n",
" \n",
" if verbose:\n",
" print('n_sim:', C_nk)\n",
" print('n_paths:', n_paths)\n",
" \n",
" # is_test is a T x C(n, k) array where each column is a logical array \n",
" # indicating which observation in in the test set\n",
" is_test_group = np.full((n, C_nk), fill_value=False)\n",
" is_test = np.full((t_span, C_nk), fill_value=False)\n",
" \n",
" # assign test folds for each of the C(n, k) simulations\n",
" for k, pair in enumerate(test_groups):\n",
" i, j = pair\n",
" is_test_group[[i, j], k] = True\n",
" \n",
" # assigning the test folds\n",
" mask = (group_num == i) | (group_num == j)\n",
" is_test[mask, k] = True\n",
" \n",
" # for each path, connect the folds from different simulations to form a backtest path\n",
" # the fold coordinates are: the fold number, and the simulation index e.g. simulation 0, fold 0 etc\n",
" path_folds = np.full((n, n_paths), fill_value=np.nan)\n",
" \n",
" for i in range(n_paths):\n",
" for j in range(n):\n",
" s_idx = is_test_group[j, :].argmax().astype(int)\n",
" path_folds[j, i] = s_idx\n",
" is_test_group[j, s_idx] = False\n",
" cv.split(X, y, pred_times=prediction_times, eval_times=evaluation_times)\n",
" \n",
" # finally, for each path we indicate which simulation we're building the path from and the time indices\n",
" paths = np.full((t_span, n_paths), fill_value= np.nan)\n",
" \n",
" for p in range(n_paths):\n",
" for i in range(n):\n",
" mask = (group_num == i)\n",
" paths[mask, p] = int(path_folds[i, p])\n",
" # paths = paths_# .astype(int)\n",
"\n",
" return (is_test, paths, path_folds) "
],
"metadata": {
"id": "MJ9RUhkSeL_j"
},
"execution_count": 53,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### The plotting function for the Combinatorial PurgedKFold\n",
"\n",
"Made this based on https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_indices.html"
],
"metadata": {
"id": "cMr24XD2er1a"
}
},
{
"cell_type": "code",
"source": [
"cmap_data = plt.cm.Paired\n",
"cmap_cv = plt.cm.coolwarm\n",
"\n",
"def plot_cv_indices(cv, X, y, group, ax, n_paths, k, paths, lw=5):\n",
" \"\"\"Create a sample plot for indices of a cross-validation object.\"\"\"\n",
"\n",
" # generate the combinations\n",
" N = n_paths + 1\n",
" test_groups = np.array(list(itt.combinations(np.arange(N), k))).reshape(-1, k)\n",
" n_splits = len(test_groups)\n",
"\n",
" # Generate the training/testing visualizations for each CV split\n",
" for ii, (tr, tt) in enumerate(cv.split(X, y, pred_times=prediction_times, eval_times=evaluation_times)):\n",
"\n",
" # print('fold', ii, '\\n')\n",
" # print(tr, '\\n')\n",
" # print(tt, '\\n')\n",
"\n",
" # Fill in indices with the training/test groups\n",
" indices = np.array([np.nan] * len(X))\n",
" indices[tt] = 1\n",
" indices[tr] = 0\n",
" indices[np.isnan(indices)] = 2\n",
"\n",
" # Visualize the results\n",
" ax.scatter(\n",
" [ii + 0.5] * len(indices),\n",
" range(len(indices)),\n",
" c=[indices],\n",
" marker=\"_\",\n",
" lw=lw,\n",
" cmap=cmap_cv,\n",
" vmin=-0.2,\n",
" vmax=1.2\n",
" )\n",
"\n",
" # Plot the data classes and groups at the end\n",
" ax.scatter(\n",
" [ii + 1.5] * len(X), \n",
" range(len(X)), \n",
" c=y, \n",
" marker=\"_\", \n",
" lw=lw, \n",
" cmap=cmap_data\n",
" )\n",
"\n",
" ax.scatter(\n",
" [ii + 2.5] * len(X), \n",
" range(len(X)), \n",
" c=group, \n",
" marker=\"_\", \n",
" lw=lw, \n",
" cmap=cmap_data\n",
" )\n",
"\n",
" # Formatting\n",
" xlabelz = list(range(n_splits, 0 , -1))\n",
" xlabelz = ['S' + str(x) for x in xlabelz]\n",
" xticklabels = xlabelz + [\"class\", \"group\"]\n",
"\n",
" ax.set(\n",
" xticks=np.arange(n_splits + 2) + 0.45,\n",
" xticklabels=xticklabels,\n",
" ylabel=\"Sample index\",\n",
" xlabel=\"CV iteration\",\n",
" xlim=[n_splits + 2.2, -0.2],\n",
" ylim=[0, X.shape[0]],\n",
" )\n",
" ax.set_title(\"{}\".format(type(cv).__name__), fontsize=5)\n",
" ax.xaxis.tick_top()\n",
"\n",
" return ax"
],
"metadata": {
"id": "8Qco6-mmeRaH"
},
"execution_count": 42,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### Just setting the constants + timeseriescv installation"
],
"metadata": {
"id": "SLBYAi72exrd"
}
},
{
"cell_type": "code",
"source": [
"data = data_ohlcv\n",
"\n",
"data_index = data.index\n",
"\n",
"# Train data\n",
"X = data.drop(['label_barrier'], axis = 1)\n",
"X.drop(X.tail(t_final).index,inplace = True)\n",
"\n",
"# Test data\n",
"y = data[['label_barrier']]\n",
"y.reindex(data_index)\n",
"y = y[:-t_final]\n",
"y = y.squeeze()\n",
"\n",
"# prediction and evalution times\n",
"t1_ = data.index\n",
"\n",
"# recall that we are holding our position for 10 days\n",
"# normally t1 is important is there events such as stop losses, or take profit events\n",
"# Recall t_final from before! This is the maximum of a box!!\n",
"\n",
"# prediction time is moment of observationxticklabels\n",
"prediction_times = pd.Series(t1_[:-t_final], index = X.index)\n",
"\n",
"# evaluation time is moment of evaluation event\n",
"evaluation_times = pd.Series(t1_[t_final:], index = X.index)"
],
"metadata": {
"id": "NMRv7V_9eSWs"
},
"execution_count": 43,
"outputs": []
},
{
"cell_type": "code",
"source": [
"num_paths = 5\n",
"k = 2\n",
"N = num_paths + 1\n",
"embargo_td = Delta * t_final * 2\n",
"cv = CombPurgedKFoldCVLocal(n_splits=N, n_test_splits=k, embargo_td=embargo_td)\n",
"\n",
"# Compute backtest paths\n",
"_, paths, _= back_test_paths_generator(X.shape[0], N, k, prediction_times, evaluation_times)\n",
"\n",
"# Plot PurgedKFold split\n",
"groups = list(range(X.shape[0]))\n",
"fig, ax = plt.subplots()\n",
"plot_cv_indices(cv, X, y, groups, ax, num_paths, k, paths)\n",
"plt.gca().invert_yaxis()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 661
},
"id": "piFfU3SBeUaK",
"outputId": "a7cd3b77-3e6f-455d-f746-d52eafcd8d9b"
},
"execution_count": 45,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"n_sim: 15\n",
"n_paths: 5\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1500x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Paths example"
],
"metadata": {
"id": "zcf2cHw1jE3O"
}
},
{
"cell_type": "code",
"source": [
"def back_test_paths_generator(t_span, n, k, prediction_times, evaluation_times, verbose=True):\n",
" # split data into N groups, with N << T\n",
" # this will assign each index position to a group position\n",
" group_num = np.arange(t_span) // (t_span // n)\n",
" group_num[group_num == n] = n-1\n",
" \n",
" # generate the combinations \n",
" test_groups = np.array(list(itt.combinations(np.arange(n), k))).reshape(-1, k)\n",
" C_nk = len(test_groups)\n",
" n_paths = C_nk * k // n \n",
" \n",
" print(n_paths)\n",
" \n",
" if verbose:\n",
" print('n_sim:', C_nk)\n",
" print('n_paths:', n_paths)\n",
" \n",
" # is_test is a T x C(n, k) array where each column is a logical array \n",
" # indicating which observation in in the test set\n",
" is_test_group = np.full((n, C_nk), fill_value=False)\n",
" is_test = np.full((t_span, C_nk), fill_value=False)\n",
" \n",
" # assign test folds for each of the C(n, k) simulations\n",
" for k, pair in enumerate(test_groups):\n",
" i, j = pair\n",
" is_test_group[[i, j], k] = True\n",
" \n",
" # assigning the test folds\n",
" mask = (group_num == i) | (group_num == j)\n",
" is_test[mask, k] = True\n",
" \n",
" # for each path, connect the folds from different simulations to form a backtest path\n",
" # the fold coordinates are: the fold number, and the simulation index e.g. simulation 0, fold 0 etc\n",
" path_folds = np.full((n, n_paths), fill_value=np.nan)\n",
" \n",
" for i in range(n_paths):\n",
" for j in range(n):\n",
" s_idx = is_test_group[j, :].argmax().astype(int)\n",
" path_folds[j, i] = s_idx\n",
" is_test_group[j, s_idx] = False\n",
" cv.split(X, y, pred_times=prediction_times, eval_times=evaluation_times)\n",
" \n",
" # finally, for each path we indicate which simulation we're building the path from and the time indices\n",
" paths = np.full((t_span, n_paths), fill_value= np.nan)\n",
" \n",
" for p in range(n_paths):\n",
" for i in range(n):\n",
" mask = (group_num == i)\n",
" paths[mask, p] = int(path_folds[i, p])\n",
" # paths = paths_# .astype(int)\n",
"\n",
" return (is_test, paths, path_folds) "
],
"metadata": {
"id": "8OLHBM1gpVlr"
},
"execution_count": 54,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Compute backtest paths\n",
"_, paths, _= back_test_paths_generator(30, 6, k, prediction_times, evaluation_times)\n",
"paths + 1"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ZGEkk6rzjHIM",
"outputId": "4b6bfbd9-2804-434d-8975-f77b31a84885"
},
"execution_count": 62,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"5\n",
"n_sim: 15\n",
"n_paths: 5\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"array([[ 1., 2., 3., 4., 5.],\n",
" [ 1., 2., 3., 4., 5.],\n",
" [ 1., 2., 3., 4., 5.],\n",
" [ 1., 2., 3., 4., 5.],\n",
" [ 1., 2., 3., 4., 5.],\n",
" [ 1., 6., 7., 8., 9.],\n",
" [ 1., 6., 7., 8., 9.],\n",
" [ 1., 6., 7., 8., 9.],\n",
" [ 1., 6., 7., 8., 9.],\n",
" [ 1., 6., 7., 8., 9.],\n",
" [ 2., 6., 10., 11., 12.],\n",
" [ 2., 6., 10., 11., 12.],\n",
" [ 2., 6., 10., 11., 12.],\n",
" [ 2., 6., 10., 11., 12.],\n",
" [ 2., 6., 10., 11., 12.],\n",
" [ 3., 7., 10., 13., 14.],\n",
" [ 3., 7., 10., 13., 14.],\n",
" [ 3., 7., 10., 13., 14.],\n",
" [ 3., 7., 10., 13., 14.],\n",
" [ 3., 7., 10., 13., 14.],\n",
" [ 4., 8., 11., 13., 15.],\n",
" [ 4., 8., 11., 13., 15.],\n",
" [ 4., 8., 11., 13., 15.],\n",
" [ 4., 8., 11., 13., 15.],\n",
" [ 4., 8., 11., 13., 15.],\n",
" [ 5., 9., 12., 14., 15.],\n",
" [ 5., 9., 12., 14., 15.],\n",
" [ 5., 9., 12., 14., 15.],\n",
" [ 5., 9., 12., 14., 15.],\n",
" [ 5., 9., 12., 14., 15.]])"
]
},
"metadata": {},
"execution_count": 62
}
]
},
{
"cell_type": "code",
"source": [
"paths.shape"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UKN6a8JDoWZb",
"outputId": "9dc553f5-5130-40bf-eca9-0176bdc8c8e2"
},
"execution_count": 57,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(20, 5)"
]
},
"metadata": {},
"execution_count": 57
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment