Skip to content

Instantly share code, notes, and snippets.

@iliasprc
Created April 6, 2020 12:10
Show Gist options
  • Save iliasprc/a423e261d6b84ec4e22aac33c28aefc8 to your computer and use it in GitHub Desktop.
Save iliasprc/a423e261d6b84ec4e22aac33c28aefc8 to your computer and use it in GitHub Desktop.
Useful functions on PyTorch.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Useful functions on PyTorch.ipynb",
"provenance": [],
"collapsed_sections": [],
"toc_visible": true,
"authorship_tag": "ABX9TyPv0byDSiBMgVxV0Y+VKElb",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/IliasPap/a423e261d6b84ec4e22aac33c28aefc8/useful-functions-on-pytorch.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TQ1gUv0V0aOJ",
"colab_type": "text"
},
"source": [
"# Useful functions on PyTorch \n",
"\n",
"A year ago I started implementing deep neural networks with PyTorch framework. Below I gathered several functions and implementations that I found useful during the first days of coding."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UX2mL4sJeMy1",
"colab_type": "text"
},
"source": [
"\n",
"## Metrics\n",
"\n",
"Accuracy Calculation for binary and multilabel classification problems\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RJc_2Q4HMk7_",
"colab_type": "code",
"colab": {}
},
"source": [
"import torch \n",
"\n",
"\n",
"def binary_accuracy(output, target):\n",
" \"\"\"Computes the accuracy for multiple binary predictions\n",
" output : network outpout with shape Batch_size x 1 in range [0,1]\n",
" target : binary targets shape Batch_size x 1 with values 0 or 1\n",
" \"\"\"\n",
" \n",
" pred = output > 0.5\n",
" truth = target > 0.5\n",
" \n",
" acc = pred.eq(truth).sum() / target.numel()\n",
" return acc\n",
"\n",
"\n",
"\n",
"def accuracy(output, target, topk=(1,)):\n",
" \"\"\"Computes the precision@k for the specified values of k \n",
" output : Logits batch size x classes\n",
" target : targets batch size x 1 ( int not one hot vector)\n",
" topk : specify top-k accuracy metric\n",
" \"\"\"\n",
" maxk = max(topk)\n",
" batch_size = target.size(0)\n",
"\n",
" _, pred = output.topk(maxk, 1, True, True)\n",
" pred = pred.t()\n",
" correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
"\n",
" res = []\n",
" for k in topk:\n",
" correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)\n",
" res.append(correct_k.mul_(100.0 / batch_size))\n",
" return res\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "L_CaEJNBOZ9H",
"colab_type": "text"
},
"source": [
"## Checkpoints \n",
"\n",
"Functions to save and load checkpoints \n",
"\n",
"Save model, optimizer and epoch in one single checkpoint file \n",
"\n",
"Option to save each module of model independently to load each state_dict independently i.e if you change a layer in your model and you want to load the rest"
]
},
{
"cell_type": "code",
"metadata": {
"id": "6ZYgrrdlOcBV",
"colab_type": "code",
"colab": {}
},
"source": [
"def load_checkpoint_modules(checkpoint, model, strict=True, optimizer=None, load_seperate_layers=False):\n",
" \"\"\"Loads model parameters (state_dict) from file_path. \n",
" If optimizer is provided, loads state_dict of\n",
" optimizer assuming it is present in checkpoint.\n",
" Args:\n",
" checkpoint: (string) filename which needs to be loaded\n",
" model: (torch.nn.Module) model for which the parameters are loaded\n",
" optimizer: (torch.optim) optional: resume optimizer from checkpoint\n",
" load_seperate_layers optional : if checkpoint has each module state dict save independetly we can load each layer independetly\n",
" Useful to load part of checkpoint .e.g Cnn only \n",
" \"\"\"\n",
" if not os.path.exists(checkpoint):\n",
" raise (\"File doesn't exist {}\".format(checkpoint))\n",
" checkpoint = torch.load(checkpoint, map_location='cpu')\n",
" print(checkpoint.keys())\n",
" if (not load_seperate_layers):\n",
"\n",
" model.load_state_dict(checkpoint['model_dict'], strict=strict)\n",
" else:\n",
" for name1, module in model.named_children():\n",
" \n",
" module.load_state_dict(checkpoint[name1 + '_dict'], strict=strict)\n",
"\n",
" epoch = 0\n",
" if optimizer != None:\n",
" optimizer.load_state_dict(checkpoint['optimizer_dict'])\n",
"\n",
" return checkpoint, epoch\n",
"\n",
"\n",
"def save_checkpoint(model, optimizer, epoch, checkpoint, name, save_seperate_layers=False):\n",
" state = {}\n",
" if (save_seperate_layers):\n",
" for child_name, module in model.named_children():\n",
"\n",
" state[child_name + '_dict'] = module.state_dict()\n",
"\n",
" state['model_dict'] = model.state_dict()\n",
" state['optimizer_dict'] = optimizer.state_dict()\n",
" state['epoch'] = epoch\n",
" filepath = os.path.join(checkpoint, name + '.pth')\n",
" if not os.path.exists(checkpoint):\n",
" print(\"Checkpoint Directory does not exist! Making directory {}\".format(checkpoint))\n",
" os.mkdir(checkpoint)\n",
" else:\n",
" print(\"Checkpoint Directory exists! {}\".format(checkpoint))\n",
"\n",
" torch.save(state, filepath)\n",
" print(\"CHECKPOINT SAVED\")\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_mVJoQDTO_xJ",
"colab_type": "text"
},
"source": [
"## Weights initialization\n",
"\n",
"\n",
"Several functions for CNN, RNN and Batchnorm weight initialization"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G-YKE4dBPDK2",
"colab_type": "code",
"colab": {}
},
"source": [
"## how to freeze batchnorm layers\n",
"def set_bn_eval(m):\n",
" classname = m.__class__.__name__\n",
" if classname.find('BatchNorm') != -1:\n",
" \n",
" m.eval()\n",
"\n",
"\n",
"def weights_init(m):\n",
" classname = m.__class__.__name__\n",
"\n",
" if classname.find('Conv2d') != -1:\n",
" m.weight.data.normal_(0.0, 0.02)\n",
" elif classname.find('BatchNorm') != -1:\n",
" m.weight.data.normal_(1.0, 0.02)\n",
" m.bias.data.fill_(0)\n",
"\n",
"\n",
"def init_weights_linear(m):\n",
" if type(m) == nn.Linear:\n",
" torch.nn.init.xavier_uniform_(m.weight)\n",
" m.bias.data.fill_(0.01)\n",
"\n",
"\n",
"def init_weights_rnn(model):\n",
" for m in model.modules():\n",
" if type(m) in [nn.GRU, nn.LSTM, nn.RNN]:\n",
" for name, param in m.named_parameters():\n",
" if 'weight_ih' in name:\n",
" torch.nn.init.xavier_uniform_(param.data)\n",
" elif 'weight_hh' in name:\n",
" torch.nn.init.orthogonal_(param.data)\n",
" elif 'bias' in name:\n",
" param.data.fill_(0)\n",
"\n",
"\n",
"def weights_init_uniform(net):\n",
" for name, param in net.named_parameters():\n",
" if 'bias' in name:\n",
" nn.init.constant_(param, 0.0)\n",
" elif 'weight' in name:\n",
" nn.init.uniform_(param, a=-0.1, b=0.1)\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "rpLLFaF5fLCi",
"colab_type": "text"
},
"source": [
"## Identity layer\n",
"\n",
"Let's say you want to use a 2D CNN trained on ImageNet for feature extraction for your task. You want to remove the default classifier and extend the CNN with other modules e.g. RNN . You can create a simple module that just returns the input.\n",
"(nn.Identity is implmented for PyTorch version >= 1.1.0 )"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5CZKZ0GJfN43",
"colab_type": "code",
"colab": {}
},
"source": [
"class Identity(nn.Module):\n",
" def __init__(self):\n",
" super(Identity, self).__init__()\n",
"\n",
" def forward(self, x):\n",
" return x\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "6H5J7mI15nED",
"colab_type": "text"
},
"source": [
"### Usage\n",
"\n",
"Simple example with pretrained CNNs from torchvision\n",
"\n",
"```\n",
"cnn = torchvision.models.alexnet(pretrained=True) \n",
"cnn.classifier[-1] = Identity() \n",
"```\n",
"\n",
"or\n",
" ```\n",
"cnn = torchvision.models.resnet18(pretrained=True)\n",
"cnn.fc = Identity()\n",
"```\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "csq8fBeHfBxQ",
"colab_type": "text"
},
"source": [
"## Time Distributed layer in PyTorch\n",
"\n",
"2D CNNs in PyTorch accepts a 4D input tensor with shape (Batch_size, Channels, Height, Width)\n",
"\n",
"\n",
"But what happens if you want to extract features from a video (image sequence) which is a 5D tensor with shape (Batch_size, Timesteps, Channels, Height, Width) ?\n",
"\n",
"You can \"cheat\" the CNN by reshaping the tensor to 4D size using batch size dimensio as in the following example:"
]
},
{
"cell_type": "code",
"metadata": {
"id": "soOTSZqNfIDf",
"colab_type": "code",
"colab": {}
},
"source": [
"class CNN_RNN(nn.Module):\n",
" def __init__(self, hidden_size=512, n_layers=2, dropt=0.5, bi=False, N_classes=1000, mode='continuous',\n",
" backbone='alexnet'):\n",
" super(CNN_RNN, self).__init__()\n",
" self.cnn = torchvision.models.alexnet(pretrained=True) \n",
" '''\n",
" .....\n",
" .....\n",
" .....\n",
" '''\n",
" def forward(self,x):\n",
"\n",
"\n",
" batch_size, timesteps, C, H, W = x.size()\n",
" c_in = x.view(batch_size * timesteps, C, H, W)\n",
" c_outputs = self.cnn(c_in)\n",
"\n",
" c_out = c_outputs.contiguous().view(batch_size, timesteps, -1)\n",
" '''\n",
" .....\n",
" .....\n",
" .....\n",
" '''\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "8KqAWcrgPK1e",
"colab_type": "text"
},
"source": [
"## Video Dataloader\n",
"\n",
"Implementation of Video Dataloader given the paths/to/Video\n",
"\n",
"Augmentation techniiques have been used in training examples\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "hRFDlMyE_oQy",
"colab_type": "code",
"outputId": "c7f36b66-e0b4-4236-9e91-e4e1456e88c1",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 231
}
},
"source": [
"class VideoDataset(Dataset):\n",
" def __init__(self, path_prefix, classes, seq_length=250, dim=(224, 224),\n",
" padding=False, normalize=True):\n",
"\n",
" \"\"\"\n",
" Args:\n",
"\n",
"\n",
"\n",
" path_prefix : train or test and path prefix to read frames acordingly\n",
" classes : list of classes\n",
" channels: Number of channels of frames\n",
" seq_length : Number of frames to be loaded in a sample\n",
" dim: Dimensions of the frames\n",
" normalize : normalize tensor with imagenet mean and std\n",
" padding : padding of video to size seq_length\n",
"\n",
"\n",
"\n",
" \"\"\"\n",
"\n",
" self.mode = path_prefix\n",
"\n",
" if (self.mode == 'train'):\n",
"\n",
" self.list_IDs, self.labels = read_video_paths('./files/train.txt')\n",
" elif (self.mode == 'test'):\n",
" self.list_IDs, self.labels = read_video_paths('./files/test.txt')\n",
" \n",
" print(\"{} examples {}\".format(self.mode, len(self.list_IDs)))\n",
" self.classes = classes\n",
" self.seq_length = seq_length\n",
" self.mode = path_prefix\n",
" self.seq_length = seq_length\n",
" self.dim = dim\n",
" self.normalize = normalize\n",
" self.padding = padding\n",
"\n",
" def __len__(self):\n",
" return len(self.list_IDs)\n",
"\n",
" def __getitem__(self, index):\n",
" ID = self.list_IDs[index]\n",
"\n",
"\n",
" label = int(self.labels[index])\n",
"\n",
" y = torch.tensor(label, dtype=torch.long)\n",
" x = self.load_video_sequence_uniform_sampling(\n",
" path='/pathtovideo/' + ID, time_steps=self.seq_length,\n",
" dim=self.dim,\n",
" augmentation=self.mode, padding=self.padding, normalize=self.normalize,\n",
" img_type='jpg')\n",
"\n",
" return x, y\n",
"\n",
" def load_video_sequence_uniform_sampling(self, path, time_steps, dim=(224, 224), augmentation='test', padding=False,\n",
" normalize=True,\n",
" img_type='png'):\n",
" images = sorted(glob.glob(os.path.join(path, '*' + img_type)))\n",
"\n",
" h_flip = False\n",
" img_sequence = []\n",
" # print(images)\n",
" if (augmentation == 'train'):\n",
" ## training set temporal AUGMENTATION\n",
" temporal_augmentation = int((np.random.randint(80, 100) / 100.0) * len(images))\n",
" if (temporal_augmentation > 15):\n",
" images = sorted(sampling(images, temporal_augmentation))\n",
" if (len(images) > time_steps):\n",
" # random frame sampling\n",
" images = sorted(sampling(images, time_steps))\n",
"\n",
" else:\n",
" # test uniform sampling\n",
" if (len(images) > time_steps):\n",
" images = sorted(sampling(images, time_steps))\n",
"\n",
" i = np.random.randint(0, 30)\n",
" j = np.random.randint(0, 30)\n",
"\n",
" brightness = 1 + random.uniform(-0.1, +0.1)\n",
" contrast = 1 + random.uniform(-0.1, +0.1)\n",
" hue = random.uniform(0, 1) / 20.0\n",
"\n",
" r_resize = ((256, 256))\n",
"\n",
" # brightness = 1\n",
" # contrast = 1\n",
" # hue = 0\n",
" t1 = VideoRandomResizedCrop(dim[0], scale=(0.9, 1.0), ratio=(0.8, 1.2))\n",
" for img_path in images:\n",
"\n",
" frame = Image.open(img_path)\n",
" frame.convert('RGB')\n",
"\n",
" if (augmentation == 'train'):\n",
"\n",
" ## training set DATA AUGMENTATION\n",
"\n",
" frame = frame.resize(r_resize)\n",
"\n",
" img_tensor = video_transforms(img=frame, i=i, j=j, bright=brightness, cont=contrast, h=hue, dim=dim,\n",
" resized_crop=t1,\n",
" augmentation='train',\n",
" normalize=normalize)\n",
" img_sequence.append(img_tensor)\n",
" else:\n",
" # TEST set NO DATA AUGMENTATION\n",
" frame = frame.resize(dim)\n",
"\n",
" img_tensor = video_transforms(img=frame, i=i, j=j, bright=0, cont=0, h=0, dim=dim, augmentation='test',\n",
" normalize=normalize)\n",
" img_sequence.append(img_tensor)\n",
" pad_len = time_steps - len(images)\n",
"\n",
" X1 = torch.stack(img_sequence).float()\n",
" # print(len(images))\n",
" if (padding):\n",
" X1 = pad_video(X1, padding_size=pad_len, padding_type='zeros')\n",
" elif (len(images) < 52):\n",
" X1 = pad_video(X1, padding_size=52 - len(images), padding_type='zeros')\n",
"\n",
" return X1\n",
"\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "error",
"ename": "NameError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-11c774d72e69>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0mVideoDataset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m def __init__(self, path_prefix, classes, seq_length=250, dim=(224, 224),\n\u001b[1;32m 3\u001b[0m padding=False, normalize=True):\n\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \"\"\"\n",
"\u001b[0;31mNameError\u001b[0m: name 'Dataset' is not defined"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tyM5onErOT1A",
"colab_type": "text"
},
"source": [
"List sampling , has been used in frame sampling for train gata augmentation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "PcAgt9DrOTMT",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"def rescale_list(input_list, size):\n",
" assert len(input_list) >= size\n",
"\n",
" # Get the number to skip between iterations.\n",
" skip = len(input_list) // size\n",
"\n",
" # Build our new output.\n",
" output = [input_list[i] for i in range(0, len(input_list), skip)]\n",
"\n",
" # Cut off the last one if needed.\n",
" return output[:size]\n",
"\n",
"\n",
"def uniform_sampling(clip, size):\n",
" return_ind = [int(i) for i in np.linspace(1, len(clip), num=size)]\n",
"\n",
" return [clip[i - 1] for i in return_ind]\n",
"\n",
"\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "AXyB8MIqANAa",
"colab_type": "text"
},
"source": [
"## Video padding\n",
"Pad (5D) video tensor with the first image of the video or a black image (torch.zeros)\n",
"\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DOMcR_50OByq",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"def pad_video(x, padding_size=0, padding_type='images'):\n",
" if (padding_size != 0):\n",
"\n",
" if padding_type == 'images':\n",
" pad_img = x[0]\n",
"\n",
" padx = pad_img.repeat(padding_size, 1, 1, 1)\n",
" X = torch.cat((padx, x))\n",
" return X\n",
" elif padding_type == 'zeros':\n",
" T, C, H, W = x.size()\n",
"\n",
" padx = torch.zeros((padding_size, C, H, W))\n",
" X = torch.cat((padx, x))\n",
" return X\n",
" return x\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "JWvSt3ELOCNf",
"colab_type": "text"
},
"source": [
"## Channel shuffling\n",
"Radnom shuffling of RGB channels of image or video tensors"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-1xPQ-ZaOD4R",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"def channel_shuffle(x):\n",
"\n",
"\n",
" r = x[0,:,:]\n",
" g = x[1,:,:]\n",
" b = x[2,:,:]\n",
"\n",
" rgb = [r,g,b]\n",
" random.shuffle(rgb)\n",
" x = torch.stack(rgb,dim=0)\n",
"\n",
"\n",
" return x\n",
"def video_channel_shuffle(x):\n",
"\n",
"\n",
" r = x[:,0,:,:]\n",
" g = x[:,1,:,:]\n",
" b = x[:,2,:,:]\n",
"\n",
" rgb = [r,g,b]\n",
"\n",
" random.shuffle(rgb)\n",
"\n",
" x = torch.stack(rgb,dim=1)\n",
"\n",
"\n",
" return x\n",
"\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q3yvwbygOFT3",
"colab_type": "text"
},
"source": [
"## Multi-Label target to tensor\n",
"\n",
"dictionary id2w is made as follows id2w = { 'label1' : 0 , 'label2' : 1 , .... }\n",
"\n",
"The function also checks is the label is inside the dictionary or if it's OOV (out of vocabulary)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "-JWqoSUOPNBC",
"colab_type": "code",
"outputId": "00de3695-3565-4620-c4ca-6017c850c650",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 129
}
},
"source": [
"\n",
"def multi_label_to_index(classes, id2w, target_labels):\n",
" indexes = []\n",
" for word in target_labels.split(' '):\n",
" indexes.append(id2w[word])\n",
"\n",
" return torch.tensor(indexes, dtype=torch.int)\n"
],
"execution_count": 0,
"outputs": [
{
"output_type": "error",
"ename": "SyntaxError",
"evalue": "ignored",
"traceback": [
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-2-bbb63571a818>\"\u001b[0;36m, line \u001b[0;32m245\u001b[0m\n\u001b[0;31m return one_hot_target = (labels == torch.arange(num_classes).reshape(1, num_classes))\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SmIgo_7lwUKH",
"colab_type": "text"
},
"source": [
"## One-hot-vector"
]
},
{
"cell_type": "code",
"metadata": {
"id": "QxPML4x_wX33",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"def one_hot_vector(labels,num_classes):\n",
" return one_hot_target = (labels == torch.arange(num_classes).reshape(1, num_classes))\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "yBUuQ0n0zEYh",
"colab_type": "text"
},
"source": [
"## CTC Loss variants implementation\n",
"\n",
"Original CTC loss from torch.nn\n",
"\n",
"Focal CTC Loss\n",
"\n",
"Aggregation Cross-entropy loss"
]
},
{
"cell_type": "code",
"metadata": {
"id": "w_D-nFPZzEyY",
"colab_type": "code",
"colab": {}
},
"source": [
"class CTC_Loss(nn.Module):\n",
" def __init__(self, crit, average=True, alpha=0.99, gamma=2.0):\n",
" super(CTC_Loss, self).__init__()\n",
"\n",
" self.alpha = alpha\n",
" self.gamma = gamma\n",
" \n",
" self.crit = crit\n",
" self.average = average\n",
" \n",
"\n",
" if (crit == 'normal'):\n",
" self.loss = self.normal_ctc_loss\n",
"\n",
" elif (crit == 'aggregation'):\n",
"\n",
" self.loss = self.Aggregation_CE\n",
"\n",
" elif (crit == 'focal'):\n",
" self.loss = self.focal_ctc_loss\n",
"\n",
"\n",
"\n",
"\n",
" def forward(self, output, target):\n",
"\n",
" cost = self.loss(output, target)\n",
"\n",
" return cost\n",
"\n",
" def normal_ctc_loss(self, log_probs, target):\n",
" if (self.average):\n",
" criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)\n",
" else:\n",
" criterion = nn.CTCLoss(blank=0, reduction='sum', zero_infinity=True)\n",
" input_len = torch.tensor([log_probs.size(0)], dtype=torch.int)\n",
" target_len = torch.tensor([target.size(1)], dtype=torch.int)\n",
" loss = criterion(nn.functional.log_softmax(log_probs, dim=2), target, input_len, target_len)\n",
" return loss\n",
"\n",
" def Aggregation_CE(self, outputs, target):\n",
"\n",
" Time, batch_size, N_classes = outputs.size()\n",
" probs = nn.functional.softmax(outputs, dim=-1)\n",
" target[:, 0] = 0.0\n",
" input = torch.sum(probs, dim=0)\n",
" input = input / float(Time)\n",
" target = target / float(Time)\n",
"\n",
" loss = (-torch.sum(torch.log(input) * target))\n",
" return loss\n",
"\n",
" def focal_ctc_loss(self, log_probs, target):\n",
"\n",
" if (self.average):\n",
" criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)\n",
" else:\n",
" criterion = nn.CTCLoss(blank=0, reduction='sum', zero_infinity=True)\n",
" input_len = torch.tensor([log_probs.size(0)], dtype=torch.int)\n",
" target_len = torch.tensor([target.size(1)], dtype=torch.int)\n",
" loss = criterion(nn.functional.log_softmax(log_probs, dim=2), target, input_len, target_len)\n",
" p = torch.exp((-1) * loss)\n",
" focal_loss = self.alpha * ((1 - p) ** self.gamma) * loss\n",
" return focal_loss"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "_r96egyxAPyQ",
"colab_type": "text"
},
"source": [
"## CTC Decoder\n",
"\n",
"CTC decoding function that removes blanks and repeated labels e.g F(A A - - B) = AB"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sLcXcp16_lBR",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"def ctc_tensors_decode(output, target):\n",
" \"\"\"\"\n",
" CTC decode of output tensor using tensors operations without losing autograd \n",
"\n",
" so it is possible to backpropagate\n",
" \n",
" \"\"\"\n",
"\n",
" prev_found_word = -1\n",
"\n",
"\n",
" probs = nn.functional.softmax(output, dim=2) \n",
" pred = probs.argmax(dim=2).squeeze(1) \n",
"\n",
"\n",
" ### remove blanks \n",
" no_blanks_probs = probs[pred != 0]\n",
" if no_blanks_probs.nelement() == 0:\n",
" no_blanks_probs = probs[-1, :, :]\n",
" new_pred = no_blanks_probs.argmax(dim=2).squeeze(1) \n",
" prev = ''\n",
" decoded_labels = no_blanks_probs\n",
" prev_found_word = new_pred[0].item()\n",
"\n",
" for i in range(1,no_blanks_probs.size(0)):\n",
" if new_pred[i].item() != prev_found_word:\n",
"\n",
" decoded_labels = torch.cat((decoded_labels,no_blanks_probs[i,:,:].unsqueeze(0)),dim=0)\n",
" prev_found_word = new_pred[i].item()\n",
"\n",
" return decoded_labels.squeeze(1)\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "pdVeXAPNeFaU",
"colab_type": "text"
},
"source": [
"## Gradient checking\n",
"\n",
"Inspect gradients during training. The following function must be called after backward operation and print all gradients of the model."
]
},
{
"cell_type": "code",
"metadata": {
"id": "-JLC-5UBeHNc",
"colab_type": "code",
"colab": {}
},
"source": [
"\n",
"\n",
"def showgradients(model):\n",
" for name,param in model.named_parameters():\n",
" # print name of parameter and weights shape \n",
" print(name,' ',type(param.data), param.size())\n",
" # Weights Gradients tensor\n",
" print( param.grad)\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "wXAWI2Cf5-zm",
"colab_type": "text"
},
"source": [
"This notebook is going to be updated every time something new comes up.\n",
"I hope this guide will be helpful and will save you some time when you start coding in PyTorch. "
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment