Skip to content

Instantly share code, notes, and snippets.

@lasseha
Last active August 24, 2019 13:00
Show Gist options
  • Save lasseha/97efcd4f4da9ecf0782c83000b252db4 to your computer and use it in GitHub Desktop.
Save lasseha/97efcd4f4da9ecf0782c83000b252db4 to your computer and use it in GitHub Desktop.
perceptual loss
vgg = VGG(pretrained=True)
vgg.eval()
def get_output():
def hook(model, input, output):
model.output = output
return hook
layer = [2,5,9]
for i in layer:
vgg.features[i].register_forward_hook(get_output())
def perceptual_loss(img, recon):
with torch.no_grad():
img_out = vgg(img)
features_img = []
for i in layer:
features_img.append(vgg.features[i].output)
recon_out = vgg(recon)
features_recon = []
for i in layer:
features_recon.append(vgg.features[i].output)
loss = 0.0
for i in range(len(layer)):
loss += l1_loss(features_recon[i], features_img[i])
return loss
def kld_loss(mu, logvar):
return (-0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()))
def total_loss(img, recon, mu, logvar):
return l1_loss(recon, img) + kld_loss(mu, logvar) + perceptual_loss(img, recon)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment