Skip to content

Instantly share code, notes, and snippets.

@JJRyan0
Created May 26, 2017 09:08
Show Gist options
  • Save JJRyan0/6331403f8aaf2e7b9d2b00ebf9181d69 to your computer and use it in GitHub Desktop.
Save JJRyan0/6331403f8aaf2e7b9d2b00ebf9181d69 to your computer and use it in GitHub Desktop.
Digit Image Recognition SGD.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": "## Sci-Kit Learn: Image Classification - Predicting Digits - Stocastic Gradient Descent (SGD)\n\nJohn Ryan - 18 May 2017\n\n### Digit Recognition - MNIST dataset\n\nDataset containing 70,000 small images of handwritten digits.\n\n\n### Contents\n\nStep 1:\n\n- Preprocessing data\n\n- Stocastic Gradient Descent binary classifier\n\n- K-fold Cross Validation\n\n\n**What is Stocastic Gradient Descent?**\n\nIt is a stochastic approximation of the gradient descent optimization algorithm for minimizing an objective function that is written as a sum of differentiable functions. It has the advantage of working on instances individually, one at a time which makes it highly efficent on large data sets. SGD picks a random instance in the raining set at every step and computes its gradients based only on the individual instance selected."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 1. Fetching Digit Images"
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#Import the dataset with a helper function from the sklearn machine learning dataset store\nfrom sklearn.datasets import fetch_mldata\nmnist = fetch_mldata('MNIST original')\nmnist",
"execution_count": 1,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "{'COL_NAMES': ['label', 'data'],\n 'DESCR': 'mldata.org dataset: mnist-original',\n 'data': array([[0, 0, 0, ..., 0, 0, 0],\n [0, 0, 0, ..., 0, 0, 0],\n [0, 0, 0, ..., 0, 0, 0],\n ..., \n [0, 0, 0, ..., 0, 0, 0],\n [0, 0, 0, ..., 0, 0, 0],\n [0, 0, 0, ..., 0, 0, 0]], dtype=uint8),\n 'target': array([ 0., 0., 0., ..., 9., 9., 9.])}"
},
"metadata": {},
"execution_count": 1
}
]
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#creating an X and Y variable to hold the data for training purposes later\nX, y = mnist[\"data\"], mnist[\"target\"]\nX.shape",
"execution_count": 2,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "(70000L, 784L)"
},
"metadata": {},
"execution_count": 2
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 2. Preprocessing the images\n\n#### Train/Test Split\n\n- The dataset used for this example is already split into a training and test set using the first 60,000 images as a training set and the final 10,000 as the test set. Therefore we will not perform a normal train test split on the data. \n\n#### Shuffling\n\n- With image data that comes from different sources it is sometimes important to shuffle the dataset to get an even representation across the training and test set to remove any possibly of the model training on data that comes from one specific source.The obvious exception to this with data that has an ordering or time sequence.\n\nThe shuffle will make sure that later in the cross validation stage the folds wont be non-similier and match in size. "
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#Shuffle the dataset across training and test sets\nimport numpy as np\nX_train, X_test, y_train, y_test = X[:60000], X[:60000],y[:60000], y[:60000]\nshuffle = np.random.permutation(60000)\nX_train, y_train = X_train[shuffle], y_train[shuffle]",
"execution_count": 3,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "#### Training the Classification Model"
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#Creates the target training & test set with just digit 4 binary target: 4 = True, other = False\ndigit = X[6000]\ny_train4 = (y_train == 4)\ny_test4 = (y_train == 4)\ny_train4",
"execution_count": 4,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([False, False, False, ..., False, False, False], dtype=bool)"
},
"metadata": {},
"execution_count": 4
}
]
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#import the plotting libaries and imshow to display the 29,000 image picked at random\n%matplotlib inline\nimport matplotlib\nimport matplotlib.pyplot as plt\ndigit = X[29000]\ndigit_im = digit.reshape(28,28)\nplt.imshow(digit_im, cmap = matplotlib.cm.binary,\n interpolation=\"nearest\")\nplt.axis(\"off\")\nplt.show()",
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABd1JREFUeJzt3b9rFFsYx+Hdi4UptEyjVsZGohCsIvijsQ5220ksLERM\nK4iV/4CVCrYWgoLYiaWNhaKVFmICVoJaaCEsFtlb3HreSTKZ3dx8n6d9M2dG4cMpzu7scDKZDIA8\n/8z6AYDZED+EEj+EEj+EEj+EEj+EEj+EEj+EEj+EOjDl+/k4IfRvuJU/svNDKPFDKPFDKPFDKPFD\nKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFD\nKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDKPFDqAOzfgBo8vnz53K+vLxc\nzi9dutQ4e/LkyY6eaT+x80Mo8UMo8UMo8UMo8UMo8UMo8UMo5/zsWQ8ePCjnv379Kufj8Xg3H2ff\nsfNDKPFDKPFDKPFDKPFDKPFDKEd9zMzPnz/L+cOHDzutv7Ky0un6/c7OD6HED6HED6HED6HED6HE\nD6HED6Gc809B21dPb926Vc7v3LlTzo8cObLtZ9oLXr58Wc7//v1bztv+3aurq9t+piR2fgglfggl\nfgglfgglfgglfgglfgjlnH8KHj9+XM4fPXpUzk+fPl3Or1+/vu1nmpYvX740zto+v9Dm9u3bna5P\nZ+eHUOKHUOKHUOKHUOKHUOKHUOKHUM75d8H379/L+c2bNzutv7m52en6Wfrw4UPj7OvXr53WPnfu\nXKfr09n5IZT4IZT4IZT4IZT4IZT4IZT4IZRz/l3w4sWLcj4cDjutf/ny5U7Xz1L1ff62/5f5+fly\nfvjw4R09E/+x80Mo8UMo8UMo8UMo8UMo8UMoR31bNB6PG2d3797ttPbVq1fL+V7+Ce62r+W2vZa8\ncuXKlXJ+9OjRHa+NnR9iiR9CiR9CiR9CiR9CiR9CiR9CDSeTyTTvN9Wb7aanT582zkajUae1FxYW\nyvmNGzfK+alTpzrdv4tr166V8/X19R2vvby8XM77/PzD2bNny/na2lpv994FW/oOuZ0fQokfQokf\nQokfQokfQokfQokfQvk+/x5Qvd56MGg/U+76avAu2j4n0uXZ3rx5s+NrB4PB4Pz5842zxcXF8tqL\nFy92uvf/gZ0fQokfQokfQokfQokfQokfQokfQjnn36KlpaXGWdv75Z8/f17Of//+Xc6n/M6Fbeny\nbIcOHSrn79+/L+fHjx/f8b2x80Ms8UMo8UMo8UMo8UMo8UMo8UMo7+2fgk+fPpXzHz9+lPPXr1+X\n848fPzbO5ubmymvPnDlTzjc2Nsr5vXv3ynn1ff75+fny2m/fvpVzGnlvP9BM/BBK/BBK/BBK/BBK\n/BDKV3qn4OTJk52uv3Dhwi49yfbdv3+/t7UPHjzY29q0s/NDKPFDKPFDKPFDKPFDKPFDKPFDKOf8\nlF69etXb2qurq72tTTs7P4QSP4QSP4QSP4QSP4QSP4QSP4Ryzk9pfX29t7VHo1Fva9POzg+hxA+h\nxA+hxA+hxA+hxA+hxA+hnPOH+/PnTzkfj8flvO0n3peWlhpnx44dK6+lX3Z+CCV+CCV+CCV+CCV+\nCCV+COWoL9zbt2/L+cbGRjkfDofl/MSJE42zubm58lr6ZeeHUOKHUOKHUOKHUOKHUOKHUOKHUM75\nw7Wd87N/2fkhlPghlPghlPghlPghlPghlPghlHP+cO/evZv1IzAjdn4IJX4IJX4IJX4IJX4IJX4I\nJX4I5Zw/3Gg0KufPnj3rtP7i4mKn6+mPnR9CiR9CiR9CiR9CiR9CiR9CiR9CDSeTyTTvN9WbQajh\nVv7Izg+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+hxA+h\nxA+hxA+hxA+hxA+hpv0T3Vt6pTDQPzs/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/\nhBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hBI/hPoXilG1gMuy+hwAAAAASUVORK5CYII=\n",
"text/plain": "<matplotlib.figure.Figure at 0x845f748>"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 3. Build Basic Classifier"
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#Create the SGD classifier, as the SGD relies on the randomness of the data\n#we need set a random state\nfrom sklearn.linear_model import SGDClassifier\nsgdmodel = SGDClassifier(random_state=60)\nsgdmodel.fit(X_train,y_train4)",
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,\n eta0=0.0, fit_intercept=True, l1_ratio=0.15,\n learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,\n penalty='l2', power_t=0.5, random_state=60, shuffle=True, verbose=0,\n warm_start=False)"
},
"metadata": {},
"execution_count": 6
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 3.1 Prediction of the digit 4"
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#Call the predict function to identify if the digit is a 4, \n#true below indicates that it is.\nsgdmodel.predict([digit])",
"execution_count": 7,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([ True], dtype=bool)"
},
"metadata": {},
"execution_count": 7
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### 4 Evaluating Model Performance\n\n##### 4.1. Cross-validation - Kfold \n\n##### What is KFold Cross - Validation?\n\nThis means splitting the data into a number of folds specified to get a maximum prediction, in this we used 3.K fold cross validation will make prediction and evaluate a model performance on each fold using a model that has been trained on the other remaining folds.\n"
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#this outputs a validation score using 3 folds\nfrom sklearn.model_selection import cross_val_score\ncross_val_score(sgdmodel,X_train, y_train4, cv=3, scoring='accuracy')",
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([ 0.97315134, 0.9791 , 0.96094805])"
},
"metadata": {},
"execution_count": 8
}
]
},
{
"metadata": {
"collapsed": false,
"trusted": true
},
"cell_type": "code",
"source": "#this outputs a validation score using 3 folds\nfrom sklearn.model_selection import cross_val_score\ncross_val_score(sgdmodel,X_train, y_train4, cv=10, scoring='accuracy')",
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "array([ 0.96850525, 0.95884019, 0.97816667, 0.97566667, 0.97333333,\n 0.94433333, 0.9775 , 0.961 , 0.97549592, 0.97832972])"
},
"metadata": {},
"execution_count": 9
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We see strong results for the cross validation scores on the 3 folds, indicating a 97% accuracy which you would expect for a simple classifer that classifies that every image is not a 4."
}
],
"metadata": {
"kernelspec": {
"name": "python2",
"display_name": "Python 2",
"language": "python"
},
"language_info": {
"mimetype": "text/x-python",
"nbconvert_exporter": "python",
"name": "python",
"pygments_lexer": "ipython2",
"version": "2.7.13",
"file_extension": ".py",
"codemirror_mode": {
"version": 2,
"name": "ipython"
}
},
"gist": {
"id": "",
"data": {
"description": "Digit Image Recognition SGD.ipynb",
"public": true
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment