Created
March 7, 2017 13:22
-
-
Save f0k/9b0bb51040719eeafec7eba473a9e79b to your computer and use it in GitHub Desktop.
Lasagne LSGAN example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Example employing Lasagne for digit generation using the MNIST dataset and | |
Least Squares Generative Adversarial Networks | |
(LSGANs, see https://arxiv.org/abs/1611.04076 for the paper). | |
It is based on a WGAN example: | |
https://gist.github.com/f0k/f3190ebba6c53887d598d03119ca2066 | |
This, in turn, is based on a DCGAN example: | |
https://gist.github.com/f0k/738fa2eedd9666b78404ed1751336f56 | |
This, in turn, is based on the MNIST example in Lasagne: | |
https://lasagne.readthedocs.io/en/latest/user/tutorial.html | |
Jan Schlüter, 2017-03-07 | |
""" | |
from __future__ import print_function | |
import sys | |
import os | |
import time | |
import numpy as np | |
import theano | |
import theano.tensor as T | |
import lasagne | |
# ################## Download and prepare the MNIST dataset ################## | |
# This is just some way of getting the MNIST dataset from an online location | |
# and loading it into numpy arrays. It doesn't involve Lasagne at all. | |
def load_dataset(): | |
# We first define a download function, supporting both Python 2 and 3. | |
if sys.version_info[0] == 2: | |
from urllib import urlretrieve | |
else: | |
from urllib.request import urlretrieve | |
def download(filename, source='http://yann.lecun.com/exdb/mnist/'): | |
print("Downloading %s" % filename) | |
urlretrieve(source + filename, filename) | |
# We then define functions for loading MNIST images and labels. | |
# For convenience, they also download the requested files if needed. | |
import gzip | |
def load_mnist_images(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
# Read the inputs in Yann LeCun's binary format. | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=16) | |
# The inputs are vectors now, we reshape them to monochrome 2D images, | |
# following the shape convention: (examples, channels, rows, columns) | |
data = data.reshape(-1, 1, 28, 28) | |
# The inputs come as bytes, we convert them to float32 in range [0,1]. | |
# (Actually to range [0, 255/256], for compatibility to the version | |
# provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) | |
return data / np.float32(256) | |
def load_mnist_labels(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
# Read the labels in Yann LeCun's binary format. | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=8) | |
# The labels are vectors of integers now, that's exactly what we want. | |
return data | |
# We can now download and read the training and test set images and labels. | |
X_train = load_mnist_images('train-images-idx3-ubyte.gz') | |
y_train = load_mnist_labels('train-labels-idx1-ubyte.gz') | |
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz') | |
y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz') | |
# We reserve the last 10000 training examples for validation. | |
X_train, X_val = X_train[:-10000], X_train[-10000:] | |
y_train, y_val = y_train[:-10000], y_train[-10000:] | |
# We just return all the arrays in order, as expected in main(). | |
# (It doesn't matter how we do this as long as we can read them again.) | |
return X_train, y_train, X_val, y_val, X_test, y_test | |
# ##################### Build the neural network model ####################### | |
# We create two models: The generator and the critic network. | |
# The models are the same as in the Lasagne DCGAN example, except that the | |
# discriminator is now a critic with linear output instead of sigmoid output. | |
def build_generator(input_var=None): | |
from lasagne.layers import InputLayer, ReshapeLayer, DenseLayer | |
try: | |
from lasagne.layers import TransposedConv2DLayer as Deconv2DLayer | |
except ImportError: | |
raise ImportError("Your Lasagne is too old. Try the bleeding-edge " | |
"version: http://lasagne.readthedocs.io/en/latest/" | |
"user/installation.html#bleeding-edge-version") | |
try: | |
from lasagne.layers.dnn import batch_norm_dnn as batch_norm | |
except ImportError: | |
from lasagne.layers import batch_norm | |
from lasagne.nonlinearities import sigmoid | |
# input: 100dim | |
layer = InputLayer(shape=(None, 100), input_var=input_var) | |
# fully-connected layer | |
layer = batch_norm(DenseLayer(layer, 1024)) | |
# project and reshape | |
layer = batch_norm(DenseLayer(layer, 128*7*7)) | |
layer = ReshapeLayer(layer, ([0], 128, 7, 7)) | |
# two fractional-stride convolutions | |
layer = batch_norm(Deconv2DLayer(layer, 64, 5, stride=2, crop='same', | |
output_size=14)) | |
layer = Deconv2DLayer(layer, 1, 5, stride=2, crop='same', output_size=28, | |
nonlinearity=sigmoid) | |
print ("Generator output:", layer.output_shape) | |
return layer | |
def build_critic(input_var=None): | |
from lasagne.layers import (InputLayer, Conv2DLayer, ReshapeLayer, | |
DenseLayer) | |
try: | |
from lasagne.layers.dnn import batch_norm_dnn as batch_norm | |
except ImportError: | |
from lasagne.layers import batch_norm | |
from lasagne.nonlinearities import LeakyRectify | |
lrelu = LeakyRectify(0.2) | |
# input: (None, 1, 28, 28) | |
layer = InputLayer(shape=(None, 1, 28, 28), input_var=input_var) | |
# two convolutions | |
layer = batch_norm(Conv2DLayer(layer, 64, 5, stride=2, pad='same', | |
nonlinearity=lrelu)) | |
layer = batch_norm(Conv2DLayer(layer, 128, 5, stride=2, pad='same', | |
nonlinearity=lrelu)) | |
# fully-connected layer | |
layer = batch_norm(DenseLayer(layer, 1024, nonlinearity=lrelu)) | |
# output layer (linear) | |
layer = DenseLayer(layer, 1, nonlinearity=None) | |
print ("critic output:", layer.output_shape) | |
return layer | |
# ############################# Batch iterator ############################### | |
# This is just a simple helper function iterating over training data in | |
# mini-batches of a particular size, optionally in random order. It assumes | |
# data is available as numpy arrays. For big datasets, you could load numpy | |
# arrays as memory-mapped files (np.load(..., mmap_mode='r')), or write your | |
# own custom data iteration function. For small datasets, you can also copy | |
# them to GPU at once for slightly improved performance. This would involve | |
# several changes in the main program, though, and is not demonstrated here. | |
def iterate_minibatches(inputs, targets, batchsize, shuffle=False, | |
forever=False): | |
assert len(inputs) == len(targets) | |
if shuffle: | |
indices = np.arange(len(inputs)) | |
while True: | |
if shuffle: | |
np.random.shuffle(indices) | |
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): | |
if shuffle: | |
excerpt = indices[start_idx:start_idx + batchsize] | |
else: | |
excerpt = slice(start_idx, start_idx + batchsize) | |
yield inputs[excerpt], targets[excerpt] | |
if not forever: | |
break | |
# ############################## Main program ################################ | |
# Everything else will be handled in our main program now. We could pull out | |
# more functions to better separate the code, but it wouldn't make it any | |
# easier to read. | |
def main(num_epochs=1000, epochsize=100, batchsize=64, initial_eta=1e-4): | |
# Load the dataset | |
print("Loading data...") | |
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset() | |
# Prepare Theano variables for inputs and targets | |
noise_var = T.matrix('noise') | |
input_var = T.tensor4('inputs') | |
# Create neural network model | |
print("Building model and compiling functions...") | |
generator = build_generator(noise_var) | |
critic = build_critic(input_var) | |
# Create expression for passing real data through the critic | |
real_out = lasagne.layers.get_output(critic) | |
# Create expression for passing fake data through the critic | |
fake_out = lasagne.layers.get_output(critic, | |
lasagne.layers.get_output(generator)) | |
# Create loss expressions to be minimized | |
# a, b, c = -1, 1, 0 # Equation (8) in the paper | |
a, b, c = 0, 1, 1 # Equation (9) in the paper | |
generator_loss = lasagne.objectives.squared_error(fake_out, c).mean() | |
critic_loss = (lasagne.objectives.squared_error(real_out, b).mean() + | |
lasagne.objectives.squared_error(fake_out, a).mean()) | |
# Create update expressions for training | |
generator_params = lasagne.layers.get_all_params(generator, trainable=True) | |
critic_params = lasagne.layers.get_all_params(critic, trainable=True) | |
eta = theano.shared(lasagne.utils.floatX(initial_eta)) | |
generator_updates = lasagne.updates.rmsprop( | |
generator_loss, generator_params, learning_rate=eta) | |
critic_updates = lasagne.updates.rmsprop( | |
critic_loss, critic_params, learning_rate=eta) | |
# Instantiate a symbolic noise generator to use for training | |
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams | |
srng = RandomStreams(seed=np.random.randint(2147462579, size=6)) | |
noise = srng.uniform((batchsize, 100)) | |
# Compile functions performing a training step on a mini-batch (according | |
# to the updates dictionary) and returning the corresponding score: | |
generator_train_fn = theano.function([], generator_loss, | |
givens={noise_var: noise}, | |
updates=generator_updates) | |
critic_train_fn = theano.function([input_var], critic_loss, | |
givens={noise_var: noise}, | |
updates=critic_updates) | |
# Compile another function generating some data | |
gen_fn = theano.function([noise_var], | |
lasagne.layers.get_output(generator, | |
deterministic=True)) | |
# Finally, launch the training loop. | |
print("Starting training...") | |
# We create an infinite supply of batches (as an iterable generator): | |
batches = iterate_minibatches(X_train, y_train, batchsize, shuffle=True, | |
forever=True) | |
# We iterate over epochs: | |
generator_updates = 0 | |
for epoch in range(num_epochs): | |
start_time = time.time() | |
# In each epoch, we do `epochsize` generator and critic updates. | |
critic_losses = [] | |
generator_losses = [] | |
for _ in range(epochsize): | |
inputs, targets = next(batches) | |
critic_losses.append(critic_train_fn(inputs)) | |
generator_losses.append(generator_train_fn()) | |
# Then we print the results for this epoch: | |
print("Epoch {} of {} took {:.3f}s".format( | |
epoch + 1, num_epochs, time.time() - start_time)) | |
print(" generator loss: {}".format(np.mean(generator_losses))) | |
print(" critic loss: {}".format(np.mean(critic_losses))) | |
# And finally, we plot some generated data | |
samples = gen_fn(lasagne.utils.floatX(np.random.rand(42, 100))) | |
try: | |
import matplotlib.pyplot as plt | |
except ImportError: | |
pass | |
else: | |
plt.imsave('lsgan_mnist_samples.png', | |
(samples.reshape(6, 7, 28, 28) | |
.transpose(0, 2, 1, 3) | |
.reshape(6*28, 7*28)), | |
cmap='gray') | |
# After half the epochs, we start decaying the learn rate towards zero | |
if epoch >= num_epochs // 2: | |
progress = float(epoch) / num_epochs | |
eta.set_value(lasagne.utils.floatX(initial_eta*2*(1 - progress))) | |
# Optionally, you could now dump the network weights to a file like this: | |
np.savez('lsgan_mnist_gen.npz', *lasagne.layers.get_all_param_values(generator)) | |
np.savez('lsgan_mnist_crit.npz', *lasagne.layers.get_all_param_values(critic)) | |
# | |
# And load them again later on like this: | |
# with np.load('model.npz') as f: | |
# param_values = [f['arr_%d' % i] for i in range(len(f.files))] | |
# lasagne.layers.set_all_param_values(network, param_values) | |
if __name__ == '__main__': | |
if ('--help' in sys.argv) or ('-h' in sys.argv): | |
print("Trains a LSGAN on MNIST using Lasagne.") | |
print("Usage: %s [EPOCHS [EPOCHSIZE]]" % sys.argv[0]) | |
print() | |
print("EPOCHS: number of training epochs to perform (default: 1000)") | |
print("EPOCHSIZE: number of network updates per epoch (default: 100)") | |
else: | |
kwargs = {} | |
if len(sys.argv) > 1: | |
kwargs['num_epochs'] = int(sys.argv[1]) | |
if len(sys.argv) > 2: | |
kwargs['epochsize'] = int(sys.argv[2]) | |
main(**kwargs) |
@pclucas14 Just tested with the lasagne.updates.adam. Within 200 epochs of training, both rmsprop and adam work for the lsgan_cifar10.py code. I just used the default hyperparams. The adam training seems slightly noisier at the beginning. According to the LSGAN paper, both should work as well.
Also, is there any specific reason as to why rmsprop is used (and not adam) ?
From p.12 of the paper, rmsprop is more stable than adam: "First, for BN_G with Adam, there is a chance for LSGANs to generate relatively good quality images. We test 10 times, and 5 of them succeeds to generate relatively good quality images. [...] Third, [...] for BN_G with RMSProp, both LSGANs and regular GANs learn the data distribution successfully, [...]."
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks for sharing @hma02, looking foward to try it out! Also, is there any specific reason as to why rmsprop is used (and not adam) ? Maybe it's a leftover from WGAN ?