Last active
April 8, 2025 18:52
-
-
Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "d325051b-74dd-4c95-9254-14c3275c2aae", | |
"metadata": {}, | |
"source": [ | |
"Our goal is to efficiently compute permutation tests for MMD which are finite-sample valid,\n", | |
"based on samples $(X_i)_{i=1}^m$, $(Y_j)_{j=1}^n$." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "4986363a-a710-438d-af2a-49c8546261f3", | |
"metadata": {}, | |
"source": [ | |
"## Computation\n", | |
"Write $Z_i = \\begin{cases}X_i & 1 \\le i \\le m \\\\ Y_{i-m} & m < i \\le m + n .\\end{cases}$\n", | |
"\n", | |
"First, note that the plug-in estimator (\"the biased estimator\") for MMD is\n", | |
"\\begin{align*}\n", | |
"\\widehat{\\operatorname{MMD}_b^2}\n", | |
" &= \\frac1{m^2} \\sum_{i=1}^m \\sum_{i'=1}^m k(X_i, X_{i'})\n", | |
" + \\frac{1}{n^2} \\sum_{j=1}^n \\sum_{j'=1}^n'k(Y_j, Y_{j'})\n", | |
" - 2 \\frac{1}{mn} \\sum_{i=1}^m \\sum_{j=1}^n k(X_i, Y_j)\n", | |
"\\\\&= \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}^\\top\n", | |
" \\underbrace{\\begin{bmatrix} K_X & K_{XY} \\\\ K_{YX} & K_Y \\end{bmatrix}}_K\n", | |
" \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}\n", | |
",\\end{align*}\n", | |
"where $\\mathbf 1_m$ is a vector of $m$ ones, and $K$ an $(m + n) \\times (m + n)$ matrix with entries $K_{ij} = k(Z_i, Z_j)$.\n", | |
"To get a permuted value of this estimator, we just need to permute the vector we're hitting it with." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "9b44f61f-1dd5-41e1-b69f-c50611152b06", | |
"metadata": {}, | |
"source": [ | |
"The typical unbiased estimator (claimed to be the MVUE but I'm not sure this is actually true...) is\n", | |
"\\begin{align*}\n", | |
"\\widehat{\\operatorname{MMD}_u^2}\n", | |
" &= \\frac{1}{m (m-1)} \\sum_{i \\ne i'} k(X_i, X_{i'})\n", | |
" + \\frac{1}{n (n-1)} \\sum_{j \\ne j'} k(Y_j, Y_{j'})\n", | |
" - \\frac{2}{m n} \\sum_{i, j} k(X_i, Y_j)\n", | |
";\\end{align*}\n", | |
"this doesn't seem especially amenable to easy permutation in the same way as the previous one." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a65afbca-9469-4100-9a2c-f20fc5337fe1", | |
"metadata": {}, | |
"source": [ | |
"But the U-statistic estimator, which assumes $m = n$ and is unbiased but not quite minimum variance, is\n", | |
"\\begin{align*}\n", | |
"\\widehat{\\operatorname{MMD}_U^2}\n", | |
" &= \\frac{1}{n (n-1)} \\sum_{i \\ne j} \\left[ k(X_i, X_j) + k(Y_i, Y_j) - k(X_i, Y_j) - k(X_j, Y_i) \\right]\n", | |
"\\\\&= \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(X_i, X_j)\n", | |
" + \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(Y_i, Y_j)\n", | |
" - \\frac{2}{n(n-1)} \\sum_{i \\ne j} k(X_i, Y_j)\n", | |
"\\\\&= \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(X_i, X_j) - \\sum_i k(X_i, X_i) \\right)\n", | |
" + \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(Y_i, Y_j) - \\sum_i k(Y_i, Y_i) \\right)\n", | |
"\\\\&\\qquad\n", | |
" - \\frac{2}{n(n-1)} \\left( \\sum_{i,j} k(X_i, Y_j) - \\sum_i k(X_i, Y_i) \\right)\n", | |
"\\\\&= \\frac{n}{n-1} \\widehat{\\operatorname{MMD}_b^2}\n", | |
" - \\frac{1}{n(n-1)} \\left( \\sum_i k(X_i, X_i) + \\sum_i k(Y_i, Y_i) - 2 \\sum_i k(X_i, Y_i) \\right)\n", | |
";\\end{align*}\n", | |
"the first two terms of the correction are simply the trace of $K$ and don't depend on the particular permutation.\n", | |
"The third term does, but isn't so bad." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "a23f6a33-72d7-4549-af56-313e5b11b9cf", | |
"metadata": {}, | |
"source": [ | |
"## p-value\n", | |
"As discussed in section 3.4 of Hemerik and Goeman (STAT 2018), [Exact testing with random permutations](https://arxiv.org/abs/1411.7565),\n", | |
"let $T_1, T_2, \\dots, T_w$ be the $w$ permuted test statistics returned by the previous procedure,\n", | |
"where $T_1$ is the actual data split\n", | |
"and $2$ through $w$ are uniformly random permutations.\n", | |
"A (possibly conservative) $p$-value is $\\lvert \\{ k \\in [w] : T_k \\ge T_1 \\} \\rvert / w$.\n", | |
"An exact (randomized) $p$-value can be found based on counting ties, but whatever." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "01140287-8c45-4dff-b1d4-cddd0f3db08a", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "8e31cd9f-392e-4c09-821e-bd7827a7379c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from collections import namedtuple\n", | |
"\n", | |
"PermutationResult = namedtuple(\n", | |
" \"PermutationResult\", [\"estimate\", \"p_value\", \"permuted_estimates\"]\n", | |
")\n", | |
"\n", | |
"\n", | |
"def mmd2_permutation(joint_kernel, n_X, n_perm=500, u_stat=False):\n", | |
" \"\"\"\n", | |
" joint_kernel: should be an array of shape [n_X + n_Y, n_X + n_Y] with kernel values,\n", | |
" ie [ K_XX K_XY ]\n", | |
" [ K_YX K_YY ]\n", | |
" n_X: number of entries in the first set\n", | |
" n_perm: total number of permutations, including the identity\n", | |
"\n", | |
" If biased is True, uses the plug-in estimator (MMD between empirical distributions).\n", | |
" If False, it uses the U-statistic estimator, which is unbiased but drops k(x_i, y_i) terms.\n", | |
" (I'm not sure how to implement the \"unbiased estimator\" (which includes those terms) efficiently.)\n", | |
" \"\"\"\n", | |
" K = joint_kernel = torch.as_tensor(joint_kernel)\n", | |
" device = K.device\n", | |
" dtype = K.dtype\n", | |
"\n", | |
" n = K.shape[0]\n", | |
" if K.shape != (n, n):\n", | |
" raise ValueError(f\"joint_kernel should be square, got {K.shape}\")\n", | |
" n_X = int(n_X)\n", | |
" n_Y = n - n_X\n", | |
" if n_X <= 0 or n_Y <= 0:\n", | |
" raise ValueError(\"need a positive number of samples from each\")\n", | |
"\n", | |
" if u_stat:\n", | |
" if n_X != n_Y:\n", | |
" raise ValueError(\"u-stat estimator only defined for equal sample sizes\")\n", | |
" w_X = 1\n", | |
" w_Y = -1\n", | |
" else:\n", | |
" w_X = 1 / n_X\n", | |
" w_Y = -1 / n_Y\n", | |
"\n", | |
" # construct permutations\n", | |
" # there probably should be a faster way to do this but, idk\n", | |
" perms = torch.stack(\n", | |
" [torch.arange(n, device=device)]\n", | |
" + [torch.randperm(n, device=device) for _ in range(n_perm - 1)]\n", | |
" )\n", | |
" X_inds = perms[:, :n_X]\n", | |
" Y_inds = perms[:, n_X:]\n", | |
"\n", | |
" # set weights to w_X for things in X_inds, w_Y for others\n", | |
" ws = torch.full((n_perm, n), w_Y, device=device, dtype=dtype)\n", | |
" ws.scatter_(1, X_inds, w_X)\n", | |
"\n", | |
" # the \"basic\" estimate; either the biased est or a constant times it\n", | |
" ests = torch.einsum(\"pi,ij,pj->p\", ws, joint_kernel, ws)\n", | |
"\n", | |
" if u_stat:\n", | |
" # need to subtract \\sum_i k(X_i, X_i) + k(Y_i, Y_i) + 2 k(X_i, Y_i)\n", | |
" # first two are just trace\n", | |
" # for the last one, we need to see which ones were lined up\n", | |
" # NOTE: this makes an unnecessary copy if joint_kernel isn't already contiguous,\n", | |
" # but this generally shouldn't be a big deal\n", | |
" cross_terms = joint_kernel.take(X_inds * n + Y_inds).sum(1)\n", | |
" ests = (ests - joint_kernel.trace() + 2 * cross_terms) / (n_X * (n_X - 1))\n", | |
"\n", | |
" p_val = (ests >= ests[0]).float().mean()\n", | |
" return PermutationResult(ests[0], p_val, ests)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "5e149931-01e5-426a-9a90-bef389e14a4f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "c45c7919-859e-434e-9679-f40ab7b12dbe", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0.0640)" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjVklEQVR4nO3df3BU1f3/8dfKjyUJmxVQdrMSSWwjqAF/gA3E2kQLUYraDo6/oHyw1g4UsETqUFLmO0T7cROxjelMFAeGoXEcxKmidUpF4lRjp4EaMIw0UItDgKisKTRmA6SJwPn+wSdb1iTIJrtns+H5mLkzu+eee897D2vy8uTuXYcxxggAAMCSi+JdAAAAuLAQPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYNTjeBXzV6dOn9dlnn8nlcsnhcMS7HAAAcB6MMWptbZXP59NFF517baPfhY/PPvtM6enp8S4DAAD0QmNjo8aMGXPOPv0ufLhcLklnik9NTY1zNQAuCMePSz7fmceffSalpMS3HiABBYNBpaenh36Pn0u/Cx+df2pJTU0lfACwY9Cg/z5OTSV8AH1wPpdMcMEpAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsGhzvAhBfGcs3x+S8B0pnxuS8AIDEx8oHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArBoc7wIwMGUs3xyzcx8onRmzcwMAYi+ilY+MjAw5HI4u26JFiyRJxhgVFxfL5/MpKSlJ+fn5qq+vj0nhAAAgMUUUPmpra3X48OHQVlVVJUm65557JEmrVq1SWVmZKioqVFtbK6/Xq+nTp6u1tTX6lQMAgIQUUfi49NJL5fV6Q9sf//hHfeMb31BeXp6MMSovL9eKFSs0a9YsZWdnq7KyUidOnNCGDRtiVT8AAEgwvb7gtKOjQy+++KIeeughORwONTQ0KBAIqKCgINTH6XQqLy9PNTU1PZ6nvb1dwWAwbAMAAANXr8PH66+/ri+++EIPPvigJCkQCEiSPB5PWD+PxxPa152SkhK53e7Qlp6e3tuSAABAAuh1+Fi3bp1mzJghn88X1u5wOMKeG2O6tJ2tqKhILS0toa2xsbG3JQEAgATQq4/aHjx4UG+//bY2bdoUavN6vZLOrICkpaWF2puamrqshpzN6XTK6XT2pgwAAJCAerXysX79eo0ePVozZ/73fguZmZnyer2hT8BIZ64Lqa6uVm5ubt8rBQAAA0LEKx+nT5/W+vXrNW/ePA0e/N/DHQ6HCgsL5ff7lZWVpaysLPn9fiUnJ2v27NlRLRoAACSuiMPH22+/rUOHDumhhx7qsm/ZsmVqa2vTwoUL1dzcrJycHG3dulUulysqxQJS7O6eyp1TAcAOhzHGxLuIswWDQbndbrW0tCg1NTXe5Qx4sbwNeqIhfFzAjh+Xhg8/8/jYMSklJb71AAkokt/ffLEcAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqyIOH59++ql++MMfatSoUUpOTtZ1112nnTt3hvYbY1RcXCyfz6ekpCTl5+ervr4+qkUDAIDEFVH4aG5u1k033aQhQ4bozTff1J49e/Sb3/xGF198cajPqlWrVFZWpoqKCtXW1srr9Wr69OlqbW2Ndu0AACABDY6k81NPPaX09HStX78+1JaRkRF6bIxReXm5VqxYoVmzZkmSKisr5fF4tGHDBs2fPz86VQMAgIQV0crHG2+8ocmTJ+uee+7R6NGjdf3112vt2rWh/Q0NDQoEAiooKAi1OZ1O5eXlqaampttztre3KxgMhm0AAGDgiih87N+/X6tXr1ZWVpbeeustLViwQD/72c/0wgsvSJICgYAkyePxhB3n8XhC+76qpKREbrc7tKWnp/fmdQAAgAQRUfg4ffq0brjhBvn9fl1//fWaP3++fvKTn2j16tVh/RwOR9hzY0yXtk5FRUVqaWkJbY2NjRG+BAAAkEgiCh9paWm6+uqrw9quuuoqHTp0SJLk9XolqcsqR1NTU5fVkE5Op1OpqalhGwAAGLgiCh833XSTPvroo7C2f/7znxo7dqwkKTMzU16vV1VVVaH9HR0dqq6uVm5ubhTKBQAAiS6iT7s8+uijys3Nld/v17333qv3339fa9as0Zo1aySd+XNLYWGh/H6/srKylJWVJb/fr+TkZM2ePTsmLwAAACSWiMLHjTfeqNdee01FRUV64oknlJmZqfLycs2ZMyfUZ9myZWpra9PChQvV3NysnJwcbd26VS6XK+rFAwCAxOMwxph4F3G2YDAot9utlpYWrv+wIGP55niX0G8cKJ0Z7xIQL8ePS8OHn3l87JiUkhLfeoAEFMnvb77bBQAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFg1ON4FAP1FxvLNMTv3gdKZMTs3ACQaVj4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVROGjuLhYDocjbPN6vaH9xhgVFxfL5/MpKSlJ+fn5qq+vj3rRAAAgcUW88nHNNdfo8OHDoW337t2hfatWrVJZWZkqKipUW1srr9er6dOnq7W1NapFAwCAxBVx+Bg8eLC8Xm9ou/TSSyWdWfUoLy/XihUrNGvWLGVnZ6uyslInTpzQhg0bol44AABITBGHj3379snn8ykzM1P333+/9u/fL0lqaGhQIBBQQUFBqK/T6VReXp5qamp6PF97e7uCwWDYBgAABq6IwkdOTo5eeOEFvfXWW1q7dq0CgYByc3N19OhRBQIBSZLH4wk7xuPxhPZ1p6SkRG63O7Slp6f34mUAAIBEEVH4mDFjhu6++25NmDBB06ZN0+bNmyVJlZWVoT4OhyPsGGNMl7azFRUVqaWlJbQ1NjZGUhIAAEgwffqobUpKiiZMmKB9+/aFPvXy1VWOpqamLqshZ3M6nUpNTQ3bAADAwNWn8NHe3q69e/cqLS1NmZmZ8nq9qqqqCu3v6OhQdXW1cnNz+1woAAAYGAZH0vmxxx7TnXfeqcsvv1xNTU363//9XwWDQc2bN08Oh0OFhYXy+/3KyspSVlaW/H6/kpOTNXv27FjVDwAAEkxE4eOTTz7RAw88oCNHjujSSy/VlClTtH37do0dO1aStGzZMrW1tWnhwoVqbm5WTk6Otm7dKpfLFZPiAQBA4nEYY0y8izhbMBiU2+1WS0sL139YkLF8c7xLuCAcKJ0Z7xJwLsePS8OHn3l87JiUkhLfeoAEFMnvb77bBQAAWEX4AAAAVkV0zQeA3onVn7f4cw6ARMTKBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMCqPoWPkpISORwOFRYWhtqMMSouLpbP51NSUpLy8/NVX1/f1zoBAMAA0evwUVtbqzVr1mjixIlh7atWrVJZWZkqKipUW1srr9er6dOnq7W1tc/FAgCAxNer8HHs2DHNmTNHa9eu1YgRI0LtxhiVl5drxYoVmjVrlrKzs1VZWakTJ05ow4YNUSsaAAAkrl6Fj0WLFmnmzJmaNm1aWHtDQ4MCgYAKCgpCbU6nU3l5eaqpqen2XO3t7QoGg2EbAAAYuAZHesDGjRv1wQcfqLa2tsu+QCAgSfJ4PGHtHo9HBw8e7PZ8JSUlevzxxyMt44KSsXxzvEsAACBqIlr5aGxs1JIlS/Tiiy9q2LBhPfZzOBxhz40xXdo6FRUVqaWlJbQ1NjZGUhIAAEgwEa187Ny5U01NTZo0aVKo7dSpU3rvvfdUUVGhjz76SNKZFZC0tLRQn6ampi6rIZ2cTqecTmdvagcAAAkoopWP7373u9q9e7d27doV2iZPnqw5c+Zo165duuKKK+T1elVVVRU6pqOjQ9XV1crNzY168QAAIPFEtPLhcrmUnZ0d1paSkqJRo0aF2gsLC+X3+5WVlaWsrCz5/X4lJydr9uzZ0asaAAAkrIgvOP06y5YtU1tbmxYuXKjm5mbl5ORo69atcrlc0R4KAAAkIIcxxsS7iLMFg0G53W61tLQoNTU13uX0C3zaBT05UDoz3iUMDMePS8OHn3l87JiUkhLfeoAEFMnvb77bBQAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVUb/JGAB7YnkPGO4hAiBWWPkAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVROFj9erVmjhxolJTU5WamqqpU6fqzTffDO03xqi4uFg+n09JSUnKz89XfX191IsGAACJK6LwMWbMGJWWlmrHjh3asWOHbr31Vn3/+98PBYxVq1aprKxMFRUVqq2tldfr1fTp09Xa2hqT4gEAQOKJKHzceeed+t73vqcrr7xSV155pZ588kkNHz5c27dvlzFG5eXlWrFihWbNmqXs7GxVVlbqxIkT2rBhQ6zqBwAACabX13ycOnVKGzdu1PHjxzV16lQ1NDQoEAiooKAg1MfpdCovL081NTU9nqe9vV3BYDBsAwAAA1fE4WP37t0aPny4nE6nFixYoNdee01XX321AoGAJMnj8YT193g8oX3dKSkpkdvtDm3p6emRlgQAABJIxOFj3Lhx2rVrl7Zv366f/vSnmjdvnvbs2RPa73A4wvobY7q0na2oqEgtLS2hrbGxMdKSAABAAhkc6QFDhw7VN7/5TUnS5MmTVVtbq9/+9rf6xS9+IUkKBAJKS0sL9W9qauqyGnI2p9Mpp9MZaRkAACBB9fk+H8YYtbe3KzMzU16vV1VVVaF9HR0dqq6uVm5ubl+HAQAAA0REKx+//OUvNWPGDKWnp6u1tVUbN27Uu+++qy1btsjhcKiwsFB+v19ZWVnKysqS3+9XcnKyZs+eHav6AQBAgokofHz++eeaO3euDh8+LLfbrYkTJ2rLli2aPn26JGnZsmVqa2vTwoUL1dzcrJycHG3dulUulysmxQMAgMTjMMaYeBdxtmAwKLfbrZaWFqWmpsa7nH4hY/nmeJeAC9CB0pnxLsGe48el4cPPPD52TEpJiW89QAKK5Pc33+0CAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALBqcLwLGEgylm+OdwkAAPR7rHwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq7jDKYBuxeqOvQdKZ8bkvAASBysfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsiCh8lJSW68cYb5XK5NHr0aP3gBz/QRx99FNbHGKPi4mL5fD4lJSUpPz9f9fX1US0aAAAkrojCR3V1tRYtWqTt27erqqpKJ0+eVEFBgY4fPx7qs2rVKpWVlamiokK1tbXyer2aPn26Wltbo148AABIPBHdXn3Lli1hz9evX6/Ro0dr586d+s53viNjjMrLy7VixQrNmjVLklRZWSmPx6MNGzZo/vz50ascAAAkpD5d89HS0iJJGjlypCSpoaFBgUBABQUFoT5Op1N5eXmqqanpy1AAAGCA6PUXyxljtHTpUn37299Wdna2JCkQCEiSPB5PWF+Px6ODBw92e5729na1t7eHngeDwd6WBAAAEkCvVz4WL16sDz/8UC+99FKXfQ6HI+y5MaZLW6eSkhK53e7Qlp6e3tuSAABAAuhV+HjkkUf0xhtv6J133tGYMWNC7V6vV9J/V0A6NTU1dVkN6VRUVKSWlpbQ1tjY2JuSAABAgogofBhjtHjxYm3atEl//vOflZmZGbY/MzNTXq9XVVVVobaOjg5VV1crNze323M6nU6lpqaGbQAAYOCK6JqPRYsWacOGDfrDH/4gl8sVWuFwu91KSkqSw+FQYWGh/H6/srKylJWVJb/fr+TkZM2ePTsmLwAAACSWiMLH6tWrJUn5+flh7evXr9eDDz4oSVq2bJna2tq0cOFCNTc3KycnR1u3bpXL5YpKwQAAILFFFD6MMV/bx+FwqLi4WMXFxb2tCQAADGB8twsAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwaHO8CAFxYMpZvjtm5D5TOjNm5AUQPKx8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqojDx3vvvac777xTPp9PDodDr7/+eth+Y4yKi4vl8/mUlJSk/Px81dfXR6teAACQ4CIOH8ePH9e1116rioqKbvevWrVKZWVlqqioUG1trbxer6ZPn67W1tY+FwsAABJfxN/tMmPGDM2YMaPbfcYYlZeXa8WKFZo1a5YkqbKyUh6PRxs2bND8+fP7Vi0AAEh4Ub3mo6GhQYFAQAUFBaE2p9OpvLw81dTUdHtMe3u7gsFg2AYAAAauqIaPQCAgSfJ4PGHtHo8ntO+rSkpK5Ha7Q1t6eno0SwIAAP1MTD7t4nA4wp4bY7q0dSoqKlJLS0toa2xsjEVJAACgn4j4mo9z8Xq9ks6sgKSlpYXam5qauqyGdHI6nXI6ndEsAwAA9GNRXfnIzMyU1+tVVVVVqK2jo0PV1dXKzc2N5lAAACBBRbzycezYMX388ceh5w0NDdq1a5dGjhypyy+/XIWFhfL7/crKylJWVpb8fr+Sk5M1e/bsqBYOAAASU8ThY8eOHbrllltCz5cuXSpJmjdvnn73u99p2bJlamtr08KFC9Xc3KycnBxt3bpVLpcrelUDQDcylm/u1XFJHf/R3v97fNX/26K2ocOiV9TXOFA609pYQH8RcfjIz8+XMabH/Q6HQ8XFxSouLu5LXQAAYIDiu10AAIBVhA8AAGBVVD9qCwCITG+vUzkfXE+C/oqVDwAAYBXhAwAAWEX4AAAAVhE+AACAVRfcBaexvLgLAAB8PVY+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFZdcDcZA4ALRaxuqsi35aKvWPkAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWMVNxgAAEYnVzcskbmB2oWDlAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFgVszucPvfcc3r66ad1+PBhXXPNNSovL9fNN98cq+EAAANALO+eGiuxuivrQL6TbExWPl5++WUVFhZqxYoVqqur080336wZM2bo0KFDsRgOAAAkkJiEj7KyMv34xz/Www8/rKuuukrl5eVKT0/X6tWrYzEcAABIIFH/s0tHR4d27typ5cuXh7UXFBSopqamS//29na1t7eHnre0tEiSgsFgtEuTJJ1uPxGT8wJIXKc6/qPOnzin2k/otDkd13qQWBLx91Usau48pzHma/tGPXwcOXJEp06dksfjCWv3eDwKBAJd+peUlOjxxx/v0p6enh7t0gCgR+7OB8/9TzzLQAJyl8e7gsjFsubW1la53e5z9onZBacOhyPsuTGmS5skFRUVaenSpaHnp0+f1r///W+NGjWq2/7nKxgMKj09XY2NjUpNTe31eQYi5qZ7zEvPmJvuMS/dY156NpDnxhij1tZW+Xy+r+0b9fBxySWXaNCgQV1WOZqamrqshkiS0+mU0+kMa7v44oujVk9qauqA+weOFuame8xLz5ib7jEv3WNeejZQ5+brVjw6Rf2C06FDh2rSpEmqqqoKa6+qqlJubm60hwMAAAkmJn92Wbp0qebOnavJkydr6tSpWrNmjQ4dOqQFCxbEYjgAAJBAYhI+7rvvPh09elRPPPGEDh8+rOzsbP3pT3/S2LFjYzFct5xOp1auXNnlTzpgbnrCvPSMueke89I95qVnzM0ZDnM+n4kBAACIEr7bBQAAWEX4AAAAVhE+AACAVYQPAABgVcKEj+bmZs2dO1dut1tut1tz587VF198cc5jjDEqLi6Wz+dTUlKS8vPzVV9fH9ZnzZo1ys/PV2pqqhwOR7fn7M3YNsVqbtrb2/XII4/okksuUUpKiu666y598sknYX0yMjLkcDjCtq9+r48tzz33nDIzMzVs2DBNmjRJf/nLX87Zv7q6WpMmTdKwYcN0xRVX6Pnnn+/S59VXX9XVV18tp9Opq6++Wq+99lqfx7UtHvNSXFzc5X3h9Xqj+rqiIdpzU19fr7vvvjv030V5eXlUxrUtHvOSCO+ZaM/L2rVrdfPNN2vEiBEaMWKEpk2bpvfff7/P4yYEkyBuv/12k52dbWpqakxNTY3Jzs42d9xxxzmPKS0tNS6Xy7z66qtm9+7d5r777jNpaWkmGAyG+jzzzDOmpKTElJSUGEmmubk5KmPbFKu5WbBggbnssstMVVWV+eCDD8wtt9xirr32WnPy5MlQn7Fjx5onnnjCHD58OLS1trbG7LX2ZOPGjWbIkCFm7dq1Zs+ePWbJkiUmJSXFHDx4sNv++/fvN8nJyWbJkiVmz549Zu3atWbIkCHmlVdeCfWpqakxgwYNMn6/3+zdu9f4/X4zePBgs3379l6Pa1u85mXlypXmmmuuCXtfNDU1xfz1RiIWc/P++++bxx57zLz00kvG6/WaZ555ps/j2haveenv75lYzMvs2bPNs88+a+rq6szevXvNj370I+N2u80nn3zS63ETRUKEjz179hhJYT/ctm3bZiSZf/zjH90ec/r0aeP1ek1paWmo7T//+Y9xu93m+eef79L/nXfe6TZ89GZsm2I1N1988YUZMmSI2bhxY6jPp59+ai666CKzZcuWUNvYsWO7/UFi27e+9S2zYMGCsLbx48eb5cuXd9t/2bJlZvz48WFt8+fPN1OmTAk9v/fee83tt98e1ue2224z999/f6/HtS1e87Jy5Upz7bXX9rH62IrF3Jytp/82LsT3zNl6mpf+/p6J9bwYY8zJkyeNy+UylZWVvR43USTEn122bdsmt9utnJycUNuUKVPkdrtVU1PT7TENDQ0KBAIqKCgItTmdTuXl5fV4TLTGtilWc7Nz5059+eWXYX18Pp+ys7O7nPepp57SqFGjdN111+nJJ59UR0dHNF/i1+ro6NDOnTvDapWkgoKCHudg27ZtXfrfdttt2rFjh7788stz9uk8Z2/GtSle89Jp37598vl8yszM1P3336/9+/f39SVFTazmJhbj2hSveenUX98ztublxIkT+vLLLzVy5Mhej5soEiJ8BAIBjR49ukv76NGju3yB3dnHSOryZXYej6fHY6I1tk2xmptAIKChQ4dqxIgRPfaRpCVLlmjjxo165513tHjxYpWXl2vhwoV9ek2ROnLkiE6dOhXRv3UgEOi2/8mTJ3XkyJFz9uk8Z2/GtSle8yJJOTk5euGFF/TWW29p7dq1CgQCys3N1dGjR6Px0vosVnMTi3Ftite8SP37PWNrXpYvX67LLrtM06ZN6/W4iSKu4aO7C4y+uu3YsUOS5HA4uhxvjOm2/Wxf3X8+x3zdOXp7nkj017n5ap9HH31UeXl5mjhxoh5++GE9//zzWrduXVx+YET6errr/9X28zlnNN5jsRSPeZkxY4buvvtuTZgwQdOmTdPmzZslSZWVlb17ETESi7mJxbi2xWNeEuE9E8t5WbVqlV566SVt2rRJw4YN69O4iSAm3+1yvhYvXqz777//nH0yMjL04Ycf6vPPP++y71//+leXRNip8yrpQCCgtLS0UHtTU1OPx/R0nkjHjoZ4z43X61VHR4eam5vDVj+amprO+e3EU6ZMkSR9/PHHGjVq1Dnrj5ZLLrlEgwYN6vJ/Auf6t/Z6vd32Hzx4cKjunvp0nrM349oUr3npTkpKiiZMmKB9+/b15qVEXazmJhbj2hSveelOf3rPxHpefv3rX8vv9+vtt9/WxIkT+zRuoojryscll1yi8ePHn3MbNmyYpk6dqpaWlrCPIP3tb39TS0tLj78IMzMz5fV6VVVVFWrr6OhQdXX1OX95flVvxo6GeM/NpEmTNGTIkLA+hw8f1t///vdzvu66ujpJCgs1sTZ06FBNmjQprFZJqqqq6rHWqVOndum/detWTZ48WUOGDDlnn85z9mZcm+I1L91pb2/X3r17rb4vziVWcxOLcW2K17x0pz+9Z2I5L08//bR+9atfacuWLZo8eXKfx00YVi9v7YPbb7/dTJw40Wzbts1s27bNTJgwocvHSceNG2c2bdoUel5aWmrcbrfZtGmT2b17t3nggQe6fJz08OHDpq6uzqxdu9ZIMu+9956pq6szR48ejWjseIrV3CxYsMCMGTPGvP322+aDDz4wt956a9hHbWtqakxZWZmpq6sz+/fvNy+//LLx+XzmrrvusvPCz9L5cbR169aZPXv2mMLCQpOSkmIOHDhgjDFm+fLlZu7cuaH+nR+De/TRR82ePXvMunXrunwM7q9//asZNGiQKS0tNXv37jWlpaU9ftS2p3HjLV7z8vOf/9y8++67Zv/+/Wb79u3mjjvuMC6Xq9/MizGxmZv29nZTV1dn6urqTFpamnnsscdMXV2d2bdv33mPG2/xmpf+/p6Jxbw89dRTZujQoeaVV17p8XYF/f390lsJEz6OHj1q5syZY1wul3G5XGbOnDldPhYryaxfvz70/PTp02blypXG6/Uap9NpvvOd75jdu3eHHbNy5Uojqct29nnOZ+x4itXctLW1mcWLF5uRI0eapKQkc8cdd5hDhw6F9u/cudPk5OQYt9tthg0bZsaNG2dWrlxpjh8/HsuX26Nnn33WjB071gwdOtTccMMNprq6OrRv3rx5Ji8vL6z/u+++a66//nozdOhQk5GRYVavXt3lnL///e/NuHHjzJAhQ8z48ePNq6++GtG4/UE85qXzvjFDhgwxPp/PzJo1y9TX18fk9fVFtOemoaGh258nXz3PhfaeOZ95SYT3TLTnZezYsd3Oy8qVK8973ETlMOb/roABAACwICE+agsAAAYOwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACr/j/4GN8PvHajxQAAAABJRU5ErkJggg==", | |
"text/plain": [ | |
"<Figure size 640x480 with 1 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"X = torch.randn(100, 3)\n", | |
"Y = torch.randn(100, 3) * 1.4\n", | |
"Z = torch.cat((X, Y))\n", | |
"K = torch.exp(-0.5 * torch.cdist(Z, Z) ** 2)\n", | |
"\n", | |
"res = mmd2_permutation(K, X.shape[0], u_stat=True)\n", | |
"\n", | |
"plt.hist(res.permuted_estimates, bins=\"auto\")\n", | |
"plt.axvline(res.estimate, color=\"r\")\n", | |
"res.p_value" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "210265ef-6139-4c14-a5f1-c0ee2ffa4222", | |
"metadata": {}, | |
"source": [ | |
"## HSIC" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "39ec253e-16a1-4ea4-96a5-6e091d3dcff7", | |
"metadata": {}, | |
"source": [ | |
"What about for HSIC?\n", | |
"\n", | |
"The biased (plug-in) estimator can be written as\n", | |
"$$\\DeclareMathOperator{\\Tr}{Tr}\n", | |
"\\langle H K H, H L H\\rangle_F = \\Tr(H K H L) = \\langle HKH, L \\rangle_F\n", | |
",$$\n", | |
"where $H = I - \\frac1n \\mathbf 1 \\mathbf 1^\\top$ is the centering matrix,\n", | |
"$K$ is the kernel matrix of $X$ values,\n", | |
"and $Y$ is the kernel matrix of $Y$ values.\n", | |
"Note that there's no reason to bother permuting _both_ $X$ and $Y$;\n", | |
"we can choose to only permute one or the other.\n", | |
"So we might as well just permute the one that we don't center,\n", | |
"to only have to center once." | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"id": "02621101-0334-4a77-8572-ddf75270cb38", | |
"metadata": {}, | |
"source": [ | |
"I'm not sure which will be faster, so let's try two different implementation approches:" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "c01efb65-5819-4c4d-89ad-a0598d059bce", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def hsic_permutation_torch(K, L, n_perm=500):\n", | |
" \"\"\"\n", | |
" K: the pairwise kernel matrix of X values\n", | |
" L: the pairwise kernel matrix of corresponding Y values\n", | |
" n_perm: total number of permutations, including the identity\n", | |
" \"\"\"\n", | |
" K = torch.as_tensor(K)\n", | |
" device = K.device\n", | |
" dtype = K.dtype\n", | |
" L = torch.as_tensor(L).to(device=device, dtype=dtype)\n", | |
"\n", | |
" n = K.shape[0]\n", | |
" if K.shape != (n, n):\n", | |
" raise ValueError(f\"K should be square, got {K.shape}\")\n", | |
" if L.shape != (n, n):\n", | |
" raise ValueError(f\"L should be same shape as K ({K.shape}), got {L.shape}\")\n", | |
"\n", | |
" row_mean = K.mean(dim=0, keepdim=True)\n", | |
" col_mean = K.mean(dim=1, keepdim=True)\n", | |
" HKH_flat = (K - row_mean - col_mean + row_mean.mean()).ravel()\n", | |
"\n", | |
" L_flats = torch.empty((n_perm, n * n), device=device, dtype=dtype)\n", | |
" L_flats[0, :] = L.ravel()\n", | |
" for i in range(1, n_perm):\n", | |
" perm = torch.randperm(n, device=device)\n", | |
" L_flats[i, :] = L[perm.unsqueeze(1), perm.unsqueeze(0)].ravel()\n", | |
"\n", | |
" ests = L_flats @ HKH_flat\n", | |
"\n", | |
" p_val = (ests >= ests[0]).float().mean()\n", | |
" return PermutationResult(ests[0], p_val, ests)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "e0edc21b-99e6-41b1-a052-61db065777c2", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from numba import jit\n", | |
"\n", | |
"@jit(nopython=True)\n", | |
"def _hsic_perms(HKH, L, perms):\n", | |
" n = HKH.shape[0]\n", | |
" n_perms = perms.shape[0]\n", | |
" ests = np.zeros(n_perms)\n", | |
" for i in range(n):\n", | |
" for j in range(n):\n", | |
" for e in range(n_perms):\n", | |
" ests[e] += HKH[i, j] * L[perms[e, i], perms[e, j]]\n", | |
" return ests\n", | |
"\n", | |
"def hsic_permutation_numba(K, L, n_perm=500):\n", | |
" \"\"\"\n", | |
" K: the pairwise kernel matrix of X values\n", | |
" L: the pairwise kernel matrix of corresponding Y values\n", | |
" n_perm: total number of permutations, including the identity\n", | |
" \"\"\"\n", | |
" K = np.asarray(K)\n", | |
" L = np.asarray(L)\n", | |
"\n", | |
" n = K.shape[0]\n", | |
" if K.shape != (n, n):\n", | |
" raise ValueError(f\"K should be square, got {K.shape}\")\n", | |
" if L.shape != (n, n):\n", | |
" raise ValueError(f\"L should be same shape as K ({K.shape}), got {L.shape}\")\n", | |
"\n", | |
" row_mean = K.mean(axis=0, keepdims=True)\n", | |
" col_mean = K.mean(axis=1, keepdims=True)\n", | |
" HKH = K - row_mean - col_mean + row_mean.mean()\n", | |
"\n", | |
" # make permutations, with first row in order\n", | |
" perms = np.tile(np.arange(n), (n_perm, 1))\n", | |
" rng = np.random.default_rng()\n", | |
" rng.permuted(perms[1:], axis=1, out=perms[1:])\n", | |
" \n", | |
" ests = _hsic_perms(HKH, L, perms)\n", | |
" \n", | |
" p_val = (ests >= ests[0]).mean()\n", | |
" return PermutationResult(ests[0], p_val, ests)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "bc73065a-9196-4802-a5bc-5dca68c07083", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.metrics.pairwise import rbf_kernel\n", | |
"\n", | |
"X = np.random.randn(500, 2)\n", | |
"Y = X[:, [0, 0]] / 5 + np.random.randn(*X.shape)\n", | |
"\n", | |
"K = rbf_kernel(X, gamma=1)\n", | |
"L = rbf_kernel(Y, gamma=1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "dabb8428-2871-40ef-8442-6df0d9fb8d0d", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(tensor(0.0010), 0.001)" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
}, | |
{ | |
"data": { | |
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAB4CAYAAAAkJeDUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAPHklEQVR4nO3dXWwUVR/H8d9C2wVrWylol5WiVdFECr0oBluVF0FIAxLCha8xJGIiKkhTiHnAi9bEUAIJoGBRogHUaL0ADIlRKVFKCDFiSUNbDcFQoUjXjRH6JmyxnOeCh3lc2sJuu92zL99PMkl3zmk5c2b55z9nzpxxGWOMAAAAAEuG2W4AAAAAkhsJKQAAAKwiIQUAAIBVJKQAAACwioQUAAAAVpGQAgAAwCoSUgAAAFiVYrsBA3HlyhWdO3dOGRkZcrlctpsDIAEZY9TR0SGv16thwxLz2p1YCmAohRNH4zIhPXfunHJzc203A0ASaGlp0bhx42w3Y0gQSwFEQyhxNC4T0oyMDElXDzAzM9NyawDEtK4uyeu9+vO5c1J6eki/1t7ertzcXCfeJCJiKYCQDSCWhhNH4zIhvXZrKTMzkyAK4MaGD///z5mZISek1yTyrWxiKYCQDSKWhhJHE3NiFAAAAOJGXI6Qwr67//PVDct/WzcvSi0BgNhBbAQGJuyE9NChQ9qwYYPq6urU2tqqvXv3auHChU65MUZvvfWWtm/frvPnz2vq1Kl67733NHHiRKdOIBDQqlWr9Pnnn+vixYuaNWuWqqqqEvbBgVhzs4ApETQBAED0hH3LvqurSwUFBdq6dWuf5evXr9fGjRu1detWHT16VB6PR0888YQ6OjqcOqWlpdq7d6+qq6t1+PBhdXZ2av78+erp6Rn4kQAAACAuhT1CWlJSopKSkj7LjDHavHmz3nzzTS1atEiStGvXLuXk5Oizzz7Tyy+/rLa2Nn300Uf65JNPNHv2bEnSp59+qtzcXB04cEBz584dxOEAAAAg3kT0oabm5mb5fD7NmTPH2ed2uzV9+nQdOXJEklRXV6fLly8H1fF6vcrPz3fqXC8QCKi9vT1oAwAAQGKIaELq8/kkSTk5OUH7c3JynDKfz6e0tDSNGjWq3zrXq6ysVFZWlrOxkDOAeHbo0CE9+eST8nq9crlc+vLLL4PKjTGqqKiQ1+vVyJEjNWPGDDU1NQXVCQQCWr58ucaMGaP09HQtWLBAZ8+ejeJRAEDkDMmyT9evN2WMuekaVDeqs3r1arW1tTlbS0tLxNoKANHGXHwACBbRZZ88Ho+kq6OgY8eOdfb7/X5n1NTj8ai7u1vnz58PGiX1+/0qLi7u8++63W653e5INjWhhfIUve028BQ/kpmtufiBQECBQMD5zPSn8A02vhIbgb5FdIQ0Ly9PHo9HNTU1zr7u7m7V1tY6yWZhYaFSU1OD6rS2tqqxsbHfhBQAksVQzcWXmP4EIHaFPULa2dmpX3/91fnc3Nys+vp6ZWdna/z48SotLdXatWs1YcIETZgwQWvXrtUtt9yi5557TpKUlZWlJUuWaOXKlRo9erSys7O1atUqTZo0ybnSB4BkdaO5+KdPn3bqhDsXX7o6/amsrMz5fO090wBgW9gJ6U8//aSZM2c6n68Ft8WLF2vnzp164403dPHiRb366qvOwvj79+9XRkaG8zubNm1SSkqKnnrqKWdh/J07d2r4v9+TCgBJLNJz8SWmPwGIXWEnpDNmzJAxpt9yl8uliooKVVRU9FtnxIgR2rJli7Zs2RLuPw8ACW2o5uIDQCwbkqfsAQADw1x8AMkook/ZAwBujrn4ABCMhBR9ioWlo4BExVx8AAjmMjeaEBqj2tvblZWVpba2NmVmZtpuTsyJh2SStfYQNV1d0q23Xv25s1NKTw/p15IhziTDMUaa7fhK7IQ1A4il4cQY5pACAADAKhJSAAAAWMUcUgAA4gSvHkWiYoQUAAAAVpGQAgAAwCoSUgAAAFhFQgoAAACreKgpDtleBw8AACCSGCEFAACAVYyQxhhGPwFg6LBsEhCbGCEFAACAVSSkAAAAsIpb9gAA/A/TpgA7SEhhRShBn7lcAAAkBxJSAAASBA9tIV4xhxQAAABWkZACAADAKhJSAAAAWEVCCgAAAKtISAEAAGAVCSkAAACsIiEFAACAVaxDGmW8BQQAACAYI6QAAACwihFSAEDC4C7UjfEmJ8QqElLELAInAADJgVv2AAAAsIqEFAAAAFaRkAIAAMAqElIAAABYRUIKAAAAq3jKPsJYcgQAACA8jJACAADAKkZIAQCAJNZ/hj2MkAIAAMAqRkgRt7iSB4DoIu5iqDBCCgAAAKtISAEAAGAVt+wBAHGDpfWAxERCioTFXCcAiC7iLgaKW/YAAACwioQUAAAAVpGQAgAAwCoSUgAAAFjFQ01h4gnPxBHKuWQCPgAAQ4+EFAAQM7joB5ITCSlwAyxhAgDA0GMOKQAAAKxihBQAAEQFd53QH0ZIAQAAYBUjpMAgcLUPAJFDTE1eJKT/wtOdAAAA0UdCCgCIGi78AfQlqRJSAiEAAPGLW/qJy2pCWlVVpQ0bNqi1tVUTJ07U5s2b9dhjj9lsEhBVvC0Kg0UcBZAIrCWkX3zxhUpLS1VVVaVHHnlEH3zwgUpKSvTzzz9r/PjxtpoFRFQkRuWjMSLAqEN8Io4CSBTWEtKNGzdqyZIleumllyRJmzdv1rfffqtt27apsrIyqG4gEFAgEHA+t7W1SZLa29vD+jevBP4eZKuB2HOz/wf55d8O+b8R07q6/v9ze7vU0xPSr107ZmPMULQqIsKJo1LkYumNROL7BgxUXMeqWDeAWBpWHDUWBAIBM3z4cLNnz56g/a+//rqZNm1ar/rl5eVGEhsbG1vUt5aWlmiFxrCEG0eNIZaysbHZ2UKJo1ZGSP/880/19PQoJycnaH9OTo58Pl+v+qtXr1ZZWZnz+cqVK/rrr780evRouVyuIW9vompvb1dubq5aWlqUmZlpuzlxi36MjFjrR2OMOjo65PV6bTelT+HGUSm0WBpr5yHe0H+DRx8OTiz1Xzhx1OpDTdcnk8aYPhNMt9stt9sdtO+2224byqYllczMTOtf2kRAP0ZGLPVjVlaW7SbcVKhxVAovlsbSeYhH9N/g0YeDEyv9F2octfLq0DFjxmj48OG9ruL9fn+vq30AQG/EUQCJxEpCmpaWpsLCQtXU1ATtr6mpUXFxsY0mAUBcIY4CSCTWbtmXlZXphRde0JQpU1RUVKTt27frzJkzWrp0qa0mJR23263y8vJet/AQHvoxMujH8A1FHOU8DA79N3j04eDEa/+5jLG3pklVVZXWr1+v1tZW5efna9OmTZo2bZqt5gBA3CGOAkgEVhNSAAAAwMocUgAAAOAaElIAAABYRUIKAAAAq0hIAQAAYBUJaQLZtm2bJk+e7LydoaioSF9//bVTboxRRUWFvF6vRo4cqRkzZqipqSnobwQCAS1fvlxjxoxRenq6FixYoLNnz0b7UGJKZWWlXC6XSktLnX30ZWgqKirkcrmCNo/H45TTj9FRWVmphx56SBkZGbrjjju0cOFCnThxIqgO56J/xNbIIqaGLyli6U3fdo+4sW/fPvPVV1+ZEydOmBMnTpg1a9aY1NRU09jYaIwxZt26dSYjI8Ps3r3bNDQ0mKefftqMHTvWtLe3O39j6dKl5s477zQ1NTXm2LFjZubMmaagoMD8888/tg7Lqh9//NHcfffdZvLkyWbFihXOfvoyNOXl5WbixImmtbXV2fx+v1NOP0bH3LlzzY4dO0xjY6Opr6838+bNM+PHjzednZ1OHc5F/4itkUNMHZhkiKUkpAlu1KhR5sMPPzRXrlwxHo/HrFu3zim7dOmSycrKMu+//74xxpgLFy6Y1NRUU11d7dT5/fffzbBhw8w333wT9bbb1tHRYSZMmGBqamrM9OnTneBJX4auvLzcFBQU9FlGP9rj9/uNJFNbW2uM4VwMBLE1fMTUgUuGWMot+wTV09Oj6upqdXV1qaioSM3NzfL5fJozZ45Tx+12a/r06Tpy5Igkqa6uTpcvXw6q4/V6lZ+f79RJJq+99prmzZun2bNnB+2nL8Nz8uRJeb1e5eXl6ZlnntGpU6ck0Y82tbW1SZKys7MlcS7CQWwdOGLq4CR6LLX26lAMjYaGBhUVFenSpUu69dZbtXfvXj344IPOFy4nJyeofk5Ojk6fPi1J8vl8SktL06hRo3rV8fl80TmAGFFdXa1jx47p6NGjvcqu9QV9eXNTp07Vxx9/rPvvv19//PGH3n77bRUXF6upqYl+tMQYo7KyMj366KPKz8+XxHc6FMTWwSGmDk4yxFIS0gTzwAMPqL6+XhcuXNDu3bu1ePFi1dbWOuUulyuovjGm177rhVInkbS0tGjFihXav3+/RowY0W89+vLmSkpKnJ8nTZqkoqIi3Xvvvdq1a5cefvhhSfRjtC1btkzHjx/X4cOHe5VxLvpHbB04YurgJUMs5ZZ9gklLS9N9992nKVOmqLKyUgUFBXrnnXecp/GuvxLy+/3OVZXH41F3d7fOnz/fb51kUFdXJ7/fr8LCQqWkpCglJUW1tbV69913lZKS4vQFfRm+9PR0TZo0SSdPnuQ7acHy5cu1b98+ff/99xo3bpyzn3Nxc8TWgSOmRl4ixlIS0gRnjFEgEFBeXp48Ho9qamqcsu7ubtXW1qq4uFiSVFhYqNTU1KA6ra2tamxsdOokg1mzZqmhoUH19fXONmXKFD3//POqr6/XPffcQ18OUCAQ0C+//KKxY8fynYwiY4yWLVumPXv26LvvvlNeXl5QOecifMTW0BFTIy8hY2m0n6LC0Fm9erU5dOiQaW5uNsePHzdr1qwxw4YNM/v37zfGXF0WIisry+zZs8c0NDSYZ599ts9lIcaNG2cOHDhgjh07Zh5//PGYWhbCln8/EWoMfRmqlStXmoMHD5pTp06ZH374wcyfP99kZGSY3377zRhDP0bLK6+8YrKysszBgweDlo35+++/nTqci/4RWyOPmBqeZIilJKQJ5MUXXzR33XWXSUtLM7fffruZNWuWEzCNubo0RHl5ufF4PMbtdptp06aZhoaGoL9x8eJFs2zZMpOdnW1Gjhxp5s+fb86cORPtQ4k51wdP+jI019bCS01NNV6v1yxatMg0NTU55fRjdEjqc9uxY4dTh3PRP2Jr5BFTw5MMsdRljDE2R2gBAACQ3JhDCgAAAKtISAEAAGAVCSkAAACsIiEFAACAVSSkAAAAsIqEFAAAAFaRkAIAAMAqElIAAABYRUIKAAAAq0hIAQAAYBUJKQAAAKz6L8EIHqGN/0SMAAAAAElFTkSuQmCC", | |
"text/plain": [ | |
"<Figure size 800x100 with 2 Axes>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"f, (a1, a2) = plt.subplots(1, 2, figsize=(8, 1))\n", | |
"\n", | |
"r1 = hsic_permutation_torch(K, L, n_perm=1_000)\n", | |
"a1.hist(r1.permuted_estimates, bins=\"auto\")\n", | |
"a1.axvline(r1.estimate, color=\"r\")\n", | |
"\n", | |
"r2 = hsic_permutation_numba(K, L, n_perm=1_000)\n", | |
"a2.hist(r2.permuted_estimates, bins=\"auto\")\n", | |
"a2.axvline(r2.estimate, color=\"r\")\n", | |
"\n", | |
"r1.p_value, r2.p_value" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "60040945-a203-46fa-8f7f-86e0dad2046e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"358 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"hsic_permutation_torch(K, L).p_value" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "7524d67d-b5d1-41c1-941e-d189ae0e013e", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"152 ms ± 513 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"hsic_permutation_numba(K, L).p_value" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.7" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment