Skip to content

Instantly share code, notes, and snippets.

@joshfp
Last active June 29, 2019 01:14
Show Gist options
  • Save joshfp/d61521965f3491f654e08512e078f397 to your computer and use it in GitHub Desktop.
Save joshfp/d61521965f3491f654e08512e078f397 to your computer and use it in GitHub Desktop.
cat-embeds-dropout.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "%reload_ext autoreload\n%autoreload 2",
"execution_count": 2,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "from fastai import *\nfrom fastai.tabular import *\nfrom fastai.metrics import accuracy\n\nPATH = Path('~/data/').expanduser()",
"execution_count": 3,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "df = pd.read_feather(PATH/'listings-df')\ndf = df.drop('title', axis=1)",
"execution_count": 5,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "cont_cols = ['col1', 'col2', 'col3', 'col4', 'col5', 'col6',\n 'col7', 'col8', 'col9', 'col10', 'col11', 'col12'] # real columns names were replaced\ncat_cols = sorted(list(set(df.columns) - set(cont_cols) - {'condition'}))\nvalid_idx = range(len(df)-10000, len(df))\nprocs = [FillMissing, Categorify, Normalize]",
"execution_count": 6,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "data = (TabularList.from_df(df, cat_cols, cont_cols, procs=procs, path=PATH)\n .split_by_idx(valid_idx)\n .label_from_df(cols='condition')\n .databunch())",
"execution_count": 7,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Basic model"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn = tabular_learner(data, layers=[64], ps=[0.5], emb_drop=0.05, metrics=accuracy)",
"execution_count": 8,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.lr_find()",
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"text": "LR Finder is complete, type {learner_name}.recorder.plot() to see the graph.\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.recorder.plot()",
"execution_count": 10,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4xLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvDW2N/gAAIABJREFUeJzt3Xl8lNXVwPHfmewhCUlIgoEACZCwI0JAgaosVYG6VK0KrVsXbN2t1VarVV99rXZ1aa1v1WppXRBccCmIVsUdIez7vgXCvi/JbOf9YwYcY0ICzDMzmZzv5zMfZ565zzznOgkn997n3iuqijHGGHM0rmgHYIwxJvZZsjDGGNMgSxbGGGMaZMnCGGNMgyxZGGOMaZAlC2OMMQ2yZGGMMaZBliyMMcY0yJKFMcaYBiVGO4BwycvL0+Li4miHYYwxTcqsWbO2q2p+Q+XiJlkUFxdTUVER7TCMMaZJEZF1jSln3VDGGGMaZMnCGGNMgyxZGGOMaZAlC2OMMQ2yZGGMMaZBliyMMcY0yJKFMcaYBlmyMHFr+ZZ9TKzYgG0dbMyJi5tJeaZ5OeT2sXH3QQ65/fRok4XLJUfeU1X++flaHpqyFLfXT8XaXfz2ol4khJQxxhwbSxYmZq3cuo+Hpyxlf40XVVCFaq+PTbsPsX2/+0i54lbpjBnQnov7FaEKt78yj2nLtjG8awGdW2fw949Ws7/GyyOX9SE50RrTJr7MXr8Lv18pL8519DqOJgsRGQE8BiQAz6jqw7XefwQYGnyZDhSoaraI9AGeBLIAH/Cgqr7sZKwmtuw+6ObH4yrYfdBDl5MyESDBJbRMS6J7YRbtctMpyknD7fUzsaKSh6Ys5Y/vLiMtKYEar5/7L+jBFad1QETIz0jhf/+zhP01Xv7v8n6kJSdEu3rGhM0j7y1nX7WXSdcPdvQ6jiULEUkAngDOAiqBmSLypqouPlxGVX8eUv5G4JTgy4PAlaq6QkTaALNEZKqq7nYqXhM7vD4/N7w4h6rd1bx0zWn065Bz1PKXlLdj5dZ9vPjlBjbsOsjt53ShrHXmkfd/cnpHMlISufP1BVz13AzG/XBAnQnjH5+uYd2OA9x3Xo+vdWsZE8vcXn9EWsxOtiwGACtVdTWAiIwHLgAW11N+DHAvgKouP3xQVTeJyFYgH7BkEUY+v3L1czPw+ZW7vtONHm1afuP9L1fvoHXLVDrlZ0QsrgcnL+HTldv5/fd6N5goDutckMk953Wv9/3RA9qTnpLILePn8LPnZ/H0leVf+wV79tM1PPB24EczOz2ZW88qO7FKGBMhHp+f9GTnRxScvEJbYEPI60rg1LoKikgHoAT4oI73BgDJwCoHYmzWXvxyHZ+s2E6L5ATO/cunjO7fntvOLiPBJbw8cwP/nr6Oyl2HADi9NI8fDi5mSFnBkb+6VZU9hzykJiWQmhSerp0JMzfw3Gdr+dHgEi4tbxeWzzzs/JPbcLDGyx2vLeDWCXN5bPQpJLiEV2dVcv/bizmnR2syU5N4/P0V9GrbkrO6t/7a+aqKKtbqMDHF7fOT3cRbFnX9RtV3D+No4BVV9X3tA0QKgX8DV6mq/xsXELkGuAagffv2JxZtM7N1XzW/n7qMwZ1b8bfv9+PxD1Yw7vO1vD1vEx6/n2qPnwElufxyRFfWbj/A89PX8aN/VtA+N52CzBQ2761m694a3L7A15KWlEBui2TyMpI5v09bLj+tPSmJx5ZAZq3byV2TFnB6aR6/HtXViWozekB79lZ7+O3kpWSlJTGkLJ9fvjqfwZ1b8djoQC/o8i37+PnLc3njhsF0ys9AVZm6aAsPT1lCcqKLJ77fl9KQbi6AZZv3cdvEeeyt9jC0SwHDuxUwoCT3mP8fGHOsPF4lKcH5P2DEqXvQRWQgcJ+qnhN8fSeAqj5UR9k5wPWq+nnIsSxgGvCQqk5s6Hrl5eVq+1k03i3j5zB5wWbeueV0Oga7mFZu3c/j76+gRUoCV5xWTPc2WUfKe3x+3lm4mQkVG/D6lNZZKbRumUp+Rgpun5+d+93sPOhmzfYDzFm/m7bZafzi7DIu6NO2Ubesbtx9iAv++ikZKYlMun4w2enJjtUd4PfvLOVv01YhAr2LsnnhJ6eSkZJ4JJbz/vIpOelJ/PbCXvz5veV8uWYnpQUZ7Dro5kCNj4cv7sUFfdoCMKFiA/e8sZCMlCR6F7Xks5XbqfH6aZGcwLVDOnH90M6IWGvEOGPoH6fRs21L/jLmlIYL10FEZqlqeUPlnGxZzARKRaQE2Eig9fD92oVEpAuQA3wRciwZeB34V2MShTk2n6/czqS5m7hpWOcjiQKgc0EGj9fzA5eU4OK8k9tw3sltGvz8T1Zs43fvLOXWCfN46uPVXDukEyN6nlTvX9kH3V7GjqugxuNn/DXljicKgNvP6YLb62fuht08fWX5kUQB0DY7jb9+/xSu+McMLntqOrktkvnf7/ZkdP927Djg5sYX53Dz+LnMWLOTao+fV2dXMrBjKx4b04eCzFQOuX18sXo742ds4I/vLueA28cvz+liCcM4wu31R6Rl4ViyUFWviNwATCVw6+yzqrpIRO4HKlT1zWDRMcB4/XoT51LgDKCViFwdPHa1qs51Kt7mosbr4+43FtI+N53rhnZ25Bqnl+YzuFMeby+o4pH3lnPz+LnkZSRzWf92fP/UDrTNTjtS1u9Xbps4jyWb9/LsVf3pXJB5lE8OHxHh7nPrHxAf1CmPP196Mqu2HeAnp5eQlZoEQOusVF4ceyp/eHcZf/9oNSJw8/BSbhpeeqQFlZacwLCurRlSVsA9by7kyWmr8PmVO0d2tYRhws7t85PSxMcsUNXJwORax+6p9fq+Os57HnjeydiaqyenrWL1tgP884f9wzYoXReXSzj/5Dac26uQT1Zu599frOPJaav427RVdGmdSb8OOfTrkMOyLfuYvGAzvx7VlaFdCxyL53gc7maqLTHBxZ0ju3FmWT4piS76dah7MpTLJTxwQU8SXS6e+ng1Xp/ym3O7WcIwYeXx+UlOaOLJwsQOVeVv01bx6H9XcP7JbRjSJTL/MLtcwpll+ZxZlk/lroO8PnsjM9bu5I25m3jhy/UAXNS3LWNP7xiReMJpUKe8BsuICPee1x2XCM9+toakROHOkd0iEJ1pLgLdUJYsTBj4/Mr9by1i3Bfr+G6fNvz+eydHJY6inHRuHF56JKblW/axatt+zureOq7/2hYRfnNuNzw+P3//aDUlrVoweoDdvWfCw+Nr+pPyTBS8MXcjn67YTpeTMulemEWnggzuf2sx/1lQxdjTS7hzZLeYmCeQ4BK6FWbRrTCr4cJx4HALY93Og9w9aSHtW6U3qmVizNH4/YrHp9ayMMdmX7WHuyctpMbjPzL/4bC7RnVj7BlNr6snniQmuPjr90/h4r99zrXPz+b16wZ97W40Y46Vxx/4PbeWhTkmz09fz75qL2/eMJjClmksqdrL0s176XJSFmeW5Uc7PANkpSbx7NX9+e4Tn/HjcRW8fM1pFGSlRjss00S5vcFkYS0L01jVHh//+HQ1p5fm0bsoG4D8zHzOsCQRc9rlpvPUlf0Y89SXDHz4AwZ1asW5vQs5p8dJEZljYuLHkWQRgZaFLe4fJyZUbGD7fjfXOzR3woRXvw65TL75W/z0jI6s33mQX726gPL//S8TKzY0fLIxQR5fYHpaJMYsLFnEgcN32fRtn82pJc5ugGLCp3NBJr8c0ZVptw3h7Ru/xSnts/mftxazeU91tEMzTYS1LMwxeXPuJjbuPmRrEDVRIkLPti350yV98Pr93PvmwmiHZJqIwzeyRGK5D0sWTZzfrzz50Sq6npTJsBibAW2OTftW6dw8vIypi7bwzsLN0Q7HNAGHWxaRWO7DkkUT9+7iLazcup/rrFURF35yegldT8rkvjcXsa/aE+1wTIzzHGlZWLIwR+Hx+fnTu8sobpXOqJ4nRTscEwZJCS4evrg3W/ZV88epy6Idjolxh7uhbJ6FOapxn69lxdb9PH1lOYkR+MvCREafdtlcNbCYcV+sJS05kf7FOfRpl02rjJRoh2ZijMcbuZaFJYso27T7ENOWbePifm2PaVe1rXurefS/KxjaJZ9vd7Oxinhz2zldWL5lH09/spr/+yhwe2SHVukMKM5lUOdWDOyYx0ktbTJfc1djLYvmodrj48fjKlhStZd/fbGWRy7r0+i1kh6ashS318+95/WwsYo4lJGSyItjT+OQ28eCjXuYs34Xs9bt4t3FW5g4qxKAkrwWdMpvQdvsNNpkp9EpP4NhXQtiYu0vExkem8HdPDzw9mKWVO3lpmGdeXHGei7462fcenYZY0/veNStSGes2cnrczZyw9DOFOe1iGDEJtLSkhMYUJLLgOD8Gb9fWVy1l+mrdzBjzU7W7zzIl2t2sq/aCwR2ALSJmc2HjVk0A2/NC+zn8LMzO3Hr2V24alAxd72+kIenLGXKgiquHlzMyJ6F39igyOvzc88bC2nTMpXrhnaKUvQmWlyuwJyMnm1b8pOQPUD2Vnv41Svzeez9FYzqVUiJ/RHRLBy+G8paFnFq7fYD3PnaAvp1yOEXZ5cB0CojhScv78vrczby6H9X8POX53HPG4v4bp+2DOrUii17q9m4+xBLqvaxdPM+nvxBX9KT7eszAVmpSfzP+T34dOV2fv3aAl4ce6p1TzYDh+dZJFnLIv5Ue3xc/+JsEhOEv4w55Wt3MYgIF/Ut4rt92jJ9zQ4mzNzAhIoN/Hv6OiAw8aZtThpjTy9hhN0qa2opyErlzpHd+PXrC5g4q5JLy9tFOyTjMHdwbShrWcShx95fwaJNe3n26nLaZKfVWcblEgZ1ymNQpzz+55CHdTsOUNgyjbyMZPtr0RzV6P7tmDRnIw/+ZwlDuxSQn2m328azSC5RbjfnR9DyLft4+uPVXNKviGFdWzfqnJZpSfQuyiY/M8UShWmQyyX89qJeHHL7eODtxdEOxzjME8EBbksWEaKq3D1pIRmpidw5qlu0wzFxrHNBBtcP7cyb8zbx9vxN0Q7HOOjImIUtJBg/Xp29kRlrdnLHiK7ktrANboyzrh3SiX4dcvjlK/NZtnlftMMxDnF7/biEiKzgYMmilp0H3Fzxjy/51Svz+c/8KnYfdAOwfX8NL81Yz1XPzqD3fVOZtW5noz9z90E3v528hL7ts23Q0UREcqKLv/2gLy1SEvnpvyvYc8gWJYxHHp8/Ikt9gA1wf8NDk5fw+aodpCcn8HLFBkQCM2XXbj+AX6F9bjoJLuHhKUuZ8NOBjRpH+N07y9hzyMODF/ay2bUmYlpnpfLkD/oy+qnp3DJ+Dv+4qr/9/MWZGq8/IuMVYMnia75cvYOJsyr52ZmduO3sMuZV7uGTFduYu2E35/Zuw4geJ9GtMJPnv1zPbyYt5KPl2xjSpf51mTbvqealGet5acZ6fvKtkkYv5WFMuJQX53LPed25541FPPr+Cm49qyzaIZkw8vj8EbkTCixZHOH2+rlr0kKKctK4eXgpiQku+nXIoV+HnG+Uvay8HX//aBV/enc5Z5blf611oapMW7aNF75czwdLt+BXGNa1gFvsl9REyRWndWDehj08/v4KSgsyOO/kNtEOyYSJ21oWkff0J6tZuXU/z15dTlry0Vd/TU50cfPwUm5/ZT5TF21mRM9CILBuz12TFvLSjPXkZSTz0zM7Mbp/Ozq0sqUXTPSICA9e2JMNOw/yiwnzaNUimUGd86IdlgmDSI5Z2AA3sH7HQR5/fwUje57U6PkPF57Slo75LfjTu8vx+RWvz89tr8zjpRmB9Z4+v2M4vxrR1RKFiQmpSQk8fWU5xXnpXPPvWSzatCfaIZkwcPsi17Jo9slCVfnNGwtJdAn3nNe90eclJri49awyVmzdz6uzK7n55bm8NnsjvzirjDtGdo3YF2hMY7VMT+KfPxxAZmoiVz83kw07D0Y7JHOC3F6Nj5aFiIwQkWUislJE7qjj/UdEZG7wsVxEdoe8d5WIrAg+rnIqxjXbDzB99Q5+cXYXClvWvfxGfUb1LKRbYRa/ejVwm+1do7px4/BShyI15sS1yU5j3I8GUOPxcdVzM9h1wB3tkMwJiIuWhYgkAE8AI4HuwBgR+dqf7qr6c1Xto6p9gL8ArwXPzQXuBU4FBgD3isg3R5rDoGN+Bv+99UyuHNjhmM91uYQ7R3YlOcHF/Rf0YOwZHRs+yZgoK2udyTNX9ady5yGufWHWkVnApunxeP2kxEHLYgCwUlVXq6obGA9ccJTyY4CXgs/PAd5T1Z2qugt4DxjhVKDtctOPewbkGWX5LLjvHK4cWBzeoIxx0ICSXH73vV5MX72TuyctQFWjHZI5Dm6fn6TEyMydcTJZtAU2hLyuDB77BhHpAJQAHxzrubHAxidMU3ThKUXcOKwzEyoqefqT1dEOxxyHSM6zcPIqdaW7+v58GQ28oqq+YzlXRK4RkQoRqdi2bdtxhmlM8/Xzb5fxnd6FPDRlKe8u2hztcMwxcnvj49bZSiB0IaQioL4lMEfzVRdUo89V1adUtVxVy/Pz808wXGOaH5dL+NMlJ9O7KJtbXp7L8i226GBTEhcD3MBMoFRESkQkmUBCeLN2IRHpAuQAX4QcngqcLSI5wYHts4PHjDFhlpqUwNNX9CM9OZFrn5/FgRpvtEMyjeT2xkE3lKp6gRsI/CO/BJigqotE5H4ROT+k6BhgvIaMsKnqTuABAglnJnB/8JgxxgEFWak8PqYPa4L7w9uAd9MQN8t9qOpkYHKtY/fUen1fPec+CzzrWHDGmK8Z1CmPX5zdhT9MXUb/klyuOO3Ybyc3kWXLfRhjouLaMzsxtEs+D7y1mPmVuxs+wURVJFsWliyMMUe4XMKfL+1DfmYK170w28YvYpzHFyfLfRhjmp6cFsk8OroPlbsO8cSHK6MdjqmHqsbN3VDGmCaqf3EuF53Slmc+WcPa7QeiHY6pg8cXuAkhOaHpz+A2xjRhh1dPvv/txdEOxdTB7Qus6WUtC2NMVBVkpXLT8M58sHQrHyzdEu1wTC2e4AKQNmZhjIm6qweV0DG/Bfe/tZgar6/hE0zEWMvCGBMzkhNd3HdeD9buOMgzn6yJdjgmxOGl5Zv8DG5jTHw4oyyfs7u35vH3V1h3VAyxloUxJuY8dFEvylpnMvZfs3h1VmW0wzEEZm+DtSyMMTGkVUYKL11zGgM7tuIXE+fx949W2fpRUea2AW5jTCzKSEnk2av7c97JbXhoylIenrI02iE1a54Id0M5upCgMSa+JCe6eOyyPmSnJfH3j1fTqSCDS8vbNXyiCbsaa1kYY2KZyyXcd34PBnduxW8mLWRJ1d5oh9QsHbkbyga4jTGxKsElPHrZKbRMS+K6F2azr9oT7ZCana+W+7BkYYyJYfmZKfxlzCms22EbJkWDtSyMMU3GqR1bcds5XXh7fhX/nr4u2uE0K4cHuJNsIUFjTFPwszMCGyY9+J8lbNp9KNrhNBvWsjDGNCkul3D/BT1RhT+/tzza4TQbbpuUZ4xpatrlpnP14GJenV3J4k12d1QkWMvCGNMkXT+kM1mpSTz8jk3Wi4SvxiwsWRhjmpCW6UncOKwzHy/fxicrtkU7nLhnLQtjTJN1xcAOFOWk8dvJS/H77VZaJ3l8fkQg0WV3QxljmpiUxARuP6cLS6r2MmnuxmiHE9dqfH6SElyIWLIwxjRB5/VuQ++iljz63xU2Uc9BHq+SEqHxCrBkYYwJM5dLuPy0DqzfeZD5lXuiHU7ccvt8JEVovAIsWRhjHHB299YkuoTJC6uiHUrc8ng1YnMswJKFMcYB2enJDOqcx+QFVdYV5RC3z09SYmTGK8CShTHGId/pdRIbdh5ikU3Sc4Tb57eWhTGm6Tu7+0kkuIT/LLCuKCe4vf6ITcgDh5OFiIwQkWUislJE7qinzKUislhEFonIiyHHfx88tkREHpdI3R9mjAmLnBbJDOrUyrqiHOL2+kmJhwFuEUkAngBGAt2BMSLSvVaZUuBOYLCq9gBuCR4fBAwGegM9gf7AmU7FaoxxxqhehazbcZDFtpte2Hl88dOyGACsVNXVquoGxgMX1CozFnhCVXcBqOrW4HEFUoFkIAVIArY4GKsxxgHn9Ah0RU22rqiwc3v9EVvqA5xNFm2BDSGvK4PHQpUBZSLymYhMF5ERAKr6BfAhUBV8TFXVJQ7GaoxxQG6LZE7rmMvkBZutKyrM4qllUdcYQ+2flkSgFBgCjAGeEZFsEekMdAOKCCSYYSJyxjcuIHKNiFSISMW2bbZwmTGxaFSvQtZsP8DSzfuiHUpcqYmjlkUl0C7kdRGwqY4yb6iqR1XXAMsIJI8Lgemqul9V9wNTgNNqX0BVn1LVclUtz8/Pd6QSxpgTc06Pk3AJTLGuqLDyxNGtszOBUhEpEZFkYDTwZq0yk4ChACKSR6BbajWwHjhTRBJFJInA4LZ1QxnTBOVlpHBqSSvemm93RYWT2xcnLQtV9QI3AFMJ/EM/QVUXicj9InJ+sNhUYIeILCYwRnG7qu4AXgFWAQuAecA8VX3LqViNMc66sG9b1mw/wOz1u6MdStzweJWkhMjNKEh08sNVdTIwudaxe0KeK3Br8BFaxgf81MnYjDGRM6pXIfe+sYhXZ1fSr0NOtMOJC3HTsjDGmMMyUhIZ2fMk3pq3iWqPL9rhxAWP109yQkLErteoZCEinUQkJfh8iIjcJCLZzoZmjIkn3+tXxL5qL+8ttilT4VATowsJvgr4gre0/gMoAV48+inGGPOV0zq2om12Gq/Mqox2KE2equLx+WNy8yN/cMD6QuBRVf05UOhcWMaYeONyCRf1bcsnK7axZW91tMNp0rx+RZWYnJTnEZExwFXA28FjSc6EZIyJVxf1LcKv8Poc25/7RHh8foCYHOD+ITAQeFBV14hICfC8c2EZY+JRSV4Lyjvk8MqsSptzcQLc3kCyiLmWhaouVtWbVPUlEckBMlX1YYdjM8bEoe/1K2Ll1v3Ms/25j9vhZBFzLQsRmSYiWSKSS2CS3HMi8mdnQzPGxKNRvQtJSXTxyqwNDRc2dXIf7oaKtZYF0FJV9wIXAc+paj/g286FZYyJV1mpSXy7e2veWbgFv9+6oo5HzLYsgEQRKQQu5asBbmOMOS5ndWvN9v01LNhoXVHHw+MLJNmYG7MA7iewjtMqVZ0pIh2BFc6FZYyJZ2eW5eMSeH/p1oYLm2+I2ZaFqk5U1d6qem3w9WpVvdjZ0Iwx8SqnRTL9OuTwwVKbzX08Do9ZRHIhwcYOcBeJyOsislVEtojIqyJS5HRwxpj4NaxraxZu3MvmPTZB71jFbMsCeI7AXhRtCOxc91bwmDHGHJfh3QoA+MC6oo6ZJ4bvhspX1edU1Rt8/BOwremMMcettCCDopw064o6DrHcstguIpeLSELwcTmww8nAjDHxTUQY3rWAT1dut2XLj1EsL/fxIwK3zW4GqoDvEVgCxBhjjtuwbq2p9vj5YpX97XksvhrgjrFkoarrVfV8Vc1X1QJV/S6BCXrGGHPcTi3JJT05gfetK+qYHOmGirVkUY9bGy5ijDH1S01K4PTSPD5YstUWFjwG7hjuhqpL5G7wNcbEreFdW7NpTzVLN++LdihNhqeJtSzszwBjzAkb0jVwY6XdQtt4R8YsYqVlISL7RGRvHY99BOZcGGPMCSnITOWU9tm8NGM9B2q80Q6nSYi5MQtVzVTVrDoemaqaGKkgjTHx7c6R3ajcdYg/TF0W7VCaBPeRhQRjbLkPY4xx0oCSXK4c2IFxX6ylYu3OaIcT89xeP8kJLkQsWRhjmplfjuhKm5Zp/PLV+TZJrwEenz+irQqwZGGMiREZKYk8dFEvVm87wOPv2w4IR+P2+iN62yxYsjDGxJAzyvK5pF8Rf/94NQttY6R6BVoWliyMMc3Y3d/pTk56Eo/+11oX9bGWhTGm2WuZnsTwrq2ZuXan7dFdD7fPH9HbZsGShTEmBpUX57DnkIeV2/ZHO5SYFHctCxEZISLLRGSliNxRT5lLRWSxiCwSkRdDjrcXkXdFZEnw/WInYzXGxI7+xbkAVKzdFeVIYlNcjVmISALwBDAS6A6MEZHutcqUAncCg1W1B3BLyNv/Av6gqt2AAYCtBWBMM9GhVTp5Gck256Iebl98tSwGACtVdbWquoHxwAW1yowFnlDVXQCquhUgmFQSVfW94PH9qnrQwViNMTFERCjvkEvFOmtZ1MXj1bgas2gLbAh5XRk8FqoMKBORz0RkuoiMCDm+W0ReE5E5IvKHYEvFGNNMlBfnsH7nQbbsrY52KDGnxueP6CKC4GyyqGt6Ye1bGxKBUmAIMAZ4RkSyg8dPB24D+gMdgau/cQGRa0SkQkQqtm3bFr7IjTFRV27jFvXyeOPrbqhKoF3I6yJgUx1l3lBVj6quAZYRSB6VwJxgF5YXmAT0rX0BVX1KVctVtTw/P9+RShhjoqNHmyxSk1xUrLNxi9oCYxbxs9zHTKBUREpEJBkYDbxZq8wkYCiAiOQR6H5aHTw3R0QOZ4BhwGIHYzXGxJikBBd92mVby6IOnniaZxFsEdwATAWWABNUdZGI3C8i5weLTQV2iMhi4EPgdlXdoao+Al1Q74vIAgJdWk87FasxJjb1L85lcdVe2+eiFrc38rfOOronhapOBibXOnZPyHMlsJf3N/bzDt4J1dvJ+Iwxsa1fhxx8fmXuht0M7pwX7XBiRtxNyjPGmBPRt0MOIjbIXZs7niblGWPMicpKTaJL60wb5K7F7fWTYi0LY4z5Sv/iXGav24XX5492KDEjrpb7MMaYcCgvzuGA28fSzfuiHUpM8Pr8+BUbszDGmFBfTc6zrigAjy8wt9laFsYYE6Jtdhpts9P4ZMX2aIcSE9zeQHectSyMMaaWc08uZNrybWzdZ+tEuYNjN8kJ8TOD2xhjwuKSfkX4/Mobc2qvGNT8HEkW1rIwxpiv61yQSZ922UyctYHAXN7my2PdUMYYU79LyotYvmU/8yv3RDuUqDrcsrABbmOMqcN5J7chJdHFxFkbGi4cx44McFuyMMaYb8pKTWJEz5N4c+4mqj2+aIcTNUdaFtYNZYwxdftevyL2Vnt5b/GWaIdT/99DAAAS70lEQVQSNYfHLFKsZWGMMXUb1CmPNi1TmTirMtqhRI21LIwxpgEJLuHifkV8smIbVXsORTucqLAxC2OMaYTv9StCFZ6fvi7aoUSFx+6GMsaYhnVo1YLzTm7Dk9NW8WkzXAKkxuZZGGNM4zx0US865Wdw40uz2bDzYLTDiajDCwlaN5QxxjQgIyWRp64sx+tXfvb8rGZ1K60tJGiMMcegJK8Fj17Wh0Wb9vLr1xY0m2VAvhqzsIUEjTGmUYZ3a83Pv13Ga3M2Mn5m85jZbS0LY4w5DjcO68yAklz+OHUZ+2u80Q7HcbY2lDHGHAeXS7hzZFd2HHDzj0/WRDscx9k8C2OMOU6ntM/hnB6teerjVezYXxPtcBzl8flJShBcLhuzMMaYY3b7OV045PHx1w9XRjsUR7m9/oh3QYElC2NMnOhckMkl/drxwvT1cT33wuPzR3xwGyxZGGPiyC1nlYLAI+8tj3YojnH7rGVhjDEnpLBlGlcPKub1uRtZUrU32uE4wu3ViA9ugyULY0ycuW5IJ9KSEvjXF/G50KDbuqGMMebEZacnM6xrAe8u2ow3OCchXizcuIdPV2wjPyMl4td2NFmIyAgRWSYiK0XkjnrKXCoii0VkkYi8WOu9LBHZKCJ/dTJOY0x8GdWrkB0H3MxYuzPaoYRNxdqdjHlqOunJifzue70jfn3HkoWIJABPACOB7sAYEeleq0wpcCcwWFV7ALfU+pgHgI+citEYE5+GdMknNcnFlAWbox1KWHyyYhtX/GMG+ZkpTPzZQEryWkQ8BidbFgOAlaq6WlXdwHjgglplxgJPqOouAFXdevgNEekHtAbedTBGY0wcSk9OZFjXAqYs3IzP37QXGPx4+TZ+/M8KivNa8PJPB9ImOy0qcTiZLNoCoSt7VQaPhSoDykTkMxGZLiIjAETEBfwJuP1oFxCRa0SkQkQqtm3bFsbQjTFN3ciehWzfX0NFE++Keu6zNeRnpjB+7GnkZ0Z+rOIwJ5NFXXPRa6f4RKAUGAKMAZ4RkWzgOmCyqh51GUlVfUpVy1W1PD8/PwwhG2PixbCuBaQkupiysGl3RS3dvI/+xTm0TE+KahxOJotKoF3I6yJgUx1l3lBVj6quAZYRSB4DgRtEZC3wR+BKEXnYwViNMXGmRUoiQ7rkM2VhFf4m2hW1+6Cbqj3VdCvMinYojiaLmUCpiJSISDIwGnizVplJwFAAEckj0C21WlV/oKrtVbUYuA34l6rWeTeVMcbUZ1SvQrbsrWH2+l3RDuW4LA5OLOwaz8lCVb3ADcBUYAkwQVUXicj9InJ+sNhUYIeILAY+BG5X1R1OxWSMaV6GdS0gOcHF5CZ6V9SSqn0AdCvMjHIkgTEDx6jqZGByrWP3hDxX4Nbgo77P+CfwT2ciNMbEs8zUJM4oy2PKwiru/k63iC/rfaKWVu0lLyOZgszUaIdiM7iNMfFtVK9CqvZUM7dyd7RDOWZLNu+NifEKsGRhjIlzw7u1JjXJxUtfro92KMfE6/OzfMt+SxbGGBMJLdOSuKy8HZPmbqRqz6Foh9Noq7cfwO310/Wk6I9XgCULY0wzMPaMjvgVnmlCe3QfXmLdWhbGGBMhRTnpXHByG16asZ5dB9zRDqdRllTtIylB6JSfEe1QAEsWxphm4mdDOnHQ7WPcF2ujHUqjLKnaS+eCzKjsXVGX2IjCGGMcVtY6k293a80/P1/LgRpvtMNp0JKqvXSLkfEKsGRhjGlGrh3Sid0HPYyfedRl56Jux/4atu6riZnxCrBkYYxpRvp1yOHUklye+WQ1bm/s7qK3dPPhmduWLIwxJiquHdKJqj3V3DphLjtjdLD7qzuhrBvKGGOi4syyfG49q4ypizbz7T9/xBtzNxJYeSh2LK7aS0FmCq2isNd2fSxZGGOaFRHhpuGlvH3j6bTLTefm8XP58bgK9hz0RDu0I5ZU7YuJlWZDWbIwxjRLXU7K5LVrB/Gbc7vz8fJt/H7q0miHBIDH52fl1n0x1QUFliyMMc1Ygkv48bdKuPy0DoyfuYEVW/ZFOyRWbduPx6d0t5aFMcbElpuGl5KelMDDU6Lfuoi1ZT4Os2RhjGn2clskc93Qzry/dCufr9oe1Vimr9pJcqKLjnktohpHbZYsjDEG+OHgYtpmp/HbyUuitmf3Zyu383LFBi4rb0diQmz98xxb0RhjTJSkJiVw+zldWLhxL2/M2xjx6+856OEXE+bRMb8Fvx7VLeLXb4glC2OMCTr/5Db0bJvFH95Zxt7qyN1Kq6r8etICtu+v4dHL+pCWnBCxazeWJQtjjAlyuYS7v9Odqr3VnPH7D3n8/RXsOeR80pg0dyP/mV/FLd8upXdRtuPXOx6WLIwxJsRpHVsx6brBlHfI4c/vLedbD3/An99bjsfnzFpSG3Ye5J5JiyjvkMO1Qzo7co1wSIx2AMYYE2tObpfNM1f1Z9GmPfzl/ZU8/v4KMlISuOaMTmG9zn/mV3H3pAUAPHJZHxJcEtbPDydrWRhjTD16tGnJ/13RjyFd8vnrByvZfTA8Cw/uPujmppfmcP2Ls2mXm87r1w+iXW56WD7bKZYsjDGmAb8a0ZV9NV6e+HDlCX/WzLU7OfuRj5m8oIpbzyrj1WsH0bkgtpb2qIslC2OMaUC3wiwu7lvEuM/XsWHnweP+nA+XbuXyZ74kIyWRSdcP5qbhpSTF2HyK+jSNKI0xJspuPasMEfjTu8uO6/y3529i7L8qKG2dwcSfDaRn25ZhjtBZliyMMaYR2mSn8aNvlTBp7iYWbtxzTOeOn7GeG1+aQ9/2Obw49rSY2qeisSxZGGNMI107pBM56Uk8NGVJozZM8vuVR95bzh2vLeCM0nzG/WgAWalJEYg0/CxZGGNMI2WlJnHjsFI+W7mDcx79mHGfr613pveegx5+PG4mj72/gov7FvH0leUxOTO7sWyehTHGHIOrBxWTkZrIC9PXce+bi3h4ylJG9SqkvDiHHm2yKGudyapt+7n2+dlU7TnEA9/tyeWntkckdudQNIY4ufesiIwAHgMSgGdU9eE6ylwK3AcoME9Vvy8ifYAngSzABzyoqi8f7Vrl5eVaUVER5hoYY0z9FlTu4cUZ63h7fhX7qr1AYEMll0CrFin87fK+9G2fE+Uoj05EZqlqeYPlnEoWIpIALAfOAiqBmcAYVV0cUqYUmAAMU9VdIlKgqltFpAxQVV0hIm2AWUA3Vd1d3/UsWRhjosXvVzbsOsjiTXtZtGkv+2u8XD+0M/mZsT+Q3dhk4WQ31ABgpaquDgY0HrgAWBxSZizwhKruAlDVrcH/Lj9cQFU3ichWIB+oN1kYY0y0uFxCh1Yt6NCqBSN7FUY7HEc4OcDdFtgQ8royeCxUGVAmIp+JyPRgt9XXiMgAIBlYVcd714hIhYhUbNu2LYyhG2OMCeVksqhrNKd2n1ciUAoMAcYAz4jIkfV5RaQQ+DfwQ1X9xpKPqvqUqparanl+fn7YAjfGGPN1TiaLSqBdyOsiYFMdZd5QVY+qrgGWEUgeiEgW8B/gblWd7mCcxhhjGuBkspgJlIpIiYgkA6OBN2uVmQQMBRCRPALdUquD5V8H/qWqEx2M0RhjTCM4lixU1QvcAEwFlgATVHWRiNwvIucHi00FdojIYuBD4HZV3QFcCpwBXC0ic4OPPk7Faowx5ugcnWcRSXbrrDHGHLvG3jpry30YY4xpkCULY4wxDYqbbigR2Qasq+OtlkDt9YRrHwt9Xdfz0GN5wPbjCLGuOBpbJhx1CH1+vHU4WoyNKXO0mBt6Xfu7iJU61HUsVr6Lo71/vN9FLP881XXMfrcb1kFVG557oKpx/QCeauhY6Ou6ntc6VhGuOBpbJhx1qFWf46pDuOtxLK9rfxexUodY/i6O9v7xfhex/PN0PN+F/W43/tEcuqHeasSxtxp4XtdnhCOOxpYJRx0aG0NDwlmPY3lt30XjYmns+8f7XcTyz1Ndx+x3O0ziphsqUkSkQhtx50AsszrEjnioRzzUAeKjHk7WoTm0LMLtqWgHEAZWh9gRD/WIhzpAfNTDsTpYy8IYY0yDrGVhjDGmQc02WYjIsyKyVUQWHse5/URkgYisFJHHJWS/RBG5UUSWicgiEfl9eKOuM5aw10NE7hORjSFLrYwKf+Rfi8OR7yL4/m0iosG1xxzl0HfxgIjMD34P7wY3A3OMQ3X4g4gsDdbj9dCVpZ3gUB0uCf5O+0XEsXGNE4m9ns+7SkRWBB9XhRw/6u9NnZy6zSrWHwTWnuoLLDyOc2cAAwkswz4FGBk8PhT4L5ASfF3QROtxH3BbU/4ugu+1I7D+2DogrynWA8gKKXMT8H9NsA5nA4nB578DftcE69AN6AJMA8pjLfZgXMW1juUCq4P/zQk+zzlaPY/2aLYtC1X9GNgZekxEOonIOyIyS0Q+EZGutc8L7rGRpapfaOD/+r+A7wbfvhZ4WFVrgtfY6mwtHKtHRDlYh0eAX/LNfVQc4UQ9VHVvSNEWOFwXh+rwrgYWFgWYTmC7gqZWhyWquszJuE8k9nqcA7ynqjs1sBvpe8CI4/3db7bJoh5PATeqaj/gNuBvdZRpS2AfjsNCdwAsA04XkS9F5CMR6e9otPU70XoA3BDsNnhWRKKx4/wJ1UECKxtvVNV5TgfagBP+LkTkQRHZAPwAuMfBWOsTjp+nw35E4C/ZSAtnHSKtMbHXpb7dSo+rnk7uwd2kiEgGMAiYGNJ9V9du60fbATCRQHPvNKA/MEFEOgazd0SEqR5PAg8EXz8A/InAL3lEnGgdRCQduItA90fUhOm7QFXvAu4SkTsJLPt/b5hDrVe46hD8rLsAL/BCOGNsSDjrEGlHi11EfgjcHDzWGZgsIm5gjapeSP31Oa56WrL4igvYrapf2zdDRBKAWcGXbxL4hzS0GR26A2Al8FowOcwQET+BtVoiuUH4CddDVbeEnPc08LaTAdfhROvQCSgB5gV/wYqA2SIyQFU3Oxx7qHD8TIV6kcDukRFLFoSpDsHB1XOB4ZH84yko3N9DJNUZO4CqPgc8ByAi04CrVXVtSJFKAltWH1ZEYGyjkuOpp1MDNU3hARQTMpAEfA5cEnwuwMn1nDeTQOvh8ODQqODxnwH3B5+XEWgCShOsR2FImZ8D45taHWqVWUsEBrgd+i5KQ8rcCLzSBOswAlgM5EfiO3Dy5wmHB7iPN3bqH+BeQ6C3Iyf4PLcx9awzrkh9ebH2AF4CqgAPgUz7YwJ/jb4DzAv+cN9Tz7nlwEJgFfBXvprcmAw8H3xvNjCsidbj38ACYD6Bv7gKm1odapVZS2TuhnLiu3g1eHw+gfV/2jbBOqwk8IfT3ODD6Tu6nKjDhcHPqgG2AFNjKXbqSBbB4z8K/v9fCfzwWH5vaj9sBrcxxpgG2d1QxhhjGmTJwhhjTIMsWRhjjGmQJQtjjDENsmRhjDGmQZYsTFwTkf0Rvt4zItI9TJ/lk8BqswtF5K2GVmsVkWwRuS4c1zamNrt11sQ1Edmvqhlh/LxE/WpRPEeFxi4i44DlqvrgUcoXA2+ras9IxGeaF2tZmGZHRPJF5FURmRl8DA4eHyAin4vInOB/uwSPXy0iE0XkLeBdERkiItNE5BUJ7NPwwuH9AILHy4PP9wcXAZwnItNFpHXweKfg65kicn8jWz9f8NUiiRki8r6IzJbAngQXBMs8DHQKtkb+ECx7e/A680Xkf8L4v9E0M5YsTHP0GPCIqvYHLgaeCR5fCpyhqqcQWN31tyHnDASuUtVhwdenALcA3YGOwOA6rtMCmK6qJwMfA2NDrv9Y8PoNrskTXMNoOIHZ9ADVwIWq2pfAHip/CiarO4BVqtpHVW8XkbOBUmAA0AfoJyJnNHQ9Y+piCwma5ujbQPeQVTyzRCQTaAmME5FSAqtwJoWc856qhu4zMENVKwFEZC6B9Xw+rXUdN18twjgLOCv4fCBf7R/wIvDHeuJMC/nsWQT2I4DAej6/Df7D7yfQ4mhdx/lnBx9zgq8zCCSPj+u5njH1smRhmiMXMFBVD4UeFJG/AB+q6oXB/v9pIW8fqPUZNSHPfdT9u+TRrwYF6ytzNIdUtY+ItCSQdK4HHiewr0U+0E9VPSKyFkit43wBHlLVvx/jdY35BuuGMs3RuwT2hQBARA4v/9wS2Bh8frWD159OoPsLYHRDhVV1D4EtVW8TkSQCcW4NJoqhQIdg0X1AZsipU4EfBfdEQETaikhBmOpgmhlLFibepYtIZcjjVgL/8JYHB30XE1haHuD3wEMi8hmQ4GBMtwC3isgMoBDY09AJqjqHwKqjowlsHlQuIhUEWhlLg2V2AJ8Fb7X9g6q+S6Cb6wsRWQC8wteTiTGNZrfOGhNhwZ38DqmqishoYIyqXtDQecZEk41ZGBN5/YC/Bu9g2k0Et6w15nhZy8IYY0yDbMzCGGNMgyxZGGOMaZAlC2OMMQ2yZGGMMaZBliyMMcY0yJKFMcaYBv0/ierE89t3UsQAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn.fit_one_cycle(10, 1e-2, wd=1e-6)",
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"text": "Total time: 01:10\nepoch train_loss valid_loss accuracy\n1 0.251339 0.246639 0.902300 (00:06)\n2 0.201433 0.232700 0.910700 (00:06)\n3 0.170635 0.232286 0.909600 (00:06)\n4 0.136034 0.249854 0.912600 (00:06)\n5 0.113001 0.260503 0.912100 (00:07)\n6 0.104206 0.246688 0.916800 (00:07)\n7 0.079228 0.280759 0.915900 (00:07)\n8 0.079510 0.269010 0.915800 (00:07)\n9 0.067095 0.281458 0.913500 (00:07)\n10 0.072313 0.277925 0.915200 (00:07)\n\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Categorical Embedding Dropout"
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def get_is_cat_unk(ni):\n emb = nn.Embedding(ni, 1)\n emb.weight.requires_grad = False\n emb.weight.zero_()\n emb.weight[0] = 1.\n return emb\n\nget_is_cat_unk(5)",
"execution_count": 18,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 18,
"data": {
"text/plain": "Embedding(5, 1)"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class LongDropout(nn.Module):\n \"Dropout for LongTensor\"\n def __init__(self, p=0.5): \n super().__init__()\n self.p = p.item() if isinstance(p, torch.Tensor) else p \n def forward(self, input):\n rand = torch.rand_like(input, dtype=torch.float)\n return torch.where(rand >= self.p, input, torch.zeros_like(input))",
"execution_count": 19,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "class TabularModel2(nn.Module):\n def __init__(self, emb_szs:ListSizes, n_cont:int, out_sz:int, layers:Collection[int],\n ps:Collection[float]=None, emb_in_drop: Collection[float]=None, emb_out_drop:float=0., \n y_range:OptRange=None, use_bn:bool=True):\n super().__init__()\n ps = ifnone(ps, [0]*len(layers))\n ps = listify(ps, layers)\n self.embeds = nn.ModuleList([embedding(ni, nf) for ni,nf in emb_szs])\n self.is_cat_unk = nn.ModuleList([get_is_cat_unk(ni) for ni,_ in emb_szs])\n emb_in_drop = ifnone(emb_in_drop, [0.]*len(emb_szs))\n emb_in_drop = listify(emb_in_drop, emb_szs)\n self.emb_in_drop = nn.ModuleList([LongDropout(p) for p in emb_in_drop])\n self.emb_out_drop = nn.Dropout(emb_out_drop)\n self.bn_cont = nn.BatchNorm1d(n_cont)\n n_emb = sum(e.embedding_dim+1 for e in self.embeds)\n self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range\n sizes = self.get_sizes(layers, out_sz)\n actns = [nn.ReLU(inplace=True)] * (len(sizes)-2) + [None]\n layers = []\n for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):\n layers += bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)\n self.layers = nn.Sequential(*layers)\n\n def get_sizes(self, layers, out_sz):\n return [self.n_emb + self.n_cont] + layers + [out_sz]\n \n def forward(self, x_cat:Tensor, x_cont:Tensor) -> Tensor:\n if self.n_emb != 0:\n x = []\n for i,(drop,emb,unk) in enumerate(zip(self.emb_in_drop, self.embeds, self.is_cat_unk)):\n x_i_cat = drop(x_cat[:,i]) if self.training else x_cat[:,i] # emb_in_dropout (for each cat)\n x_i_emb = emb(x_i_cat) # embedding vector\n x_i_unk = unk(x_i_cat) # 1: if unknown category; 0: otherwise \n x.append(torch.cat([x_i_emb, x_i_unk], dim=1))\n x = torch.cat(x, 1)\n x = self.emb_out_drop(x)\n if self.n_cont != 0:\n x_cont = self.bn_cont(x_cont)\n x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont\n x = self.layers(x)\n if self.y_range is not None:\n x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]\n return x",
"execution_count": 20,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "def tabular_learner2(data:DataBunch, layers:Collection[int], emb_szs:Dict[str,int]=None, metrics=None,\n ps:Collection[float]=None, emb_in_drop=None, emb_out_drop:float=0., y_range:OptRange=None, use_bn:bool=True, **kwargs):\n \"Get a `Learner` using `data`, with `metrics`, including a `TabularModel` created using the remaining params.\"\n emb_szs = data.get_emb_szs(ifnone(emb_szs, {}))\n model = TabularModel2(emb_szs, len(data.cont_names), out_sz=data.c, layers=layers, ps=ps, \n emb_in_drop=emb_in_drop, emb_out_drop=emb_out_drop, y_range=y_range, use_bn=use_bn)\n return Learner(data, model, metrics=metrics, **kwargs)",
"execution_count": 21,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# get the % of unknown values for each category in valid set\npct_unk_by_cat = torch.zeros(len(cat_cols))\nfor x,_ in data.valid_dl: pct_unk_by_cat += (x[0] == 0).sum(dim=0).cpu().float()\npct_unk_by_cat /= len(data.valid_ds)\npct_unk_by_cat",
"execution_count": 22,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 22,
"data": {
"text/plain": "tensor([0.0000, 0.0433, 0.0000, 0.0000, 0.0157, 0.0010, 0.0178, 0.0011, 0.0026,\n 0.0000, 0.0000, 0.0021, 0.2679, 0.0000, 0.0000, 0.0000, 0.0000])"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "emb_in_drop = 0.1 * pct_unk_by_cat\nlearn2 = tabular_learner2(data, layers=[64], ps=[0.5], emb_in_drop=emb_in_drop, emb_out_drop=0.05, metrics=accuracy)",
"execution_count": 90,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "learn2.fit_one_cycle(10, 3e-3, wd=1e-5)",
"execution_count": 91,
"outputs": [
{
"output_type": "stream",
"text": "Total time: 01:31\nepoch train_loss valid_loss accuracy\n1 0.261289 0.235679 0.907700 (00:08)\n2 0.211450 0.224960 0.910800 (00:09)\n3 0.168766 0.241331 0.910700 (00:08)\n4 0.137618 0.253354 0.904200 (00:09)\n5 0.116874 0.230630 0.912900 (00:09)\n6 0.103587 0.233860 0.914100 (00:09)\n7 0.095109 0.240633 0.914900 (00:09)\n8 0.092940 0.244328 0.915300 (00:09)\n9 0.082886 0.250083 0.914300 (00:09)\n10 0.066324 0.249013 0.915700 (00:09)\n\n",
"name": "stdout"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"kernelspec": {
"name": "conda-env-fastai-py",
"display_name": "Python [conda env:fastai]",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.6.6",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",
"nbconvert_exporter": "python",
"file_extension": ".py"
},
"varInspector": {
"window_display": false,
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"library": "var_list.py",
"delete_cmd_prefix": "del ",
"delete_cmd_postfix": "",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"library": "var_list.r",
"delete_cmd_prefix": "rm(",
"delete_cmd_postfix": ") ",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
]
},
"gist": {
"id": "93790dc3f0926f67b0d984e660679051",
"data": {
"description": "mercadolibre/2-Copy1. tabular-cat-embs.ipynb",
"public": false
}
},
"_draft": {
"nbviewer_url": "https://gist.github.com/93790dc3f0926f67b0d984e660679051"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment