Skip to content

Instantly share code, notes, and snippets.

@RasinGue
Created July 8, 2019 09:07
Show Gist options
  • Save RasinGue/cd9b4077388c823feb808cddf33c8607 to your computer and use it in GitHub Desktop.
Save RasinGue/cd9b4077388c823feb808cddf33c8607 to your computer and use it in GitHub Desktop.
import VGG19
import torch.nn as nn
class VGGLoss(nn.Module):
def __init__(self, gpu_ids):
super(VGGLoss, self).__init__()
self.vgg = VGG19.cuda()
self.criterion = nn.L1Loss()
self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def forward(self, x, y):
x_vgg, y_vgg = self.vgg(x), self.vgg(y)
loss = 0
for i in range(len(x_vgg)):
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
return loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment