Skip to content

Instantly share code, notes, and snippets.

@nzw0301
Created January 12, 2018 14:09
Show Gist options
  • Save nzw0301/66d88fc81687cab6856ad2f1a947383e to your computer and use it in GitHub Desktop.
Save nzw0301/66d88fc81687cab6856ad2f1a947383e to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n",
"/home/nzw/.pyenv/versions/miniconda3-latest/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
" return f(*args, **kwds)\n"
]
}
],
"source": [
"'''\n",
"Based on https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py\n",
"- Change the batch size from 32 to 128\n",
"- Change the number of epochs\n",
"'''\n",
"\n",
"import keras\n",
"from keras.datasets import cifar10\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Activation, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras.preprocessing.image import ImageDataGenerator\n",
"\n",
"from keras_contrib.callbacks import LearningRateWarmRestarter\n",
"\n",
"from __future__ import print_function\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"%matplotlib inline\n",
"\n",
"np.random.seed(7)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"batch_size = 128\n",
"num_classes = 10\n",
"epochs = 40\n",
"num_predictions = 20\n",
"\n",
"(x_train, y_train), (x_test, y_test) = cifar10.load_data()\n",
"x_train = x_train.astype('float32')\n",
"x_test = x_test.astype('float32')\n",
"\n",
"x_train /= 255\n",
"x_test /= 255\n",
"\n",
"y_train = keras.utils.to_categorical(y_train, num_classes)\n",
"y_test = keras.utils.to_categorical(y_test, num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def create_model():\n",
" model = Sequential()\n",
" model.add(Conv2D(32, (3, 3), padding='same',\n",
" input_shape=x_train.shape[1:]))\n",
" model.add(Activation('relu'))\n",
" model.add(Conv2D(32, (3, 3)))\n",
" model.add(Activation('relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.25))\n",
"\n",
" model.add(Conv2D(64, (3, 3), padding='same'))\n",
" model.add(Activation('relu'))\n",
" model.add(Conv2D(64, (3, 3)))\n",
" model.add(Activation('relu'))\n",
" model.add(MaxPooling2D(pool_size=(2, 2)))\n",
" model.add(Dropout(0.25))\n",
"\n",
" model.add(Flatten())\n",
" model.add(Dense(512))\n",
" model.add(Activation('relu'))\n",
" model.add(Dropout(0.5))\n",
" model.add(Dense(num_classes))\n",
" model.add(Activation('softmax'))\n",
"\n",
" opt = keras.optimizers.SGD(lr=0.01, decay=0.0005, momentum=0.9, nesterov=False)\n",
"\n",
" model.compile(loss='categorical_crossentropy',\n",
" optimizer=opt,\n",
" metrics=['accuracy'])\n",
" return model\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using real-time data augmentation.\n",
"Epoch 1/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 2.0405 - acc: 0.2455 - val_loss: 1.7301 - val_acc: 0.3815\n",
"Epoch 2/40\n",
"391/391 [==============================] - 15s 38ms/step - loss: 1.7017 - acc: 0.3739 - val_loss: 1.5146 - val_acc: 0.4629\n",
"Epoch 3/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 1.5202 - acc: 0.4433 - val_loss: 1.3368 - val_acc: 0.5188\n",
"Epoch 4/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 1.4117 - acc: 0.4860 - val_loss: 1.2133 - val_acc: 0.5712\n",
"Epoch 5/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.3376 - acc: 0.5197 - val_loss: 1.1535 - val_acc: 0.5869\n",
"Epoch 6/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.2804 - acc: 0.5407 - val_loss: 1.0958 - val_acc: 0.6137\n",
"Epoch 7/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 1.2197 - acc: 0.5635 - val_loss: 1.0672 - val_acc: 0.6194\n",
"Epoch 8/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 1.1747 - acc: 0.5790 - val_loss: 0.9996 - val_acc: 0.6489\n",
"Epoch 9/40\n",
"391/391 [==============================] - 16s 42ms/step - loss: 1.1441 - acc: 0.5901 - val_loss: 0.9804 - val_acc: 0.6565\n",
"Epoch 10/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.1099 - acc: 0.6041 - val_loss: 0.9411 - val_acc: 0.6686\n",
"Epoch 11/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.0854 - acc: 0.6114 - val_loss: 0.9278 - val_acc: 0.6741\n",
"Epoch 12/40\n",
"391/391 [==============================] - 17s 45ms/step - loss: 1.0600 - acc: 0.6231 - val_loss: 0.9047 - val_acc: 0.6796\n",
"Epoch 13/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 1.0364 - acc: 0.6294 - val_loss: 0.8925 - val_acc: 0.6886\n",
"Epoch 14/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 1.0204 - acc: 0.6363 - val_loss: 0.8963 - val_acc: 0.6900\n",
"Epoch 15/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.0030 - acc: 0.6426 - val_loss: 0.8438 - val_acc: 0.7084\n",
"Epoch 16/40\n",
"391/391 [==============================] - 17s 42ms/step - loss: 0.9880 - acc: 0.6496 - val_loss: 0.8477 - val_acc: 0.7037\n",
"Epoch 17/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.9793 - acc: 0.6519 - val_loss: 0.8369 - val_acc: 0.7091\n",
"Epoch 18/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.9573 - acc: 0.6610 - val_loss: 0.8247 - val_acc: 0.7125\n",
"Epoch 19/40\n",
"391/391 [==============================] - 18s 46ms/step - loss: 0.9453 - acc: 0.6660 - val_loss: 0.8056 - val_acc: 0.7170\n",
"Epoch 20/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 0.9336 - acc: 0.6689 - val_loss: 0.7921 - val_acc: 0.7253\n",
"Epoch 21/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.9246 - acc: 0.6719 - val_loss: 0.7914 - val_acc: 0.7241\n",
"Epoch 22/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 0.9069 - acc: 0.6779 - val_loss: 0.7748 - val_acc: 0.7306\n",
"Epoch 23/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.9056 - acc: 0.6792 - val_loss: 0.7701 - val_acc: 0.7323\n",
"Epoch 24/40\n",
"391/391 [==============================] - 15s 40ms/step - loss: 0.8981 - acc: 0.6820 - val_loss: 0.7651 - val_acc: 0.7309\n",
"Epoch 25/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.8875 - acc: 0.6860 - val_loss: 0.7442 - val_acc: 0.7437\n",
"Epoch 26/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.8805 - acc: 0.6878 - val_loss: 0.7534 - val_acc: 0.7386\n",
"Epoch 27/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 0.8670 - acc: 0.6940 - val_loss: 0.7412 - val_acc: 0.7421\n",
"Epoch 28/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.8653 - acc: 0.6935 - val_loss: 0.7319 - val_acc: 0.7437\n",
"Epoch 29/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.8601 - acc: 0.6952 - val_loss: 0.7489 - val_acc: 0.7422\n",
"Epoch 30/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 0.8504 - acc: 0.6987 - val_loss: 0.7161 - val_acc: 0.7522\n",
"Epoch 31/40\n",
"391/391 [==============================] - 18s 47ms/step - loss: 0.8437 - acc: 0.7019 - val_loss: 0.7148 - val_acc: 0.7513\n",
"Epoch 32/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.8342 - acc: 0.7078 - val_loss: 0.6997 - val_acc: 0.7557\n",
"Epoch 33/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.8354 - acc: 0.7041 - val_loss: 0.7134 - val_acc: 0.7559\n",
"Epoch 34/40\n",
"391/391 [==============================] - 18s 45ms/step - loss: 0.8331 - acc: 0.7049 - val_loss: 0.7083 - val_acc: 0.7539\n",
"Epoch 35/40\n",
"391/391 [==============================] - 18s 45ms/step - loss: 0.8228 - acc: 0.7122 - val_loss: 0.6853 - val_acc: 0.7621\n",
"Epoch 36/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.8185 - acc: 0.7098 - val_loss: 0.6955 - val_acc: 0.7594\n",
"Epoch 37/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 0.8125 - acc: 0.7137 - val_loss: 0.6954 - val_acc: 0.7588\n",
"Epoch 38/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 0.8126 - acc: 0.7132 - val_loss: 0.6907 - val_acc: 0.7602\n",
"Epoch 39/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 0.8019 - acc: 0.7171 - val_loss: 0.6874 - val_acc: 0.7623\n",
"Epoch 40/40\n",
"391/391 [==============================] - 15s 38ms/step - loss: 0.7972 - acc: 0.7200 - val_loss: 0.6755 - val_acc: 0.7669\n"
]
}
],
"source": [
"print('Using real-time data augmentation.')\n",
"# This will do preprocessing and realtime data augmentation:\n",
"datagen = ImageDataGenerator(\n",
" featurewise_center=False, # set input mean to 0 over the dataset\n",
" samplewise_center=False, # set each sample mean to 0\n",
" featurewise_std_normalization=False, # divide inputs by std of the dataset\n",
" samplewise_std_normalization=False, # divide each input by its std\n",
" zca_whitening=False, # apply ZCA whitening\n",
" rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)\n",
" width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)\n",
" height_shift_range=0.1, # randomly shift images vertically (fraction of total height)\n",
" horizontal_flip=True, # randomly flip images\n",
" vertical_flip=False) # randomly flip images\n",
"\n",
"# (std, mean, and principal components if ZCA whitening is applied).\n",
"datagen.fit(x_train)\n",
"\n",
"# Fit the model on the batches generated by datagen.flow().\n",
"model = create_model()\n",
"hist = model.fit_generator(datagen.flow(x_train, y_train,\n",
" batch_size=batch_size),\n",
" epochs=epochs,\n",
" validation_data=(x_test, y_test),\n",
" workers=4)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000/10000 [==============================] - 1s 106us/step\n",
"Test loss: 0.6755105630874634\n",
"Test accuracy: 0.7669\n"
]
}
],
"source": [
"scores = model.evaluate(x_test, y_test, verbose=1)\n",
"print('Test loss:', scores[0])\n",
"print('Test accuracy:', scores[1])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 2.0546 - acc: 0.2359 - val_loss: 1.6990 - val_acc: 0.3935\n",
"Epoch 2/40\n",
"391/391 [==============================] - 14s 36ms/step - loss: 1.6966 - acc: 0.3731 - val_loss: 1.4856 - val_acc: 0.4540\n",
"Epoch 3/40\n",
"391/391 [==============================] - 17s 45ms/step - loss: 1.5238 - acc: 0.4426 - val_loss: 1.3598 - val_acc: 0.5095\n",
"Epoch 4/40\n",
"391/391 [==============================] - 15s 38ms/step - loss: 1.4380 - acc: 0.4778 - val_loss: 1.2942 - val_acc: 0.5294\n",
"Epoch 5/40\n",
"391/391 [==============================] - 17s 42ms/step - loss: 1.3956 - acc: 0.4925 - val_loss: 1.2574 - val_acc: 0.5511\n",
"Epoch 6/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.4110 - acc: 0.4916 - val_loss: 1.2505 - val_acc: 0.5566\n",
"Epoch 7/40\n",
"391/391 [==============================] - 16s 42ms/step - loss: 1.3411 - acc: 0.5178 - val_loss: 1.1696 - val_acc: 0.5819\n",
"Epoch 8/40\n",
"391/391 [==============================] - 16s 42ms/step - loss: 1.2848 - acc: 0.5383 - val_loss: 1.1356 - val_acc: 0.5931\n",
"Epoch 9/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.2403 - acc: 0.5560 - val_loss: 1.0811 - val_acc: 0.6153\n",
"Epoch 10/40\n",
"391/391 [==============================] - 15s 40ms/step - loss: 1.2098 - acc: 0.5679 - val_loss: 1.0347 - val_acc: 0.6337\n",
"Epoch 11/40\n",
"391/391 [==============================] - 16s 40ms/step - loss: 1.1754 - acc: 0.5787 - val_loss: 1.0254 - val_acc: 0.6399\n",
"Epoch 12/40\n",
"391/391 [==============================] - 17s 43ms/step - loss: 1.1499 - acc: 0.5891 - val_loss: 1.0023 - val_acc: 0.6490\n",
"Epoch 13/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 1.1362 - acc: 0.5948 - val_loss: 1.0023 - val_acc: 0.6455\n",
"Epoch 14/40\n",
"391/391 [==============================] - 14s 36ms/step - loss: 1.1291 - acc: 0.5956 - val_loss: 0.9897 - val_acc: 0.6541\n",
"Epoch 15/40\n",
"391/391 [==============================] - 15s 40ms/step - loss: 1.1190 - acc: 0.5983 - val_loss: 0.9908 - val_acc: 0.6504\n",
"Epoch 16/40\n",
"391/391 [==============================] - 14s 37ms/step - loss: 1.1409 - acc: 0.5921 - val_loss: 1.0039 - val_acc: 0.6444\n",
"Epoch 17/40\n",
"391/391 [==============================] - 14s 37ms/step - loss: 1.1228 - acc: 0.5998 - val_loss: 0.9734 - val_acc: 0.6553\n",
"Epoch 18/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 1.1000 - acc: 0.6074 - val_loss: 0.9597 - val_acc: 0.6650\n",
"Epoch 19/40\n",
"391/391 [==============================] - 15s 38ms/step - loss: 1.0823 - acc: 0.6130 - val_loss: 0.9445 - val_acc: 0.6648\n",
"Epoch 20/40\n",
"391/391 [==============================] - 15s 38ms/step - loss: 1.0632 - acc: 0.6207 - val_loss: 0.9186 - val_acc: 0.6777\n",
"Epoch 21/40\n",
"391/391 [==============================] - 15s 37ms/step - loss: 1.0480 - acc: 0.6282 - val_loss: 0.9000 - val_acc: 0.6823\n",
"Epoch 22/40\n",
"391/391 [==============================] - 14s 36ms/step - loss: 1.0389 - acc: 0.6271 - val_loss: 0.8920 - val_acc: 0.6892\n",
"Epoch 23/40\n",
"391/391 [==============================] - 16s 42ms/step - loss: 1.0257 - acc: 0.6322 - val_loss: 0.8795 - val_acc: 0.6892\n",
"Epoch 24/40\n",
"391/391 [==============================] - 14s 37ms/step - loss: 1.0127 - acc: 0.6407 - val_loss: 0.9037 - val_acc: 0.6793\n",
"Epoch 25/40\n",
"391/391 [==============================] - 16s 40ms/step - loss: 1.0083 - acc: 0.6420 - val_loss: 0.8663 - val_acc: 0.6966\n",
"Epoch 26/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.9929 - acc: 0.6479 - val_loss: 0.8512 - val_acc: 0.7001\n",
"Epoch 27/40\n",
"391/391 [==============================] - 15s 38ms/step - loss: 0.9877 - acc: 0.6462 - val_loss: 0.8488 - val_acc: 0.7038\n",
"Epoch 28/40\n",
"391/391 [==============================] - 15s 40ms/step - loss: 0.9863 - acc: 0.6489 - val_loss: 0.8486 - val_acc: 0.6991\n",
"Epoch 29/40\n",
"391/391 [==============================] - 18s 46ms/step - loss: 0.9726 - acc: 0.6559 - val_loss: 0.8419 - val_acc: 0.7045\n",
"Epoch 30/40\n",
"391/391 [==============================] - 15s 40ms/step - loss: 0.9688 - acc: 0.6554 - val_loss: 0.8403 - val_acc: 0.7047\n",
"Epoch 31/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.9637 - acc: 0.6561 - val_loss: 0.8335 - val_acc: 0.7073\n",
"Epoch 32/40\n",
"391/391 [==============================] - 16s 42ms/step - loss: 0.9617 - acc: 0.6581 - val_loss: 0.8363 - val_acc: 0.7063\n",
"Epoch 33/40\n",
"391/391 [==============================] - 17s 45ms/step - loss: 0.9620 - acc: 0.6583 - val_loss: 0.8335 - val_acc: 0.7068\n",
"Epoch 34/40\n",
"391/391 [==============================] - 17s 44ms/step - loss: 0.9626 - acc: 0.6549 - val_loss: 0.8364 - val_acc: 0.7058\n",
"Epoch 35/40\n",
"391/391 [==============================] - 15s 39ms/step - loss: 0.9555 - acc: 0.6584 - val_loss: 0.8344 - val_acc: 0.7062\n",
"Epoch 36/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 0.9761 - acc: 0.6520 - val_loss: 0.8453 - val_acc: 0.7019\n",
"Epoch 37/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 0.9723 - acc: 0.6519 - val_loss: 0.8287 - val_acc: 0.7108\n",
"Epoch 38/40\n",
"391/391 [==============================] - 16s 41ms/step - loss: 0.9632 - acc: 0.6586 - val_loss: 0.8401 - val_acc: 0.7047\n",
"Epoch 39/40\n",
"391/391 [==============================] - 17s 42ms/step - loss: 0.9597 - acc: 0.6581 - val_loss: 0.8228 - val_acc: 0.7120\n",
"Epoch 40/40\n",
"391/391 [==============================] - 15s 40ms/step - loss: 0.9484 - acc: 0.6631 - val_loss: 0.8000 - val_acc: 0.7215\n",
"10000/10000 [==============================] - 1s 94us/step\n",
"Test loss: 0.8000454781532288\n",
"Test accuracy: 0.7215\n"
]
}
],
"source": [
"model = create_model()\n",
"warm_restart_hist = model.fit_generator(datagen.flow(x_train, y_train,\n",
" batch_size=batch_size),\n",
" epochs=epochs,\n",
" validation_data=(x_test, y_test),\n",
" callbacks=[LearningRateWarmRestarter(max_lr=0.01, num_restart_epochs=5, factor=2)], \n",
" workers=4)\n",
"\n",
"\n",
"scores = model.evaluate(x_test, y_test, verbose=1)\n",
"print('Test loss:', scores[0])\n",
"print('Test accuracy:', scores[1])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7f6cfefa3860>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<matplotlib.figure.Figure at 0x7f6cfebf8f28>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.arange(40), warm_restart_hist.history['val_loss'], label='warm restart')\n",
"plt.plot(np.arange(40), hist.history['val_loss'], label='normal')\n",
"plt.title('CIFAR10 validation loss')\n",
"plt.legend()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment