Skip to content

Instantly share code, notes, and snippets.

@djsutherland
Last active April 8, 2025 18:52
Show Gist options
  • Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "d325051b-74dd-4c95-9254-14c3275c2aae",
"metadata": {},
"source": [
"Our goal is to efficiently compute permutation tests for MMD which are finite-sample valid,\n",
"based on samples $(X_i)_{i=1}^m$, $(Y_j)_{j=1}^n$."
]
},
{
"cell_type": "markdown",
"id": "4986363a-a710-438d-af2a-49c8546261f3",
"metadata": {},
"source": [
"## Computation\n",
"Write $Z_i = \\begin{cases}X_i & 1 \\le i \\le m \\\\ Y_{i-m} & m < i \\le m + n .\\end{cases}$\n",
"\n",
"First, note that the plug-in estimator (\"the biased estimator\") for MMD is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_b^2}\n",
" &= \\frac1{m^2} \\sum_{i=1}^m \\sum_{i'=1}^m k(X_i, X_{i'})\n",
" + \\frac{1}{n^2} \\sum_{j=1}^n \\sum_{j'=1}^n'k(Y_j, Y_{j'})\n",
" - 2 \\frac{1}{mn} \\sum_{i=1}^m \\sum_{j=1}^n k(X_i, Y_j)\n",
"\\\\&= \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}^\\top\n",
" \\underbrace{\\begin{bmatrix} K_X & K_{XY} \\\\ K_{YX} & K_Y \\end{bmatrix}}_K\n",
" \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}\n",
",\\end{align*}\n",
"where $\\mathbf 1_m$ is a vector of $m$ ones, and $K$ an $(m + n) \\times (m + n)$ matrix with entries $K_{ij} = k(Z_i, Z_j)$.\n",
"To get a permuted value of this estimator, we just need to permute the vector we're hitting it with."
]
},
{
"cell_type": "markdown",
"id": "9b44f61f-1dd5-41e1-b69f-c50611152b06",
"metadata": {},
"source": [
"The typical unbiased estimator (claimed to be the MVUE but I'm not sure this is actually true...) is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_u^2}\n",
" &= \\frac{1}{m (m-1)} \\sum_{i \\ne i'} k(X_i, X_{i'})\n",
" + \\frac{1}{n (n-1)} \\sum_{j \\ne j'} k(Y_j, Y_{j'})\n",
" - \\frac{2}{m n} \\sum_{i, j} k(X_i, Y_j)\n",
";\\end{align*}\n",
"this doesn't seem especially amenable to easy permutation in the same way as the previous one."
]
},
{
"cell_type": "markdown",
"id": "a65afbca-9469-4100-9a2c-f20fc5337fe1",
"metadata": {},
"source": [
"But the U-statistic estimator, which assumes $m = n$ and is unbiased but not quite minimum variance, is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_U^2}\n",
" &= \\frac{1}{n (n-1)} \\sum_{i \\ne j} \\left[ k(X_i, X_j) + k(Y_i, Y_j) - k(X_i, Y_j) - k(X_j, Y_i) \\right]\n",
"\\\\&= \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(X_i, X_j)\n",
" + \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(Y_i, Y_j)\n",
" - \\frac{2}{n(n-1)} \\sum_{i \\ne j} k(X_i, Y_j)\n",
"\\\\&= \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(X_i, X_j) - \\sum_i k(X_i, X_i) \\right)\n",
" + \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(Y_i, Y_j) - \\sum_i k(Y_i, Y_i) \\right)\n",
"\\\\&\\qquad\n",
" - \\frac{2}{n(n-1)} \\left( \\sum_{i,j} k(X_i, Y_j) - \\sum_i k(X_i, Y_i) \\right)\n",
"\\\\&= \\frac{n}{n-1} \\widehat{\\operatorname{MMD}_b^2}\n",
" - \\frac{1}{n(n-1)} \\left( \\sum_i k(X_i, X_i) + \\sum_i k(Y_i, Y_i) - 2 \\sum_i k(X_i, Y_i) \\right)\n",
";\\end{align*}\n",
"the first two terms of the correction are simply the trace of $K$ and don't depend on the particular permutation.\n",
"The third term does, but isn't so bad."
]
},
{
"cell_type": "markdown",
"id": "a23f6a33-72d7-4549-af56-313e5b11b9cf",
"metadata": {},
"source": [
"## p-value\n",
"As discussed in section 3.4 of Hemerik and Goeman (STAT 2018), [Exact testing with random permutations](https://arxiv.org/abs/1411.7565),\n",
"let $T_1, T_2, \\dots, T_w$ be the $w$ permuted test statistics returned by the previous procedure,\n",
"where $T_1$ is the actual data split\n",
"and $2$ through $w$ are uniformly random permutations.\n",
"A (possibly conservative) $p$-value is $\\lvert \\{ k \\in [w] : T_k \\ge T_1 \\} \\rvert / w$.\n",
"An exact (randomized) $p$-value can be found based on counting ties, but whatever."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "01140287-8c45-4dff-b1d4-cddd0f3db08a",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e31cd9f-392e-4c09-821e-bd7827a7379c",
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"\n",
"PermutationResult = namedtuple(\n",
" \"PermutationResult\", [\"estimate\", \"p_value\", \"permuted_estimates\"]\n",
")\n",
"\n",
"\n",
"def mmd2_permutation(joint_kernel, n_X, n_perm=500, u_stat=False):\n",
" \"\"\"\n",
" joint_kernel: should be an array of shape [n_X + n_Y, n_X + n_Y] with kernel values,\n",
" ie [ K_XX K_XY ]\n",
" [ K_YX K_YY ]\n",
" n_X: number of entries in the first set\n",
" n_perm: total number of permutations, including the identity\n",
"\n",
" If biased is True, uses the plug-in estimator (MMD between empirical distributions).\n",
" If False, it uses the U-statistic estimator, which is unbiased but drops k(x_i, y_i) terms.\n",
" (I'm not sure how to implement the \"unbiased estimator\" (which includes those terms) efficiently.)\n",
" \"\"\"\n",
" K = joint_kernel = torch.as_tensor(joint_kernel)\n",
" device = K.device\n",
" dtype = K.dtype\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"joint_kernel should be square, got {K.shape}\")\n",
" n_X = int(n_X)\n",
" n_Y = n - n_X\n",
" if n_X <= 0 or n_Y <= 0:\n",
" raise ValueError(\"need a positive number of samples from each\")\n",
"\n",
" if u_stat:\n",
" if n_X != n_Y:\n",
" raise ValueError(\"u-stat estimator only defined for equal sample sizes\")\n",
" w_X = 1\n",
" w_Y = -1\n",
" else:\n",
" w_X = 1 / n_X\n",
" w_Y = -1 / n_Y\n",
"\n",
" # construct permutations\n",
" # there probably should be a faster way to do this but, idk\n",
" perms = torch.stack(\n",
" [torch.arange(n, device=device)]\n",
" + [torch.randperm(n, device=device) for _ in range(n_perm - 1)]\n",
" )\n",
" X_inds = perms[:, :n_X]\n",
" Y_inds = perms[:, n_X:]\n",
"\n",
" # set weights to w_X for things in X_inds, w_Y for others\n",
" ws = torch.full((n_perm, n), w_Y, device=device, dtype=dtype)\n",
" ws.scatter_(1, X_inds, w_X)\n",
"\n",
" # the \"basic\" estimate; either the biased est or a constant times it\n",
" ests = torch.einsum(\"pi,ij,pj->p\", ws, joint_kernel, ws)\n",
"\n",
" if u_stat:\n",
" # need to subtract \\sum_i k(X_i, X_i) + k(Y_i, Y_i) + 2 k(X_i, Y_i)\n",
" # first two are just trace\n",
" # for the last one, we need to see which ones were lined up\n",
" # NOTE: this makes an unnecessary copy if joint_kernel isn't already contiguous,\n",
" # but this generally shouldn't be a big deal\n",
" cross_terms = joint_kernel.take(X_inds * n + Y_inds).sum(1)\n",
" ests = (ests - joint_kernel.trace() + 2 * cross_terms) / (n_X * (n_X - 1))\n",
"\n",
" p_val = (ests >= ests[0]).float().mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5e149931-01e5-426a-9a90-bef389e14a4f",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c45c7919-859e-434e-9679-f40ab7b12dbe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.0640)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X = torch.randn(100, 3)\n",
"Y = torch.randn(100, 3) * 1.4\n",
"Z = torch.cat((X, Y))\n",
"K = torch.exp(-0.5 * torch.cdist(Z, Z) ** 2)\n",
"\n",
"res = mmd2_permutation(K, X.shape[0], u_stat=True)\n",
"\n",
"plt.hist(res.permuted_estimates, bins=\"auto\")\n",
"plt.axvline(res.estimate, color=\"r\")\n",
"res.p_value"
]
},
{
"cell_type": "markdown",
"id": "210265ef-6139-4c14-a5f1-c0ee2ffa4222",
"metadata": {},
"source": [
"## HSIC"
]
},
{
"cell_type": "markdown",
"id": "39ec253e-16a1-4ea4-96a5-6e091d3dcff7",
"metadata": {},
"source": [
"What about for HSIC?\n",
"\n",
"The biased (plug-in) estimator can be written as\n",
"$$\\DeclareMathOperator{\\Tr}{Tr}\n",
"\\langle H K H, H L H\\rangle_F = \\Tr(H K H L) = \\langle HKH, L \\rangle_F\n",
",$$\n",
"where $H = I - \\frac1n \\mathbf 1 \\mathbf 1^\\top$ is the centering matrix,\n",
"$K$ is the kernel matrix of $X$ values,\n",
"and $Y$ is the kernel matrix of $Y$ values.\n",
"Note that there's no reason to bother permuting _both_ $X$ and $Y$;\n",
"we can choose to only permute one or the other.\n",
"So we might as well just permute the one that we don't center,\n",
"to only have to center once."
]
},
{
"cell_type": "markdown",
"id": "02621101-0334-4a77-8572-ddf75270cb38",
"metadata": {},
"source": [
"I'm not sure which will be faster, so let's try two different implementation approches:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c01efb65-5819-4c4d-89ad-a0598d059bce",
"metadata": {},
"outputs": [],
"source": [
"def hsic_permutation_torch(K, L, n_perm=500):\n",
" \"\"\"\n",
" K: the pairwise kernel matrix of X values\n",
" L: the pairwise kernel matrix of corresponding Y values\n",
" n_perm: total number of permutations, including the identity\n",
" \"\"\"\n",
" K = torch.as_tensor(K)\n",
" device = K.device\n",
" dtype = K.dtype\n",
" L = torch.as_tensor(L).to(device=device, dtype=dtype)\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"K should be square, got {K.shape}\")\n",
" if L.shape != (n, n):\n",
" raise ValueError(f\"L should be same shape as K ({K.shape}), got {L.shape}\")\n",
"\n",
" row_mean = K.mean(dim=0, keepdim=True)\n",
" col_mean = K.mean(dim=1, keepdim=True)\n",
" HKH_flat = (K - row_mean - col_mean + row_mean.mean()).ravel()\n",
"\n",
" L_flats = torch.empty((n_perm, n * n), device=device, dtype=dtype)\n",
" L_flats[0, :] = L.ravel()\n",
" for i in range(1, n_perm):\n",
" perm = torch.randperm(n, device=device)\n",
" L_flats[i, :] = L[perm.unsqueeze(1), perm.unsqueeze(0)].ravel()\n",
"\n",
" ests = L_flats @ HKH_flat\n",
"\n",
" p_val = (ests >= ests[0]).float().mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e0edc21b-99e6-41b1-a052-61db065777c2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import jit\n",
"\n",
"@jit(nopython=True)\n",
"def _hsic_perms(HKH, L, perms):\n",
" n = HKH.shape[0]\n",
" n_perms = perms.shape[0]\n",
" ests = np.zeros(n_perms)\n",
" for i in range(n):\n",
" for j in range(n):\n",
" for e in range(n_perms):\n",
" ests[e] += HKH[i, j] * L[perms[e, i], perms[e, j]]\n",
" return ests\n",
"\n",
"def hsic_permutation_numba(K, L, n_perm=500):\n",
" \"\"\"\n",
" K: the pairwise kernel matrix of X values\n",
" L: the pairwise kernel matrix of corresponding Y values\n",
" n_perm: total number of permutations, including the identity\n",
" \"\"\"\n",
" K = np.asarray(K)\n",
" L = np.asarray(L)\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"K should be square, got {K.shape}\")\n",
" if L.shape != (n, n):\n",
" raise ValueError(f\"L should be same shape as K ({K.shape}), got {L.shape}\")\n",
"\n",
" row_mean = K.mean(axis=0, keepdims=True)\n",
" col_mean = K.mean(axis=1, keepdims=True)\n",
" HKH = K - row_mean - col_mean + row_mean.mean()\n",
"\n",
" # make permutations, with first row in order\n",
" perms = np.tile(np.arange(n), (n_perm, 1))\n",
" rng = np.random.default_rng()\n",
" rng.permuted(perms[1:], axis=1, out=perms[1:])\n",
" \n",
" ests = _hsic_perms(HKH, L, perms)\n",
" \n",
" p_val = (ests >= ests[0]).mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bc73065a-9196-4802-a5bc-5dca68c07083",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics.pairwise import rbf_kernel\n",
"\n",
"X = np.random.randn(500, 2)\n",
"Y = X[:, [0, 0]] / 5 + np.random.randn(*X.shape)\n",
"\n",
"K = rbf_kernel(X, gamma=1)\n",
"L = rbf_kernel(Y, gamma=1)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dabb8428-2871-40ef-8442-6df0d9fb8d0d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.0010), 0.001)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAB4CAYAAAAkJeDUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAPHklEQVR4nO3dXWwUVR/H8d9C2wVrWylol5WiVdFECr0oBluVF0FIAxLCha8xJGIiKkhTiHnAi9bEUAIJoGBRogHUaL0ADIlRKVFKCDFiSUNbDcFQoUjXjRH6JmyxnOeCh3lc2sJuu92zL99PMkl3zmk5c2b55z9nzpxxGWOMAAAAAEuG2W4AAAAAkhsJKQAAAKwiIQUAAIBVJKQAAACwioQUAAAAVpGQAgAAwCoSUgAAAFiVYrsBA3HlyhWdO3dOGRkZcrlctpsDIAEZY9TR0SGv16thwxLz2p1YCmAohRNH4zIhPXfunHJzc203A0ASaGlp0bhx42w3Y0gQSwFEQyhxNC4T0oyMDElXDzAzM9NyawDEtK4uyeu9+vO5c1J6eki/1t7ertzcXCfeJCJiKYCQDSCWhhNH4zIhvXZrKTMzkyAK4MaGD///z5mZISek1yTyrWxiKYCQDSKWhhJHE3NiFAAAAOJGXI6Qwr67//PVDct/WzcvSi0BgNhBbAQGJuyE9NChQ9qwYYPq6urU2tqqvXv3auHChU65MUZvvfWWtm/frvPnz2vq1Kl67733NHHiRKdOIBDQqlWr9Pnnn+vixYuaNWuWqqqqEvbBgVhzs4ApETQBAED0hH3LvqurSwUFBdq6dWuf5evXr9fGjRu1detWHT16VB6PR0888YQ6OjqcOqWlpdq7d6+qq6t1+PBhdXZ2av78+erp6Rn4kQAAACAuhT1CWlJSopKSkj7LjDHavHmz3nzzTS1atEiStGvXLuXk5Oizzz7Tyy+/rLa2Nn300Uf65JNPNHv2bEnSp59+qtzcXB04cEBz584dxOEAAAAg3kT0oabm5mb5fD7NmTPH2ed2uzV9+nQdOXJEklRXV6fLly8H1fF6vcrPz3fqXC8QCKi9vT1oAwAAQGKIaELq8/kkSTk5OUH7c3JynDKfz6e0tDSNGjWq3zrXq6ysVFZWlrOxkDOAeHbo0CE9+eST8nq9crlc+vLLL4PKjTGqqKiQ1+vVyJEjNWPGDDU1NQXVCQQCWr58ucaMGaP09HQtWLBAZ8+ejeJRAEDkDMmyT9evN2WMuekaVDeqs3r1arW1tTlbS0tLxNoKANHGXHwACBbRZZ88Ho+kq6OgY8eOdfb7/X5n1NTj8ai7u1vnz58PGiX1+/0qLi7u8++63W653e5INjWhhfIUve028BQ/kpmtufiBQECBQMD5zPSn8A02vhIbgb5FdIQ0Ly9PHo9HNTU1zr7u7m7V1tY6yWZhYaFSU1OD6rS2tqqxsbHfhBQAksVQzcWXmP4EIHaFPULa2dmpX3/91fnc3Nys+vp6ZWdna/z48SotLdXatWs1YcIETZgwQWvXrtUtt9yi5557TpKUlZWlJUuWaOXKlRo9erSys7O1atUqTZo0ybnSB4BkdaO5+KdPn3bqhDsXX7o6/amsrMz5fO090wBgW9gJ6U8//aSZM2c6n68Ft8WLF2vnzp164403dPHiRb366qvOwvj79+9XRkaG8zubNm1SSkqKnnrqKWdh/J07d2r4v9+TCgBJLNJz8SWmPwGIXWEnpDNmzJAxpt9yl8uliooKVVRU9FtnxIgR2rJli7Zs2RLuPw8ACW2o5uIDQCwbkqfsAQADw1x8AMkook/ZAwBujrn4ABCMhBR9ioWlo4BExVx8AAjmMjeaEBqj2tvblZWVpba2NmVmZtpuTsyJh2SStfYQNV1d0q23Xv25s1NKTw/p15IhziTDMUaa7fhK7IQ1A4il4cQY5pACAADAKhJSAAAAWMUcUgAA4gSvHkWiYoQUAAAAVpGQAgAAwCoSUgAAAFhFQgoAAACreKgpDtleBw8AACCSGCEFAACAVYyQxhhGPwFg6LBsEhCbGCEFAACAVSSkAAAAsIpb9gAA/A/TpgA7SEhhRShBn7lcAAAkBxJSAAASBA9tIV4xhxQAAABWkZACAADAKhJSAAAAWEVCCgAAAKtISAEAAGAVCSkAAACsIiEFAACAVaxDGmW8BQQAACAYI6QAAACwihFSAEDC4C7UjfEmJ8QqElLELAInAADJgVv2AAAAsIqEFAAAAFaRkAIAAMAqElIAAABYRUIKAAAAq3jKPsJYcgQAACA8jJACAADAKkZIAQCAJNZ/hj2MkAIAAMAqRkgRt7iSB4DoIu5iqDBCCgAAAKtISAEAAGAVt+wBAHGDpfWAxERCioTFXCcAiC7iLgaKW/YAAACwioQUAAAAVpGQAgAAwCoSUgAAAFjFQ01h4gnPxBHKuWQCPgAAQ4+EFAAQM7joB5ITCSlwAyxhAgDA0GMOKQAAAKxihBQAAEQFd53QH0ZIAQAAYBUjpMAgcLUPAJFDTE1eJKT/wtOdAAAA0UdCCgCIGi78AfQlqRJSAiEAAPGLW/qJy2pCWlVVpQ0bNqi1tVUTJ07U5s2b9dhjj9lsEhBVvC0Kg0UcBZAIrCWkX3zxhUpLS1VVVaVHHnlEH3zwgUpKSvTzzz9r/PjxtpoFRFQkRuWjMSLAqEN8Io4CSBTWEtKNGzdqyZIleumllyRJmzdv1rfffqtt27apsrIyqG4gEFAgEHA+t7W1SZLa29vD+jevBP4eZKuB2HOz/wf55d8O+b8R07q6/v9ze7vU0xPSr107ZmPMULQqIsKJo1LkYumNROL7BgxUXMeqWDeAWBpWHDUWBAIBM3z4cLNnz56g/a+//rqZNm1ar/rl5eVGEhsbG1vUt5aWlmiFxrCEG0eNIZaysbHZ2UKJo1ZGSP/880/19PQoJycnaH9OTo58Pl+v+qtXr1ZZWZnz+cqVK/rrr780evRouVyuIW9vompvb1dubq5aWlqUmZlpuzlxi36MjFjrR2OMOjo65PV6bTelT+HGUSm0WBpr5yHe0H+DRx8OTiz1Xzhx1OpDTdcnk8aYPhNMt9stt9sdtO+2224byqYllczMTOtf2kRAP0ZGLPVjVlaW7SbcVKhxVAovlsbSeYhH9N/g0YeDEyv9F2octfLq0DFjxmj48OG9ruL9fn+vq30AQG/EUQCJxEpCmpaWpsLCQtXU1ATtr6mpUXFxsY0mAUBcIY4CSCTWbtmXlZXphRde0JQpU1RUVKTt27frzJkzWrp0qa0mJR23263y8vJet/AQHvoxMujH8A1FHOU8DA79N3j04eDEa/+5jLG3pklVVZXWr1+v1tZW5efna9OmTZo2bZqt5gBA3CGOAkgEVhNSAAAAwMocUgAAAOAaElIAAABYRUIKAAAAq0hIAQAAYBUJaQLZtm2bJk+e7LydoaioSF9//bVTboxRRUWFvF6vRo4cqRkzZqipqSnobwQCAS1fvlxjxoxRenq6FixYoLNnz0b7UGJKZWWlXC6XSktLnX30ZWgqKirkcrmCNo/H45TTj9FRWVmphx56SBkZGbrjjju0cOFCnThxIqgO56J/xNbIIqaGLyli6U3fdo+4sW/fPvPVV1+ZEydOmBMnTpg1a9aY1NRU09jYaIwxZt26dSYjI8Ps3r3bNDQ0mKefftqMHTvWtLe3O39j6dKl5s477zQ1NTXm2LFjZubMmaagoMD8888/tg7Lqh9//NHcfffdZvLkyWbFihXOfvoyNOXl5WbixImmtbXV2fx+v1NOP0bH3LlzzY4dO0xjY6Opr6838+bNM+PHjzednZ1OHc5F/4itkUNMHZhkiKUkpAlu1KhR5sMPPzRXrlwxHo/HrFu3zim7dOmSycrKMu+//74xxpgLFy6Y1NRUU11d7dT5/fffzbBhw8w333wT9bbb1tHRYSZMmGBqamrM9OnTneBJX4auvLzcFBQU9FlGP9rj9/uNJFNbW2uM4VwMBLE1fMTUgUuGWMot+wTV09Oj6upqdXV1qaioSM3NzfL5fJozZ45Tx+12a/r06Tpy5Igkqa6uTpcvXw6q4/V6lZ+f79RJJq+99prmzZun2bNnB+2nL8Nz8uRJeb1e5eXl6ZlnntGpU6ck0Y82tbW1SZKys7MlcS7CQWwdOGLq4CR6LLX26lAMjYaGBhUVFenSpUu69dZbtXfvXj344IPOFy4nJyeofk5Ojk6fPi1J8vl8SktL06hRo3rV8fl80TmAGFFdXa1jx47p6NGjvcqu9QV9eXNTp07Vxx9/rPvvv19//PGH3n77bRUXF6upqYl+tMQYo7KyMj366KPKz8+XxHc6FMTWwSGmDk4yxFIS0gTzwAMPqL6+XhcuXNDu3bu1ePFi1dbWOuUulyuovjGm177rhVInkbS0tGjFihXav3+/RowY0W89+vLmSkpKnJ8nTZqkoqIi3Xvvvdq1a5cefvhhSfRjtC1btkzHjx/X4cOHe5VxLvpHbB04YurgJUMs5ZZ9gklLS9N9992nKVOmqLKyUgUFBXrnnXecp/GuvxLy+/3OVZXH41F3d7fOnz/fb51kUFdXJ7/fr8LCQqWkpCglJUW1tbV69913lZKS4vQFfRm+9PR0TZo0SSdPnuQ7acHy5cu1b98+ff/99xo3bpyzn3Nxc8TWgSOmRl4ixlIS0gRnjFEgEFBeXp48Ho9qamqcsu7ubtXW1qq4uFiSVFhYqNTU1KA6ra2tamxsdOokg1mzZqmhoUH19fXONmXKFD3//POqr6/XPffcQ18OUCAQ0C+//KKxY8fynYwiY4yWLVumPXv26LvvvlNeXl5QOecifMTW0BFTIy8hY2m0n6LC0Fm9erU5dOiQaW5uNsePHzdr1qwxw4YNM/v37zfGXF0WIisry+zZs8c0NDSYZ599ts9lIcaNG2cOHDhgjh07Zh5//PGYWhbCln8/EWoMfRmqlStXmoMHD5pTp06ZH374wcyfP99kZGSY3377zRhDP0bLK6+8YrKysszBgweDlo35+++/nTqci/4RWyOPmBqeZIilJKQJ5MUXXzR33XWXSUtLM7fffruZNWuWEzCNubo0RHl5ufF4PMbtdptp06aZhoaGoL9x8eJFs2zZMpOdnW1Gjhxp5s+fb86cORPtQ4k51wdP+jI019bCS01NNV6v1yxatMg0NTU55fRjdEjqc9uxY4dTh3PRP2Jr5BFTw5MMsdRljDE2R2gBAACQ3JhDCgAAAKtISAEAAGAVCSkAAACsIiEFAACAVSSkAAAAsIqEFAAAAFaRkAIAAMAqElIAAABYRUIKAAAAq0hIAQAAYBUJKQAAAKz6L8EIHqGN/0SMAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 800x100 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f, (a1, a2) = plt.subplots(1, 2, figsize=(8, 1))\n",
"\n",
"r1 = hsic_permutation_torch(K, L, n_perm=1_000)\n",
"a1.hist(r1.permuted_estimates, bins=\"auto\")\n",
"a1.axvline(r1.estimate, color=\"r\")\n",
"\n",
"r2 = hsic_permutation_numba(K, L, n_perm=1_000)\n",
"a2.hist(r2.permuted_estimates, bins=\"auto\")\n",
"a2.axvline(r2.estimate, color=\"r\")\n",
"\n",
"r1.p_value, r2.p_value"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "60040945-a203-46fa-8f7f-86e0dad2046e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"358 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"hsic_permutation_torch(K, L).p_value"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7524d67d-b5d1-41c1-941e-d189ae0e013e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"152 ms ± 513 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"hsic_permutation_numba(K, L).p_value"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment