Skip to content

Instantly share code, notes, and snippets.

@lasseha
lasseha / conv_upsampling.py
Last active August 23, 2019 14:15
conv upsampling
class ConvUpsampling(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(ConvUpsampling, self).__init__()
self.scale_factor = kernel_size
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)
@lasseha
lasseha / perceptual_loss.py
Last active August 24, 2019 13:00
perceptual loss
vgg = VGG(pretrained=True)
vgg.eval()
def get_output():
def hook(model, input, output):
model.output = output
return hook
layer = [2,5,9]
for i in layer:
@lasseha
lasseha / kld_loss.py
Last active August 24, 2019 12:58
kld loss
def kld_loss(mu, logvar):
return (-0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()))
def total_loss(img, recon, mu, logvar):
return l1_loss(recon, img) + kld_loss(mu, logvar)
@lasseha
lasseha / vae.py
Last active February 23, 2021 18:11
vae
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(Conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)