Skip to content

Instantly share code, notes, and snippets.

@djsutherland
Last active April 8, 2025 18:52
Show Gist options
  • Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Save djsutherland/6f564b460d14f11ba8ee1df664b12136 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "d325051b-74dd-4c95-9254-14c3275c2aae",
"metadata": {},
"source": [
"Our goal is to efficiently compute permutation tests for MMD which are finite-sample valid,\n",
"based on samples $(X_i)_{i=1}^m$, $(Y_j)_{j=1}^n$."
]
},
{
"cell_type": "markdown",
"id": "4986363a-a710-438d-af2a-49c8546261f3",
"metadata": {},
"source": [
"## Computation\n",
"Write $Z_i = \\begin{cases}X_i & 1 \\le i \\le m \\\\ Y_{i-m} & m < i \\le m + n .\\end{cases}$\n",
"\n",
"First, note that the plug-in estimator (\"the biased estimator\") for MMD is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_b^2}\n",
" &= \\frac1{m^2} \\sum_{i=1}^m \\sum_{i'=1}^m k(X_i, X_{i'})\n",
" + \\frac{1}{n^2} \\sum_{j=1}^n \\sum_{j'=1}^n'k(Y_j, Y_{j'})\n",
" - 2 \\frac{1}{mn} \\sum_{i=1}^m \\sum_{j=1}^n k(X_i, Y_j)\n",
"\\\\&= \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}^\\top\n",
" \\underbrace{\\begin{bmatrix} K_X & K_{XY} \\\\ K_{YX} & K_Y \\end{bmatrix}}_K\n",
" \\begin{bmatrix} \\mathbf{1}_m /m \\\\ -\\mathbf{1}_n / n \\end{bmatrix}\n",
",\\end{align*}\n",
"where $\\mathbf 1_m$ is a vector of $m$ ones, and $K$ an $(m + n) \\times (m + n)$ matrix with entries $K_{ij} = k(Z_i, Z_j)$.\n",
"To get a permuted value of this estimator, we just need to permute the vector we're hitting it with."
]
},
{
"cell_type": "markdown",
"id": "9b44f61f-1dd5-41e1-b69f-c50611152b06",
"metadata": {},
"source": [
"The typical unbiased estimator (claimed to be the MVUE but I'm not sure this is actually true...) is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_u^2}\n",
" &= \\frac{1}{m (m-1)} \\sum_{i \\ne i'} k(X_i, X_{i'})\n",
" + \\frac{1}{n (n-1)} \\sum_{j \\ne j'} k(Y_j, Y_{j'})\n",
" - \\frac{2}{m n} \\sum_{i, j} k(X_i, Y_j)\n",
";\\end{align*}\n",
"this doesn't seem especially amenable to easy permutation in the same way as the previous one."
]
},
{
"cell_type": "markdown",
"id": "a65afbca-9469-4100-9a2c-f20fc5337fe1",
"metadata": {},
"source": [
"But the U-statistic estimator, which assumes $m = n$ and is unbiased but not quite minimum variance, is\n",
"\\begin{align*}\n",
"\\widehat{\\operatorname{MMD}_U^2}\n",
" &= \\frac{1}{n (n-1)} \\sum_{i \\ne j} \\left[ k(X_i, X_j) + k(Y_i, Y_j) - k(X_i, Y_j) - k(X_j, Y_i) \\right]\n",
"\\\\&= \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(X_i, X_j)\n",
" + \\frac{1}{n(n-1)} \\sum_{i \\ne j} k(Y_i, Y_j)\n",
" - \\frac{2}{n(n-1)} \\sum_{i \\ne j} k(X_i, Y_j)\n",
"\\\\&= \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(X_i, X_j) - \\sum_i k(X_i, X_i) \\right)\n",
" + \\frac{1}{n(n-1)} \\left( \\sum_{i,j} k(Y_i, Y_j) - \\sum_i k(Y_i, Y_i) \\right)\n",
"\\\\&\\qquad\n",
" - \\frac{2}{n(n-1)} \\left( \\sum_{i,j} k(X_i, Y_j) - \\sum_i k(X_i, Y_i) \\right)\n",
"\\\\&= \\frac{n}{n-1} \\widehat{\\operatorname{MMD}_b^2}\n",
" - \\frac{1}{n(n-1)} \\left( \\sum_i k(X_i, X_i) + \\sum_i k(Y_i, Y_i) - 2 \\sum_i k(X_i, Y_i) \\right)\n",
";\\end{align*}\n",
"the first two terms of the correction are simply the trace of $K$ and don't depend on the particular permutation.\n",
"The third term does, but isn't so bad."
]
},
{
"cell_type": "markdown",
"id": "a23f6a33-72d7-4549-af56-313e5b11b9cf",
"metadata": {},
"source": [
"## p-value\n",
"As discussed in section 3.4 of Hemerik and Goeman (STAT 2018), [Exact testing with random permutations](https://arxiv.org/abs/1411.7565),\n",
"let $T_1, T_2, \\dots, T_w$ be the $w$ permuted test statistics returned by the previous procedure,\n",
"where $T_1$ is the actual data split\n",
"and $2$ through $w$ are uniformly random permutations.\n",
"A (possibly conservative) $p$-value is $\\lvert \\{ k \\in [w] : T_k \\ge T_1 \\} \\rvert / w$.\n",
"An exact (randomized) $p$-value can be found based on counting ties, but whatever."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "01140287-8c45-4dff-b1d4-cddd0f3db08a",
"metadata": {},
"outputs": [],
"source": [
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "8e31cd9f-392e-4c09-821e-bd7827a7379c",
"metadata": {},
"outputs": [],
"source": [
"from collections import namedtuple\n",
"\n",
"PermutationResult = namedtuple(\n",
" \"PermutationResult\", [\"estimate\", \"p_value\", \"permuted_estimates\"]\n",
")\n",
"\n",
"\n",
"def mmd2_permutation(joint_kernel, n_X, n_perm=500, u_stat=False):\n",
" \"\"\"\n",
" joint_kernel: should be an array of shape [n_X + n_Y, n_X + n_Y] with kernel values,\n",
" ie [ K_XX K_XY ]\n",
" [ K_YX K_YY ]\n",
" n_X: number of entries in the first set\n",
" n_perm: total number of permutations, including the identity\n",
"\n",
" If biased is True, uses the plug-in estimator (MMD between empirical distributions).\n",
" If False, it uses the U-statistic estimator, which is unbiased but drops k(x_i, y_i) terms.\n",
" (I'm not sure how to implement the \"unbiased estimator\" (which includes those terms) efficiently.)\n",
" \"\"\"\n",
" K = joint_kernel = torch.as_tensor(joint_kernel)\n",
" device = K.device\n",
" dtype = K.dtype\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"joint_kernel should be square, got {K.shape}\")\n",
" n_X = int(n_X)\n",
" n_Y = n - n_X\n",
" if n_X <= 0 or n_Y <= 0:\n",
" raise ValueError(\"need a positive number of samples from each\")\n",
"\n",
" if u_stat:\n",
" if n_X != n_Y:\n",
" raise ValueError(\"u-stat estimator only defined for equal sample sizes\")\n",
" w_X = 1\n",
" w_Y = -1\n",
" else:\n",
" w_X = 1 / n_X\n",
" w_Y = -1 / n_Y\n",
"\n",
" # construct permutations\n",
" # there probably should be a faster way to do this but, idk\n",
" perms = torch.stack(\n",
" [torch.arange(n, device=device)]\n",
" + [torch.randperm(n, device=device) for _ in range(n_perm - 1)]\n",
" )\n",
" X_inds = perms[:, :n_X]\n",
" Y_inds = perms[:, n_X:]\n",
"\n",
" # set weights to w_X for things in X_inds, w_Y for others\n",
" ws = torch.full((n_perm, n), w_Y, device=device, dtype=dtype)\n",
" ws.scatter_(1, X_inds, w_X)\n",
"\n",
" # the \"basic\" estimate; either the biased est or a constant times it\n",
" ests = torch.einsum(\"pi,ij,pj->p\", ws, joint_kernel, ws)\n",
"\n",
" if u_stat:\n",
" # need to subtract \\sum_i k(X_i, X_i) + k(Y_i, Y_i) + 2 k(X_i, Y_i)\n",
" # first two are just trace\n",
" # for the last one, we need to see which ones were lined up\n",
" # NOTE: this makes an unnecessary copy if joint_kernel isn't already contiguous,\n",
" # but this generally shouldn't be a big deal\n",
" cross_terms = joint_kernel.take(X_inds * n + Y_inds).sum(1)\n",
" ests = (ests - joint_kernel.trace() + 2 * cross_terms) / (n_X * (n_X - 1))\n",
"\n",
" p_val = (ests >= ests[0]).float().mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5e149931-01e5-426a-9a90-bef389e14a4f",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c45c7919-859e-434e-9679-f40ab7b12dbe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.0640)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjVklEQVR4nO3df3BU1f3/8dfKjyUJmxVQdrMSSWwjqAF/gA3E2kQLUYraDo6/oHyw1g4UsETqUFLmO0T7cROxjelMFAeGoXEcxKmidUpF4lRjp4EaMIw0UItDgKisKTRmA6SJwPn+wSdb1iTIJrtns+H5mLkzu+eee897D2vy8uTuXYcxxggAAMCSi+JdAAAAuLAQPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYNTjeBXzV6dOn9dlnn8nlcsnhcMS7HAAAcB6MMWptbZXP59NFF517baPfhY/PPvtM6enp8S4DAAD0QmNjo8aMGXPOPv0ufLhcLklnik9NTY1zNQAuCMePSz7fmceffSalpMS3HiABBYNBpaenh36Pn0u/Cx+df2pJTU0lfACwY9Cg/z5OTSV8AH1wPpdMcMEpAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsGhzvAhBfGcs3x+S8B0pnxuS8AIDEx8oHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArBoc7wIwMGUs3xyzcx8onRmzcwMAYi+ilY+MjAw5HI4u26JFiyRJxhgVFxfL5/MpKSlJ+fn5qq+vj0nhAAAgMUUUPmpra3X48OHQVlVVJUm65557JEmrVq1SWVmZKioqVFtbK6/Xq+nTp6u1tTX6lQMAgIQUUfi49NJL5fV6Q9sf//hHfeMb31BeXp6MMSovL9eKFSs0a9YsZWdnq7KyUidOnNCGDRtiVT8AAEgwvb7gtKOjQy+++KIeeughORwONTQ0KBAIqKCgINTH6XQqLy9PNTU1PZ6nvb1dwWAwbAMAAANXr8PH66+/ri+++EIPPvigJCkQCEiSPB5PWD+PxxPa152SkhK53e7Qlp6e3tuSAABAAuh1+Fi3bp1mzJghn88X1u5wOMKeG2O6tJ2tqKhILS0toa2xsbG3JQEAgATQq4/aHjx4UG+//bY2bdoUavN6vZLOrICkpaWF2puamrqshpzN6XTK6XT2pgwAAJCAerXysX79eo0ePVozZ/73fguZmZnyer2hT8BIZ64Lqa6uVm5ubt8rBQAAA0LEKx+nT5/W+vXrNW/ePA0e/N/DHQ6HCgsL5ff7lZWVpaysLPn9fiUnJ2v27NlRLRoAACSuiMPH22+/rUOHDumhhx7qsm/ZsmVqa2vTwoUL1dzcrJycHG3dulUulysqxQJS7O6eyp1TAcAOhzHGxLuIswWDQbndbrW0tCg1NTXe5Qx4sbwNeqIhfFzAjh+Xhg8/8/jYMSklJb71AAkokt/ffLEcAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqyIOH59++ql++MMfatSoUUpOTtZ1112nnTt3hvYbY1RcXCyfz6ekpCTl5+ervr4+qkUDAIDEFVH4aG5u1k033aQhQ4bozTff1J49e/Sb3/xGF198cajPqlWrVFZWpoqKCtXW1srr9Wr69OlqbW2Ndu0AACABDY6k81NPPaX09HStX78+1JaRkRF6bIxReXm5VqxYoVmzZkmSKisr5fF4tGHDBs2fPz86VQMAgIQV0crHG2+8ocmTJ+uee+7R6NGjdf3112vt2rWh/Q0NDQoEAiooKAi1OZ1O5eXlqaampttztre3KxgMhm0AAGDgiih87N+/X6tXr1ZWVpbeeustLViwQD/72c/0wgsvSJICgYAkyePxhB3n8XhC+76qpKREbrc7tKWnp/fmdQAAgAQRUfg4ffq0brjhBvn9fl1//fWaP3++fvKTn2j16tVh/RwOR9hzY0yXtk5FRUVqaWkJbY2NjRG+BAAAkEgiCh9paWm6+uqrw9quuuoqHTp0SJLk9XolqcsqR1NTU5fVkE5Op1OpqalhGwAAGLgiCh833XSTPvroo7C2f/7znxo7dqwkKTMzU16vV1VVVaH9HR0dqq6uVm5ubhTKBQAAiS6iT7s8+uijys3Nld/v17333qv3339fa9as0Zo1aySd+XNLYWGh/H6/srKylJWVJb/fr+TkZM2ePTsmLwAAACSWiMLHjTfeqNdee01FRUV64oknlJmZqfLycs2ZMyfUZ9myZWpra9PChQvV3NysnJwcbd26VS6XK+rFAwCAxOMwxph4F3G2YDAot9utlpYWrv+wIGP55niX0G8cKJ0Z7xIQL8ePS8OHn3l87JiUkhLfeoAEFMnvb77bBQAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFg1ON4FAP1FxvLNMTv3gdKZMTs3ACQaVj4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVROGjuLhYDocjbPN6vaH9xhgVFxfL5/MpKSlJ+fn5qq+vj3rRAAAgcUW88nHNNdfo8OHDoW337t2hfatWrVJZWZkqKipUW1srr9er6dOnq7W1NapFAwCAxBVx+Bg8eLC8Xm9ou/TSSyWdWfUoLy/XihUrNGvWLGVnZ6uyslInTpzQhg0bol44AABITBGHj3379snn8ykzM1P333+/9u/fL0lqaGhQIBBQQUFBqK/T6VReXp5qamp6PF97e7uCwWDYBgAABq6IwkdOTo5eeOEFvfXWW1q7dq0CgYByc3N19OhRBQIBSZLH4wk7xuPxhPZ1p6SkRG63O7Slp6f34mUAAIBEEVH4mDFjhu6++25NmDBB06ZN0+bNmyVJlZWVoT4OhyPsGGNMl7azFRUVqaWlJbQ1NjZGUhIAAEgwffqobUpKiiZMmKB9+/aFPvXy1VWOpqamLqshZ3M6nUpNTQ3bAADAwNWn8NHe3q69e/cqLS1NmZmZ8nq9qqqqCu3v6OhQdXW1cnNz+1woAAAYGAZH0vmxxx7TnXfeqcsvv1xNTU363//9XwWDQc2bN08Oh0OFhYXy+/3KyspSVlaW/H6/kpOTNXv27FjVDwAAEkxE4eOTTz7RAw88oCNHjujSSy/VlClTtH37do0dO1aStGzZMrW1tWnhwoVqbm5WTk6Otm7dKpfLFZPiAQBA4nEYY0y8izhbMBiU2+1WS0sL139YkLF8c7xLuCAcKJ0Z7xJwLsePS8OHn3l87JiUkhLfeoAEFMnvb77bBQAAWEX4AAAAVkV0zQeA3onVn7f4cw6ARMTKBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMCqPoWPkpISORwOFRYWhtqMMSouLpbP51NSUpLy8/NVX1/f1zoBAMAA0evwUVtbqzVr1mjixIlh7atWrVJZWZkqKipUW1srr9er6dOnq7W1tc/FAgCAxNer8HHs2DHNmTNHa9eu1YgRI0LtxhiVl5drxYoVmjVrlrKzs1VZWakTJ05ow4YNUSsaAAAkrl6Fj0WLFmnmzJmaNm1aWHtDQ4MCgYAKCgpCbU6nU3l5eaqpqen2XO3t7QoGg2EbAAAYuAZHesDGjRv1wQcfqLa2tsu+QCAgSfJ4PGHtHo9HBw8e7PZ8JSUlevzxxyMt44KSsXxzvEsAACBqIlr5aGxs1JIlS/Tiiy9q2LBhPfZzOBxhz40xXdo6FRUVqaWlJbQ1NjZGUhIAAEgwEa187Ny5U01NTZo0aVKo7dSpU3rvvfdUUVGhjz76SNKZFZC0tLRQn6ampi6rIZ2cTqecTmdvagcAAAkoopWP7373u9q9e7d27doV2iZPnqw5c+Zo165duuKKK+T1elVVVRU6pqOjQ9XV1crNzY168QAAIPFEtPLhcrmUnZ0d1paSkqJRo0aF2gsLC+X3+5WVlaWsrCz5/X4lJydr9uzZ0asaAAAkrIgvOP06y5YtU1tbmxYuXKjm5mbl5ORo69atcrlc0R4KAAAkIIcxxsS7iLMFg0G53W61tLQoNTU13uX0C3zaBT05UDoz3iUMDMePS8OHn3l87JiUkhLfeoAEFMnvb77bBQAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVUb/JGAB7YnkPGO4hAiBWWPkAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGBVROFj9erVmjhxolJTU5WamqqpU6fqzTffDO03xqi4uFg+n09JSUnKz89XfX191IsGAACJK6LwMWbMGJWWlmrHjh3asWOHbr31Vn3/+98PBYxVq1aprKxMFRUVqq2tldfr1fTp09Xa2hqT4gEAQOKJKHzceeed+t73vqcrr7xSV155pZ588kkNHz5c27dvlzFG5eXlWrFihWbNmqXs7GxVVlbqxIkT2rBhQ6zqBwAACabX13ycOnVKGzdu1PHjxzV16lQ1NDQoEAiooKAg1MfpdCovL081NTU9nqe9vV3BYDBsAwAAA1fE4WP37t0aPny4nE6nFixYoNdee01XX321AoGAJMnj8YT193g8oX3dKSkpkdvtDm3p6emRlgQAABJIxOFj3Lhx2rVrl7Zv366f/vSnmjdvnvbs2RPa73A4wvobY7q0na2oqEgtLS2hrbGxMdKSAABAAhkc6QFDhw7VN7/5TUnS5MmTVVtbq9/+9rf6xS9+IUkKBAJKS0sL9W9qauqyGnI2p9Mpp9MZaRkAACBB9fk+H8YYtbe3KzMzU16vV1VVVaF9HR0dqq6uVm5ubl+HAQAAA0REKx+//OUvNWPGDKWnp6u1tVUbN27Uu+++qy1btsjhcKiwsFB+v19ZWVnKysqS3+9XcnKyZs+eHav6AQBAgokofHz++eeaO3euDh8+LLfbrYkTJ2rLli2aPn26JGnZsmVqa2vTwoUL1dzcrJycHG3dulUulysmxQMAgMTjMMaYeBdxtmAwKLfbrZaWFqWmpsa7nH4hY/nmeJeAC9CB0pnxLsGe48el4cPPPD52TEpJiW89QAKK5Pc33+0CAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALBqcLwLGEgylm+OdwkAAPR7rHwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAq7jDKYBuxeqOvQdKZ8bkvAASBysfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsiCh8lJSW68cYb5XK5NHr0aP3gBz/QRx99FNbHGKPi4mL5fD4lJSUpPz9f9fX1US0aAAAkrojCR3V1tRYtWqTt27erqqpKJ0+eVEFBgY4fPx7qs2rVKpWVlamiokK1tbXyer2aPn26Wltbo148AABIPBHdXn3Lli1hz9evX6/Ro0dr586d+s53viNjjMrLy7VixQrNmjVLklRZWSmPx6MNGzZo/vz50ascAAAkpD5d89HS0iJJGjlypCSpoaFBgUBABQUFoT5Op1N5eXmqqanpy1AAAGCA6PUXyxljtHTpUn37299Wdna2JCkQCEiSPB5PWF+Px6ODBw92e5729na1t7eHngeDwd6WBAAAEkCvVz4WL16sDz/8UC+99FKXfQ6HI+y5MaZLW6eSkhK53e7Qlp6e3tuSAABAAuhV+HjkkUf0xhtv6J133tGYMWNC7V6vV9J/V0A6NTU1dVkN6VRUVKSWlpbQ1tjY2JuSAABAgogofBhjtHjxYm3atEl//vOflZmZGbY/MzNTXq9XVVVVobaOjg5VV1crNze323M6nU6lpqaGbQAAYOCK6JqPRYsWacOGDfrDH/4gl8sVWuFwu91KSkqSw+FQYWGh/H6/srKylJWVJb/fr+TkZM2ePTsmLwAAACSWiMLH6tWrJUn5+flh7evXr9eDDz4oSVq2bJna2tq0cOFCNTc3KycnR1u3bpXL5YpKwQAAILFFFD6MMV/bx+FwqLi4WMXFxb2tCQAADGB8twsAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACrCB8AAMAqwgcAALCK8AEAAKwaHO8CAFxYMpZvjtm5D5TOjNm5AUQPKx8AAMAqwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAqojDx3vvvac777xTPp9PDodDr7/+eth+Y4yKi4vl8/mUlJSk/Px81dfXR6teAACQ4CIOH8ePH9e1116rioqKbvevWrVKZWVlqqioUG1trbxer6ZPn67W1tY+FwsAABJfxN/tMmPGDM2YMaPbfcYYlZeXa8WKFZo1a5YkqbKyUh6PRxs2bND8+fP7Vi0AAEh4Ub3mo6GhQYFAQAUFBaE2p9OpvLw81dTUdHtMe3u7gsFg2AYAAAauqIaPQCAgSfJ4PGHtHo8ntO+rSkpK5Ha7Q1t6eno0SwIAAP1MTD7t4nA4wp4bY7q0dSoqKlJLS0toa2xsjEVJAACgn4j4mo9z8Xq9ks6sgKSlpYXam5qauqyGdHI6nXI6ndEsAwAA9GNRXfnIzMyU1+tVVVVVqK2jo0PV1dXKzc2N5lAAACBBRbzycezYMX388ceh5w0NDdq1a5dGjhypyy+/XIWFhfL7/crKylJWVpb8fr+Sk5M1e/bsqBYOAAASU8ThY8eOHbrllltCz5cuXSpJmjdvnn73u99p2bJlamtr08KFC9Xc3KycnBxt3bpVLpcrelUDQDcylm/u1XFJHf/R3v97fNX/26K2ocOiV9TXOFA609pYQH8RcfjIz8+XMabH/Q6HQ8XFxSouLu5LXQAAYIDiu10AAIBVhA8AAGBVVD9qCwCITG+vUzkfXE+C/oqVDwAAYBXhAwAAWEX4AAAAVhE+AACAVRfcBaexvLgLAAB8PVY+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFZdcDcZA4ALRaxuqsi35aKvWPkAAABWET4AAIBVhA8AAGAV4QMAAFhF+AAAAFYRPgAAgFWEDwAAYBXhAwAAWMVNxgAAEYnVzcskbmB2oWDlAwAAWEX4AAAAVhE+AACAVYQPAABgFeEDAABYRfgAAABWET4AAIBVhA8AAGAV4QMAAFgVszucPvfcc3r66ad1+PBhXXPNNSovL9fNN98cq+EAAANALO+eGiuxuivrQL6TbExWPl5++WUVFhZqxYoVqqur080336wZM2bo0KFDsRgOAAAkkJiEj7KyMv34xz/Www8/rKuuukrl5eVKT0/X6tWrYzEcAABIIFH/s0tHR4d27typ5cuXh7UXFBSopqamS//29na1t7eHnre0tEiSgsFgtEuTJJ1uPxGT8wJIXKc6/qPOnzin2k/otDkd13qQWBLx91Usau48pzHma/tGPXwcOXJEp06dksfjCWv3eDwKBAJd+peUlOjxxx/v0p6enh7t0gCgR+7OB8/9TzzLQAJyl8e7gsjFsubW1la53e5z9onZBacOhyPsuTGmS5skFRUVaenSpaHnp0+f1r///W+NGjWq2/7nKxgMKj09XY2NjUpNTe31eQYi5qZ7zEvPmJvuMS/dY156NpDnxhij1tZW+Xy+r+0b9fBxySWXaNCgQV1WOZqamrqshkiS0+mU0+kMa7v44oujVk9qauqA+weOFuame8xLz5ib7jEv3WNeejZQ5+brVjw6Rf2C06FDh2rSpEmqqqoKa6+qqlJubm60hwMAAAkmJn92Wbp0qebOnavJkydr6tSpWrNmjQ4dOqQFCxbEYjgAAJBAYhI+7rvvPh09elRPPPGEDh8+rOzsbP3pT3/S2LFjYzFct5xOp1auXNnlTzpgbnrCvPSMueke89I95qVnzM0ZDnM+n4kBAACIEr7bBQAAWEX4AAAAVhE+AACAVYQPAABgVcKEj+bmZs2dO1dut1tut1tz587VF198cc5jjDEqLi6Wz+dTUlKS8vPzVV9fH9ZnzZo1ys/PV2pqqhwOR7fn7M3YNsVqbtrb2/XII4/okksuUUpKiu666y598sknYX0yMjLkcDjCtq9+r48tzz33nDIzMzVs2DBNmjRJf/nLX87Zv7q6WpMmTdKwYcN0xRVX6Pnnn+/S59VXX9XVV18tp9Opq6++Wq+99lqfx7UtHvNSXFzc5X3h9Xqj+rqiIdpzU19fr7vvvjv030V5eXlUxrUtHvOSCO+ZaM/L2rVrdfPNN2vEiBEaMWKEpk2bpvfff7/P4yYEkyBuv/12k52dbWpqakxNTY3Jzs42d9xxxzmPKS0tNS6Xy7z66qtm9+7d5r777jNpaWkmGAyG+jzzzDOmpKTElJSUGEmmubk5KmPbFKu5WbBggbnssstMVVWV+eCDD8wtt9xirr32WnPy5MlQn7Fjx5onnnjCHD58OLS1trbG7LX2ZOPGjWbIkCFm7dq1Zs+ePWbJkiUmJSXFHDx4sNv++/fvN8nJyWbJkiVmz549Zu3atWbIkCHmlVdeCfWpqakxgwYNMn6/3+zdu9f4/X4zePBgs3379l6Pa1u85mXlypXmmmuuCXtfNDU1xfz1RiIWc/P++++bxx57zLz00kvG6/WaZ555ps/j2haveenv75lYzMvs2bPNs88+a+rq6szevXvNj370I+N2u80nn3zS63ETRUKEjz179hhJYT/ctm3bZiSZf/zjH90ec/r0aeP1ek1paWmo7T//+Y9xu93m+eef79L/nXfe6TZ89GZsm2I1N1988YUZMmSI2bhxY6jPp59+ai666CKzZcuWUNvYsWO7/UFi27e+9S2zYMGCsLbx48eb5cuXd9t/2bJlZvz48WFt8+fPN1OmTAk9v/fee83tt98e1ue2224z999/f6/HtS1e87Jy5Upz7bXX9rH62IrF3Jytp/82LsT3zNl6mpf+/p6J9bwYY8zJkyeNy+UylZWVvR43USTEn122bdsmt9utnJycUNuUKVPkdrtVU1PT7TENDQ0KBAIqKCgItTmdTuXl5fV4TLTGtilWc7Nz5059+eWXYX18Pp+ys7O7nPepp57SqFGjdN111+nJJ59UR0dHNF/i1+ro6NDOnTvDapWkgoKCHudg27ZtXfrfdttt2rFjh7788stz9uk8Z2/GtSle89Jp37598vl8yszM1P3336/9+/f39SVFTazmJhbj2hSveenUX98ztublxIkT+vLLLzVy5Mhej5soEiJ8BAIBjR49ukv76NGju3yB3dnHSOryZXYej6fHY6I1tk2xmptAIKChQ4dqxIgRPfaRpCVLlmjjxo165513tHjxYpWXl2vhwoV9ek2ROnLkiE6dOhXRv3UgEOi2/8mTJ3XkyJFz9uk8Z2/GtSle8yJJOTk5euGFF/TWW29p7dq1CgQCys3N1dGjR6Px0vosVnMTi3Ftite8SP37PWNrXpYvX67LLrtM06ZN6/W4iSKu4aO7C4y+uu3YsUOS5HA4uhxvjOm2/Wxf3X8+x3zdOXp7nkj017n5ap9HH31UeXl5mjhxoh5++GE9//zzWrduXVx+YET6errr/9X28zlnNN5jsRSPeZkxY4buvvtuTZgwQdOmTdPmzZslSZWVlb17ETESi7mJxbi2xWNeEuE9E8t5WbVqlV566SVt2rRJw4YN69O4iSAm3+1yvhYvXqz777//nH0yMjL04Ycf6vPPP++y71//+leXRNip8yrpQCCgtLS0UHtTU1OPx/R0nkjHjoZ4z43X61VHR4eam5vDVj+amprO+e3EU6ZMkSR9/PHHGjVq1Dnrj5ZLLrlEgwYN6vJ/Auf6t/Z6vd32Hzx4cKjunvp0nrM349oUr3npTkpKiiZMmKB9+/b15qVEXazmJhbj2hSveelOf3rPxHpefv3rX8vv9+vtt9/WxIkT+zRuoojryscll1yi8ePHn3MbNmyYpk6dqpaWlrCPIP3tb39TS0tLj78IMzMz5fV6VVVVFWrr6OhQdXX1OX95flVvxo6GeM/NpEmTNGTIkLA+hw8f1t///vdzvu66ujpJCgs1sTZ06FBNmjQprFZJqqqq6rHWqVOndum/detWTZ48WUOGDDlnn85z9mZcm+I1L91pb2/X3r17rb4vziVWcxOLcW2K17x0pz+9Z2I5L08//bR+9atfacuWLZo8eXKfx00YVi9v7YPbb7/dTJw40Wzbts1s27bNTJgwocvHSceNG2c2bdoUel5aWmrcbrfZtGmT2b17t3nggQe6fJz08OHDpq6uzqxdu9ZIMu+9956pq6szR48ejWjseIrV3CxYsMCMGTPGvP322+aDDz4wt956a9hHbWtqakxZWZmpq6sz+/fvNy+//LLx+XzmrrvusvPCz9L5cbR169aZPXv2mMLCQpOSkmIOHDhgjDFm+fLlZu7cuaH+nR+De/TRR82ePXvMunXrunwM7q9//asZNGiQKS0tNXv37jWlpaU9ftS2p3HjLV7z8vOf/9y8++67Zv/+/Wb79u3mjjvuMC6Xq9/MizGxmZv29nZTV1dn6urqTFpamnnsscdMXV2d2bdv33mPG2/xmpf+/p6Jxbw89dRTZujQoeaVV17p8XYF/f390lsJEz6OHj1q5syZY1wul3G5XGbOnDldPhYryaxfvz70/PTp02blypXG6/Uap9NpvvOd75jdu3eHHbNy5Uojqct29nnOZ+x4itXctLW1mcWLF5uRI0eapKQkc8cdd5hDhw6F9u/cudPk5OQYt9tthg0bZsaNG2dWrlxpjh8/HsuX26Nnn33WjB071gwdOtTccMMNprq6OrRv3rx5Ji8vL6z/u+++a66//nozdOhQk5GRYVavXt3lnL///e/NuHHjzJAhQ8z48ePNq6++GtG4/UE85qXzvjFDhgwxPp/PzJo1y9TX18fk9fVFtOemoaGh258nXz3PhfaeOZ95SYT3TLTnZezYsd3Oy8qVK8973ETlMOb/roABAACwICE+agsAAAYOwgcAALCK8AEAAKwifAAAAKsIHwAAwCrCBwAAsIrwAQAArCJ8AAAAqwgfAADAKsIHAACwivABAACsInwAAACr/j/4GN8PvHajxQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X = torch.randn(100, 3)\n",
"Y = torch.randn(100, 3) * 1.4\n",
"Z = torch.cat((X, Y))\n",
"K = torch.exp(-0.5 * torch.cdist(Z, Z) ** 2)\n",
"\n",
"res = mmd2_permutation(K, X.shape[0], u_stat=True)\n",
"\n",
"plt.hist(res.permuted_estimates, bins=\"auto\")\n",
"plt.axvline(res.estimate, color=\"r\")\n",
"res.p_value"
]
},
{
"cell_type": "markdown",
"id": "210265ef-6139-4c14-a5f1-c0ee2ffa4222",
"metadata": {},
"source": [
"## HSIC"
]
},
{
"cell_type": "markdown",
"id": "39ec253e-16a1-4ea4-96a5-6e091d3dcff7",
"metadata": {},
"source": [
"What about for HSIC?\n",
"\n",
"The biased (plug-in) estimator can be written as\n",
"$$\\DeclareMathOperator{\\Tr}{Tr}\n",
"\\langle H K H, H L H\\rangle_F = \\Tr(H K H L) = \\langle HKH, L \\rangle_F\n",
",$$\n",
"where $H = I - \\frac1n \\mathbf 1 \\mathbf 1^\\top$ is the centering matrix,\n",
"$K$ is the kernel matrix of $X$ values,\n",
"and $Y$ is the kernel matrix of $Y$ values.\n",
"Note that there's no reason to bother permuting _both_ $X$ and $Y$;\n",
"we can choose to only permute one or the other.\n",
"So we might as well just permute the one that we don't center,\n",
"to only have to center once."
]
},
{
"cell_type": "markdown",
"id": "02621101-0334-4a77-8572-ddf75270cb38",
"metadata": {},
"source": [
"I'm not sure which will be faster, so let's try two different implementation approches:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c01efb65-5819-4c4d-89ad-a0598d059bce",
"metadata": {},
"outputs": [],
"source": [
"def hsic_permutation_torch(K, L, n_perm=500):\n",
" \"\"\"\n",
" K: the pairwise kernel matrix of X values\n",
" L: the pairwise kernel matrix of corresponding Y values\n",
" n_perm: total number of permutations, including the identity\n",
" \"\"\"\n",
" K = torch.as_tensor(K)\n",
" device = K.device\n",
" dtype = K.dtype\n",
" L = torch.as_tensor(L).to(device=device, dtype=dtype)\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"K should be square, got {K.shape}\")\n",
" if L.shape != (n, n):\n",
" raise ValueError(f\"L should be same shape as K ({K.shape}), got {L.shape}\")\n",
"\n",
" row_mean = K.mean(dim=0, keepdim=True)\n",
" col_mean = K.mean(dim=1, keepdim=True)\n",
" HKH_flat = (K - row_mean - col_mean + row_mean.mean()).ravel()\n",
"\n",
" L_flats = torch.empty((n_perm, n * n), device=device, dtype=dtype)\n",
" L_flats[0, :] = L.ravel()\n",
" for i in range(1, n_perm):\n",
" perm = torch.randperm(n, device=device)\n",
" L_flats[i, :] = L[perm.unsqueeze(1), perm.unsqueeze(0)].ravel()\n",
"\n",
" ests = L_flats @ HKH_flat\n",
"\n",
" p_val = (ests >= ests[0]).float().mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e0edc21b-99e6-41b1-a052-61db065777c2",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from numba import jit\n",
"\n",
"@jit(nopython=True)\n",
"def _hsic_perms(HKH, L, perms):\n",
" n = HKH.shape[0]\n",
" n_perms = perms.shape[0]\n",
" ests = np.zeros(n_perms)\n",
" for i in range(n):\n",
" for j in range(n):\n",
" for e in range(n_perms):\n",
" ests[e] += HKH[i, j] * L[perms[e, i], perms[e, j]]\n",
" return ests\n",
"\n",
"def hsic_permutation_numba(K, L, n_perm=500):\n",
" \"\"\"\n",
" K: the pairwise kernel matrix of X values\n",
" L: the pairwise kernel matrix of corresponding Y values\n",
" n_perm: total number of permutations, including the identity\n",
" \"\"\"\n",
" K = np.asarray(K)\n",
" L = np.asarray(L)\n",
"\n",
" n = K.shape[0]\n",
" if K.shape != (n, n):\n",
" raise ValueError(f\"K should be square, got {K.shape}\")\n",
" if L.shape != (n, n):\n",
" raise ValueError(f\"L should be same shape as K ({K.shape}), got {L.shape}\")\n",
"\n",
" row_mean = K.mean(axis=0, keepdims=True)\n",
" col_mean = K.mean(axis=1, keepdims=True)\n",
" HKH = K - row_mean - col_mean + row_mean.mean()\n",
"\n",
" # make permutations, with first row in order\n",
" perms = np.tile(np.arange(n), (n_perm, 1))\n",
" rng = np.random.default_rng()\n",
" rng.permuted(perms[1:], axis=1, out=perms[1:])\n",
" \n",
" ests = _hsic_perms(HKH, L, perms)\n",
" \n",
" p_val = (ests >= ests[0]).mean()\n",
" return PermutationResult(ests[0], p_val, ests)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "bc73065a-9196-4802-a5bc-5dca68c07083",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics.pairwise import rbf_kernel\n",
"\n",
"X = np.random.randn(500, 2)\n",
"Y = X[:, [0, 0]] / 5 + np.random.randn(*X.shape)\n",
"\n",
"K = rbf_kernel(X, gamma=1)\n",
"L = rbf_kernel(Y, gamma=1)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "dabb8428-2871-40ef-8442-6df0d9fb8d0d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor(0.0010), 0.001)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAqQAAAB4CAYAAAAkJeDUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAPHklEQVR4nO3dXWwUVR/H8d9C2wVrWylol5WiVdFECr0oBluVF0FIAxLCha8xJGIiKkhTiHnAi9bEUAIJoGBRogHUaL0ADIlRKVFKCDFiSUNbDcFQoUjXjRH6JmyxnOeCh3lc2sJuu92zL99PMkl3zmk5c2b55z9nzpxxGWOMAAAAAEuG2W4AAAAAkhsJKQAAAKwiIQUAAIBVJKQAAACwioQUAAAAVpGQAgAAwCoSUgAAAFiVYrsBA3HlyhWdO3dOGRkZcrlctpsDIAEZY9TR0SGv16thwxLz2p1YCmAohRNH4zIhPXfunHJzc203A0ASaGlp0bhx42w3Y0gQSwFEQyhxNC4T0oyMDElXDzAzM9NyawDEtK4uyeu9+vO5c1J6eki/1t7ertzcXCfeJCJiKYCQDSCWhhNH4zIhvXZrKTMzkyAK4MaGD///z5mZISek1yTyrWxiKYCQDSKWhhJHE3NiFAAAAOJGXI6Qwr67//PVDct/WzcvSi0BgNhBbAQGJuyE9NChQ9qwYYPq6urU2tqqvXv3auHChU65MUZvvfWWtm/frvPnz2vq1Kl67733NHHiRKdOIBDQqlWr9Pnnn+vixYuaNWuWqqqqEvbBgVhzs4ApETQBAED0hH3LvqurSwUFBdq6dWuf5evXr9fGjRu1detWHT16VB6PR0888YQ6OjqcOqWlpdq7d6+qq6t1+PBhdXZ2av78+erp6Rn4kQAAACAuhT1CWlJSopKSkj7LjDHavHmz3nzzTS1atEiStGvXLuXk5Oizzz7Tyy+/rLa2Nn300Uf65JNPNHv2bEnSp59+qtzcXB04cEBz584dxOEAAAAg3kT0oabm5mb5fD7NmTPH2ed2uzV9+nQdOXJEklRXV6fLly8H1fF6vcrPz3fqXC8QCKi9vT1oAwAAQGKIaELq8/kkSTk5OUH7c3JynDKfz6e0tDSNGjWq3zrXq6ysVFZWlrOxkDOAeHbo0CE9+eST8nq9crlc+vLLL4PKjTGqqKiQ1+vVyJEjNWPGDDU1NQXVCQQCWr58ucaMGaP09HQtWLBAZ8+ejeJRAEDkDMmyT9evN2WMuekaVDeqs3r1arW1tTlbS0tLxNoKANHGXHwACBbRZZ88Ho+kq6OgY8eOdfb7/X5n1NTj8ai7u1vnz58PGiX1+/0qLi7u8++63W653e5INjWhhfIUve028BQ/kpmtufiBQECBQMD5zPSn8A02vhIbgb5FdIQ0Ly9PHo9HNTU1zr7u7m7V1tY6yWZhYaFSU1OD6rS2tqqxsbHfhBQAksVQzcWXmP4EIHaFPULa2dmpX3/91fnc3Nys+vp6ZWdna/z48SotLdXatWs1YcIETZgwQWvXrtUtt9yi5557TpKUlZWlJUuWaOXKlRo9erSys7O1atUqTZo0ybnSB4BkdaO5+KdPn3bqhDsXX7o6/amsrMz5fO090wBgW9gJ6U8//aSZM2c6n68Ft8WLF2vnzp164403dPHiRb366qvOwvj79+9XRkaG8zubNm1SSkqKnnrqKWdh/J07d2r4v9+TCgBJLNJz8SWmPwGIXWEnpDNmzJAxpt9yl8uliooKVVRU9FtnxIgR2rJli7Zs2RLuPw8ACW2o5uIDQCwbkqfsAQADw1x8AMkook/ZAwBujrn4ABCMhBR9ioWlo4BExVx8AAjmMjeaEBqj2tvblZWVpba2NmVmZtpuTsyJh2SStfYQNV1d0q23Xv25s1NKTw/p15IhziTDMUaa7fhK7IQ1A4il4cQY5pACAADAKhJSAAAAWMUcUgAA4gSvHkWiYoQUAAAAVpGQAgAAwCoSUgAAAFhFQgoAAACreKgpDtleBw8AACCSGCEFAACAVYyQxhhGPwFg6LBsEhCbGCEFAACAVSSkAAAAsIpb9gAA/A/TpgA7SEhhRShBn7lcAAAkBxJSAAASBA9tIV4xhxQAAABWkZACAADAKhJSAAAAWEVCCgAAAKtISAEAAGAVCSkAAACsIiEFAACAVaxDGmW8BQQAACAYI6QAAACwihFSAEDC4C7UjfEmJ8QqElLELAInAADJgVv2AAAAsIqEFAAAAFaRkAIAAMAqElIAAABYRUIKAAAAq3jKPsJYcgQAACA8jJACAADAKkZIAQCAJNZ/hj2MkAIAAMAqRkgRt7iSB4DoIu5iqDBCCgAAAKtISAEAAGAVt+wBAHGDpfWAxERCioTFXCcAiC7iLgaKW/YAAACwioQUAAAAVpGQAgAAwCoSUgAAAFjFQ01h4gnPxBHKuWQCPgAAQ4+EFAAQM7joB5ITCSlwAyxhAgDA0GMOKQAAAKxihBQAAEQFd53QH0ZIAQAAYBUjpMAgcLUPAJFDTE1eJKT/wtOdAAAA0UdCCgCIGi78AfQlqRJSAiEAAPGLW/qJy2pCWlVVpQ0bNqi1tVUTJ07U5s2b9dhjj9lsEhBVvC0Kg0UcBZAIrCWkX3zxhUpLS1VVVaVHHnlEH3zwgUpKSvTzzz9r/PjxtpoFRFQkRuWjMSLAqEN8Io4CSBTWEtKNGzdqyZIleumllyRJmzdv1rfffqtt27apsrIyqG4gEFAgEHA+t7W1SZLa29vD+jevBP4eZKuB2HOz/wf55d8O+b8R07q6/v9ze7vU0xPSr107ZmPMULQqIsKJo1LkYumNROL7BgxUXMeqWDeAWBpWHDUWBAIBM3z4cLNnz56g/a+//rqZNm1ar/rl5eVGEhsbG1vUt5aWlmiFxrCEG0eNIZaysbHZ2UKJo1ZGSP/880/19PQoJycnaH9OTo58Pl+v+qtXr1ZZWZnz+cqVK/rrr780evRouVyuIW9vompvb1dubq5aWlqUmZlpuzlxi36MjFjrR2OMOjo65PV6bTelT+HGUSm0WBpr5yHe0H+DRx8OTiz1Xzhx1OpDTdcnk8aYPhNMt9stt9sdtO+2224byqYllczMTOtf2kRAP0ZGLPVjVlaW7SbcVKhxVAovlsbSeYhH9N/g0YeDEyv9F2octfLq0DFjxmj48OG9ruL9fn+vq30AQG/EUQCJxEpCmpaWpsLCQtXU1ATtr6mpUXFxsY0mAUBcIY4CSCTWbtmXlZXphRde0JQpU1RUVKTt27frzJkzWrp0qa0mJR23263y8vJet/AQHvoxMujH8A1FHOU8DA79N3j04eDEa/+5jLG3pklVVZXWr1+v1tZW5efna9OmTZo2bZqt5gBA3CGOAkgEVhNSAAAAwMocUgAAAOAaElIAAABYRUIKAAAAq0hIAQAAYBUJaQLZtm2bJk+e7LydoaioSF9//bVTboxRRUWFvF6vRo4cqRkzZqipqSnobwQCAS1fvlxjxoxRenq6FixYoLNnz0b7UGJKZWWlXC6XSktLnX30ZWgqKirkcrmCNo/H45TTj9FRWVmphx56SBkZGbrjjju0cOFCnThxIqgO56J/xNbIIqaGLyli6U3fdo+4sW/fPvPVV1+ZEydOmBMnTpg1a9aY1NRU09jYaIwxZt26dSYjI8Ps3r3bNDQ0mKefftqMHTvWtLe3O39j6dKl5s477zQ1NTXm2LFjZubMmaagoMD8888/tg7Lqh9//NHcfffdZvLkyWbFihXOfvoyNOXl5WbixImmtbXV2fx+v1NOP0bH3LlzzY4dO0xjY6Opr6838+bNM+PHjzednZ1OHc5F/4itkUNMHZhkiKUkpAlu1KhR5sMPPzRXrlwxHo/HrFu3zim7dOmSycrKMu+//74xxpgLFy6Y1NRUU11d7dT5/fffzbBhw8w333wT9bbb1tHRYSZMmGBqamrM9OnTneBJX4auvLzcFBQU9FlGP9rj9/uNJFNbW2uM4VwMBLE1fMTUgUuGWMot+wTV09Oj6upqdXV1qaioSM3NzfL5fJozZ45Tx+12a/r06Tpy5Igkqa6uTpcvXw6q4/V6lZ+f79RJJq+99prmzZun2bNnB+2nL8Nz8uRJeb1e5eXl6ZlnntGpU6ck0Y82tbW1SZKys7MlcS7CQWwdOGLq4CR6LLX26lAMjYaGBhUVFenSpUu69dZbtXfvXj344IPOFy4nJyeofk5Ojk6fPi1J8vl8SktL06hRo3rV8fl80TmAGFFdXa1jx47p6NGjvcqu9QV9eXNTp07Vxx9/rPvvv19//PGH3n77bRUXF6upqYl+tMQYo7KyMj366KPKz8+XxHc6FMTWwSGmDk4yxFIS0gTzwAMPqL6+XhcuXNDu3bu1ePFi1dbWOuUulyuovjGm177rhVInkbS0tGjFihXav3+/RowY0W89+vLmSkpKnJ8nTZqkoqIi3Xvvvdq1a5cefvhhSfRjtC1btkzHjx/X4cOHe5VxLvpHbB04YurgJUMs5ZZ9gklLS9N9992nKVOmqLKyUgUFBXrnnXecp/GuvxLy+/3OVZXH41F3d7fOnz/fb51kUFdXJ7/fr8LCQqWkpCglJUW1tbV69913lZKS4vQFfRm+9PR0TZo0SSdPnuQ7acHy5cu1b98+ff/99xo3bpyzn3Nxc8TWgSOmRl4ixlIS0gRnjFEgEFBeXp48Ho9qamqcsu7ubtXW1qq4uFiSVFhYqNTU1KA6ra2tamxsdOokg1mzZqmhoUH19fXONmXKFD3//POqr6/XPffcQ18OUCAQ0C+//KKxY8fynYwiY4yWLVumPXv26LvvvlNeXl5QOecifMTW0BFTIy8hY2m0n6LC0Fm9erU5dOiQaW5uNsePHzdr1qwxw4YNM/v37zfGXF0WIisry+zZs8c0NDSYZ599ts9lIcaNG2cOHDhgjh07Zh5//PGYWhbCln8/EWoMfRmqlStXmoMHD5pTp06ZH374wcyfP99kZGSY3377zRhDP0bLK6+8YrKysszBgweDlo35+++/nTqci/4RWyOPmBqeZIilJKQJ5MUXXzR33XWXSUtLM7fffruZNWuWEzCNubo0RHl5ufF4PMbtdptp06aZhoaGoL9x8eJFs2zZMpOdnW1Gjhxp5s+fb86cORPtQ4k51wdP+jI019bCS01NNV6v1yxatMg0NTU55fRjdEjqc9uxY4dTh3PRP2Jr5BFTw5MMsdRljDE2R2gBAACQ3JhDCgAAAKtISAEAAGAVCSkAAACsIiEFAACAVSSkAAAAsIqEFAAAAFaRkAIAAMAqElIAAABYRUIKAAAAq0hIAQAAYBUJKQAAAKz6L8EIHqGN/0SMAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 800x100 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"f, (a1, a2) = plt.subplots(1, 2, figsize=(8, 1))\n",
"\n",
"r1 = hsic_permutation_torch(K, L, n_perm=1_000)\n",
"a1.hist(r1.permuted_estimates, bins=\"auto\")\n",
"a1.axvline(r1.estimate, color=\"r\")\n",
"\n",
"r2 = hsic_permutation_numba(K, L, n_perm=1_000)\n",
"a2.hist(r2.permuted_estimates, bins=\"auto\")\n",
"a2.axvline(r2.estimate, color=\"r\")\n",
"\n",
"r1.p_value, r2.p_value"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "60040945-a203-46fa-8f7f-86e0dad2046e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"358 ms ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"hsic_permutation_torch(K, L).p_value"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7524d67d-b5d1-41c1-941e-d189ae0e013e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"152 ms ± 513 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"%%timeit\n",
"hsic_permutation_numba(K, L).p_value"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment