Last active
September 28, 2020 09:12
-
-
Save BenedictWilkins/d58bcecc48eaf0553320484ee7eda040 to your computer and use it in GitHub Desktop.
Train a siamese network on MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# REQUIRES\n", | |
"\n", | |
"`[h5py, numpy, matplotlib, pytorch]`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%matplotlib notebook\n", | |
"\n", | |
"import numpy as np\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"class CNet(nn.Module):\n", | |
"\n", | |
" def __init__(self, input_shape):\n", | |
" super(CNet, self).__init__() \n", | |
" self.input_shape = input_shape\n", | |
" \n", | |
" self.conv1 = nn.Conv2d(input_shape[0], 64, kernel_size=4, stride=2)\n", | |
" self.conv2 = nn.Conv2d(64, 32, kernel_size=4, stride=1)\n", | |
" self.conv3 = nn.Conv2d(32, 16, kernel_size=4, stride=1)\n", | |
" \n", | |
" s1 = conv_output_shape(input_shape[1:], kernel_size=4, stride=2)\n", | |
" s2 = conv_output_shape(s1, kernel_size=4, stride=1)\n", | |
" s3 = conv_output_shape(s2, kernel_size=4, stride=1)\n", | |
" \n", | |
" self.output_shape = np.prod(s3) * 16\n", | |
" \n", | |
" def to(self, device):\n", | |
" self.device = device\n", | |
" return super(CNet, self).to(device)\n", | |
"\n", | |
" def forward(self, x_):\n", | |
" x_ = x_.to(self.device)\n", | |
" y_ = F.leaky_relu(self.conv1(x_))\n", | |
" y_ = F.leaky_relu(self.conv2(y_))\n", | |
" y_ = F.leaky_relu(self.conv3(y_)).view(x_.shape[0], -1)\n", | |
" return y_\n", | |
" \n", | |
"class CNet2(CNet):\n", | |
" \n", | |
" def __init__(self, input_shape, output_shape, activation=lambda x: x):\n", | |
" super(CNet2, self).__init__(input_shape)\n", | |
" self.out_layer = nn.Linear(self.output_shape, output_shape)\n", | |
" self.output_shape = output_shape\n", | |
" self.activation = activation\n", | |
" \n", | |
" def forward(self, x_):\n", | |
" x_ = super(CNet2, self).forward(x_)\n", | |
" y_ = self.activation(self.out_layer(x_))\n", | |
" return y_\n", | |
"\n", | |
"def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1):\n", | |
" from math import floor\n", | |
" if type(kernel_size) is not tuple:\n", | |
" kernel_size = (kernel_size, kernel_size)\n", | |
" h = floor( ((h_w[0] + (2 * pad) - ( dilation * (kernel_size[0] - 1) ) - 1 )/ stride) + 1)\n", | |
" w = floor( ((h_w[1] + (2 * pad) - ( dilation * (kernel_size[1] - 1) ) - 1 )/ stride) + 1)\n", | |
" return h, w" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import h5py #pip install h5py -- https://www.h5py.org/\n", | |
"\n", | |
"def mnist():\n", | |
" #load train\n", | |
" f = h5py.File(\"./train.hdf5\", 'r')\n", | |
" train_x, train_y = f['image'][...], f['label'][...]\n", | |
" f.close()\n", | |
"\n", | |
" #load test\n", | |
" f = h5py.File(\"./test.hdf5\", 'r')\n", | |
" test_x, test_y = f['image'][...], f['label'][...]\n", | |
" f.close()\n", | |
"\n", | |
" print(\"train_x\", train_x.shape, train_x.dtype)\n", | |
" print(\"train_y\", train_y.shape, train_y.dtype)\n", | |
" \n", | |
" return train_x[:,np.newaxis], train_y, test_x[:,np.newaxis], test_y\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def distance_matrix(x1, x2=None): #L22 distance by default\n", | |
" if x2 is None:\n", | |
" x2 = x1\n", | |
" n_dif = x1.unsqueeze(1) - x2.unsqueeze(0)\n", | |
" return torch.sum(n_dif * n_dif, -1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def loss(model, x, y, margin=0.2):\n", | |
" x_ = model(x)\n", | |
" unique = np.unique(y)\n", | |
" device = list(model.parameters())[0].device\n", | |
" loss = torch.FloatTensor(np.array([0.])).to(device)\n", | |
"\n", | |
" for u in unique:\n", | |
" pi = np.nonzero(y == u)[0]\n", | |
" ni = np.nonzero(y != u)[0]\n", | |
" \n", | |
" #slightly more efficient below\n", | |
" xp_ = x_[pi] # get all positive images\n", | |
" xn_ = x_[ni] # get all negative images\n", | |
" xp = distance_matrix(xp_, xp_) #P-P distance\n", | |
" xn = distance_matrix(xp_, xn_) #P-N distance\n", | |
"\n", | |
" #3D tensor, (a - p) - (a - n) \n", | |
" xf = xp.unsqueeze(2) - xn\n", | |
"\n", | |
" xf = F.relu(xf + margin) #triplet loss\n", | |
" loss += xf.sum()\n", | |
"\n", | |
" return loss" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import matplotlib.pyplot as plt\n", | |
"plt.rcParams['figure.figsize'] = [8,4.5]\n", | |
"\n", | |
"def plot(fig, model, x, y):\n", | |
" plt.clf()\n", | |
" with torch.no_grad():\n", | |
" z = model(x).cpu().numpy()\n", | |
" for i in range(0,10):\n", | |
" plt.scatter(*z[y==i].T, marker=\".\", alpha=0.5, edgecolors='none')\n", | |
" plt.legend([str(i) for i in range(0,10)], loc=\"upper right\")\n", | |
" fig.canvas.draw()\n", | |
" \n", | |
"def figtoimage(fig):\n", | |
" # Get the RGBA buffer from the figure\n", | |
" w,h = fig.canvas.get_width_height()\n", | |
" buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)\n", | |
" return np.flip(buf.reshape((h,w,3)), 2) #bgr format for opencv!\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"input_dim = (1, 28, 28)\n", | |
"batch_size = 100\n", | |
"margin = 0.2\n", | |
"latent_dim = 2\n", | |
"lr = 0.0005\n", | |
"epochs = 3\n", | |
"\n", | |
"if torch.cuda.is_available(): \n", | |
" device = 'cuda'\n", | |
"else:\n", | |
" device = 'cpu'\n", | |
"print(\"USING DEVICE:\", device)\n", | |
"\n", | |
"x_train, y_train, x_test, y_test = mnist()\n", | |
"x_train = torch.FloatTensor(x_train).to(device)\n", | |
"x_test = torch.FloatTensor(x_test).to(device)\n", | |
"model = CNet2(input_dim, latent_dim).to(device)\n", | |
"\n", | |
"optim = torch.optim.Adam(model.parameters(), lr=lr)\n", | |
"\n", | |
"fig = plt.figure()\n", | |
"fig.tight_layout()\n", | |
"plot(fig, model, x_test, y_test)\n", | |
"img = figtoimage(fig)\n", | |
"\n", | |
"plt.imsave(\"./initial.png\", img)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"fig = plt.figure()\n", | |
"fig.tight_layout()\n", | |
"video = []\n", | |
"\n", | |
"x_train = x_train.reshape(x_train.shape[0] // batch_size, batch_size, *x_train.shape[1:])\n", | |
"y_train = y_train.reshape(y_train.shape[0] // batch_size, batch_size, *y_train.shape[1:])\n", | |
"\n", | |
"for e in range(epochs):\n", | |
" for x,y in zip(*[x_train, y_train]):\n", | |
" optim.zero_grad()\n", | |
" _loss = loss(model, x, y, margin=margin)\n", | |
" _loss.backward()\n", | |
" optim.step()\n", | |
" #print(_loss.item())\n", | |
" \n", | |
" plot(fig, model, x_test, y_test)\n", | |
" video.append(figtoimage(fig))\n", | |
" \n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"scrolled": false | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import cv2\n", | |
"\n", | |
"file = \"./video.mp4\"\n", | |
"fps = 24\n", | |
"#video must be CV format (NHWC)\n", | |
"fourcc = cv2.VideoWriter_fourcc(*'mp4v') #ehhh.... platform specific?\n", | |
"writer = cv2.VideoWriter(file, fourcc, fps, (video[0].shape[1], video[0].shape[0]), True)\n", | |
"for frame in video:\n", | |
" writer.write(frame)\n", | |
" " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "PhD", | |
"language": "python", | |
"name": "phd" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment