Created
July 15, 2023 18:29
-
-
Save Micky774/05cfb28a6fd1b20ee5ed577452fb9817 to your computer and use it in GitHub Desktop.
Generate benchmarks with a permutation test
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
from os import mkdir | |
from pathlib import Path | |
from os.path import join | |
branch = "PR" | |
path = Path(join("data", f"{branch}.csv")) | |
if not path.parent.exists(): | |
mkdir(path.parent) | |
# %% | |
import numpy as np | |
import pandas as pd | |
from time import time | |
from sklearn.utils import safe_sqr | |
n_samples = 10_000 | |
n_feat = 1_000 | |
N_REPEAT = 30 | |
results = np.empty(N_REPEAT, dtype=np.float64) | |
for i in range(N_REPEAT): | |
rng = np.random.RandomState(i) | |
X = rng.randn(n_samples, n_feat) | |
start = time() | |
safe_sqr(X) | |
results[i] = time() - start | |
df = pd.DataFrame(results, columns=("duration",)) | |
df.to_csv(path, index=False) | |
# %% | |
from scipy.stats import permutation_test, bootstrap | |
from pathlib import Path | |
from os.path import join | |
import pandas as pd | |
branches = ("main", "PR") | |
dfs = [pd.read_csv(Path(join("data", f"{branch}.csv"))) for branch in branches] | |
result_arrs = tuple([df["duration"].values for df in dfs]) | |
# %% | |
from matplotlib import pyplot as plt | |
def plot_null_distribution(simulated_statistics, bins, observed_statistic): | |
"""Plot simulated null distribution.""" | |
# Find which tail is smallest | |
left_tail_prob = np.sum(simulated_statistics <= observed_statistic) / len(simulated_statistics) | |
right_tail_prob = np.sum(simulated_statistics >= observed_statistic) / len(simulated_statistics) | |
tail_idx = np.argmin((left_tail_prob, right_tail_prob)) | |
tail = "right" if tail_idx == 1 else "left" | |
plt.figure(figsize=(10, 7)) | |
counts, bins = np.histogram(simulated_statistics, bins=bins, density=True) | |
bin_pairs = np.column_stack((bins[:-1], bins[1:])) | |
for bin_pair, count in zip(bin_pairs, counts): | |
if tail == "left": | |
color = "C0" if bin_pair[0] > observed_statistic else "C1" | |
else: | |
color = "C0" if bin_pair[0] < observed_statistic else "C1" | |
plt.bar(bin_pair[0], count, width=bin_pair[1] - bin_pair[0], edgecolor="black", color=color) | |
plt.axvline(x=observed_statistic, color="C1") | |
plt.xticks(np.arange(bins[0], bins[-1], 2)) | |
plt.show() | |
# %% | |
# Permutation test to generate p-value; what are the chances this is a symptom of random luck? | |
perm_res = permutation_test(result_arrs, statistic=lambda x, y: x.mean()-y.mean(), alternative="greater") | |
print(f"p-value = {perm_res.pvalue}") | |
obs_stat = result_arrs[0].mean() - result_arrs[1].mean() | |
plot_null_distribution(perm_res.null_distribution, 70, obs_stat) | |
# %% | |
# Bootstrap test to build a confidence interval for the estimate | |
boot_res = bootstrap(result_arrs, statistic=lambda x, y: x.mean() - y.mean(), method="BCa") | |
print(f"95% confidence interval = ({boot_res.confidence_interval[0]:.3f}, {boot_res.confidence_interval[1]:.3f})") | |
plot_null_distribution(boot_res.bootstrap_distribution, 70, obs_stat) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment