Skip to content

Instantly share code, notes, and snippets.

@lasseha
Last active February 23, 2021 18:11
Show Gist options
  • Save lasseha/f651243e5904f9b44940a807e5325ffb to your computer and use it in GitHub Desktop.
Save lasseha/f651243e5904f9b44940a807e5325ffb to your computer and use it in GitHub Desktop.
vae
class Conv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(Conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)
def forward(self, x):
return self.conv(x)
class ConvTranspose(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(ConvTranspose, self).__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU()
)
def forward(self, x):
return self.conv(x)
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
base = 16
self.encoder = nn.Sequential(
Conv(3, base, 3, stride=2, padding=1),
Conv(base, 2*base, 3, padding=1),
Conv(2*base, 2*base, 3, stride=2, padding=1),
Conv(2*base, 2*base, 3, padding=1),
Conv(2*base, 2*base, 3, stride=2, padding=1),
Conv(2*base, 4*base, 3, padding=1),
Conv(4*base, 4*base, 3, stride=2, padding=1),
Conv(4*base, 4*base, 3, padding=1),
Conv(4*base, 4*base, 3, stride=2, padding=1),
nn.Conv2d(4*base, 64*base, 8),
nn.LeakyReLU()
)
self.encoder_mu = nn.Conv2d(64*base, 32*base, 1)
self.encoder_logvar = nn.Conv2d(64*base, 32*base, 1)
self.decoder = nn.Sequential(
nn.Conv2d(32*base, 64*base, 1),
ConvTranspose(64*base, 4*base, 8),
Conv(4*base, 4*base, 3, padding=1),
ConvTranspose(4*base, 4*base, 4, stride=2, padding=1),
Conv(4*base, 4*base, 3, padding=1),
ConvTranspose(4*base, 4*base, 4, stride=2, padding=1),
Conv(4*base, 2*base, 3, padding=1),
ConvTranspose(2*base, 2*base, 4, stride=2, padding=1),
Conv(2*base, 2*base, 3, padding=1),
ConvTranspose(2*base, 2*base, 4, stride=2, padding=1),
Conv(2*base, base, 3, padding=1),
ConvTranspose(base, base, 4, stride=2, padding=1),
nn.Conv2d(base, 3, 3, padding=1),
nn.Tanh()
)
def encode(self, x):
x = self.encoder(x)
return self.encoder_mu(x), self.encoder_logvar(x)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment