{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "TPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "eoNiO909Cywc" }, "source": [ "## Lab Introduction\n", "This lab will introduce you to using [TensorFlow](https://www.tensorflow.org/) for building multi-layer perceptrons and convolutional neural networks. This is loosely based on the tutorials [here](https://www.tensorflow.org/guide/keras/training_with_built_in_methods). TensorFlow has [excellent documentation](https://www.tensorflow.org/api_docs/python/tf/keras). If you're running this in Colab, you can improve the speed of execution by selecting : Runtime | Change Runtime Type and selecting GPU (graphics processing unit) or [TPU](https://en.wikipedia.org/wiki/Tensor_Processing_Unit)" ] }, { "cell_type": "markdown", "metadata": { "id": "Qy0u8f9ZIHaQ" }, "source": [ "## Imports\n", "We need to import tensorflow, commonly named as tf.\n", "We also use tf.keras and tf.keras.layers a lot as well, so we rename those.\n" ] }, { "cell_type": "code", "metadata": { "id": "YkbiGFbfaoIH" }, "source": [ "import tensorflow as tf\n", "keras = tf.keras\n", "layers = tf.keras.layers\n", "# Include numpy for basic stuff\n", "import numpy as np\n", "# matplotlib allows us to visualise our data.\n", "import matplotlib.pyplot as plt\n", "# Import a library for displaying models\n", "from IPython.display import Image" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "lADpUL33aqzB" }, "source": [ "## Let's get some data\n", "For this lab we're going to look at the [MNIST handwritten digit dataset](http://yann.lecun.com/exdb/mnist/). This dataset has been explored for a wide variety of models and serves as a standard benchmark for comparing things. Today we're going to be looking at a digit classification task.\n", "1. Look at the size and shape of the data elements. How large are the images? How are the labels stored?\n", "2. Use [plt.imshow](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.imshow.html?highlight=imshow#matplotlib.pyplot.imshow) to visualise some examples from the data. Check what the corresponding label for each example is." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 319 }, "id": "qESOAFEUbBjs", "outputId": "74aaa35d-1505-48f6-9c51-b32d2b48390b" }, "source": [ "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", "# Print the shapes of the array and the unique label\n", "print(x_train.shape, y_train.shape, np.unique(y_test))\n", "\n", "# Pick an example index\n", "example_idx = 0\n", "plt.imshow(x_train[example_idx,...],cmap='gray')\n", "# Print the label\n", "plt.show()\n" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz\n", "11490434/11490434 [==============================] - 0s 0us/step\n", "(60000, 28, 28) (60000,) [0 1 2 3 4 5 6 7 8 9]\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAN8klEQVR4nO3df6jVdZ7H8ddrbfojxzI39iZOrWOEUdE6i9nSyjYRTj8o7FYMIzQ0JDl/JDSwyIb7xxSLIVu6rBSDDtXYMus0UJHFMNVm5S6BdDMrs21qoxjlphtmmv1a9b1/3K9xp+75nOs53/PD+34+4HDO+b7P93zffPHl99f53o8jQgAmvj/rdQMAuoOwA0kQdiAJwg4kQdiBJE7o5sJsc+of6LCI8FjT29qy277C9lu237F9ezvfBaCz3Op1dtuTJP1B0gJJOyW9JGlRROwozMOWHeiwTmzZ50l6JyLejYgvJf1G0sI2vg9AB7UT9hmS/jjq/c5q2p+wvcT2kO2hNpYFoE0dP0EXEeskrZPYjQd6qZ0t+y5JZ4x6/51qGoA+1E7YX5J0tu3v2j5R0o8kbaynLQB1a3k3PiIO2V4q6SlJkyQ9EBFv1NYZgFq1fOmtpYVxzA50XEd+VAPg+EHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEi0P2Yzjw6RJk4r1U045paPLX7p0acPaSSedVJx39uzZxfqtt95arN9zzz0Na4sWLSrO+/nnnxfrK1euLNbvvPPOYr0X2gq77fckHZB0WNKhiJhbR1MA6lfHlv3SiPiwhu8B0EEcswNJtBv2kPS07ZdtLxnrA7aX2B6yPdTmsgC0od3d+PkRscv2X0h6xvZ/R8Tm0R+IiHWS1kmS7WhzeQBa1NaWPSJ2Vc97JD0maV4dTQGoX8thtz3Z9pSjryX9QNL2uhoDUK92duMHJD1m++j3/HtE/L6WriaYM888s1g/8cQTi/WLL764WJ8/f37D2tSpU4vzXn/99cV6L+3cubNYX7NmTbE+ODjYsHbgwIHivK+++mqx/sILLxTr/ajlsEfEu5L+qsZeAHQQl96AJAg7kARhB5Ig7EAShB1IwhHd+1HbRP0F3Zw5c4r1TZs2Feudvs20Xx05cqRYv/nmm4v1Tz75pOVlDw8PF+sfffRRsf7WW2+1vOxOiwiPNZ0tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX2GkybNq1Y37JlS7E+a9asOtupVbPe9+3bV6xfeumlDWtffvllcd6svz9oF9fZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJhmyuwd69e4v1ZcuWFetXX311sf7KK68U683+pHLJtm3bivUFCxYU6wcPHizWzzvvvIa12267rTgv6sWWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4H72PnDyyScX682GF167dm3D2uLFi4vz3njjjcX6hg0binX0n5bvZ7f9gO09trePmjbN9jO2366eT62zWQD1G89u/K8kXfG1abdLejYizpb0bPUeQB9rGvaI2Czp678HXShpffV6vaRr620LQN1a/W38QEQcHSzrA0kDjT5oe4mkJS0uB0BN2r4RJiKidOItItZJWidxgg7opVYvve22PV2Squc99bUEoBNaDftGSTdVr2+S9Hg97QDolKa78bY3SPq+pNNs75T0c0krJf3W9mJJ70v6YSebnOj279/f1vwff/xxy/PecsstxfrDDz9crDcbYx39o2nYI2JRg9JlNfcCoIP4uSyQBGEHkiDsQBKEHUiCsANJcIvrBDB58uSGtSeeeKI47yWXXFKsX3nllcX6008/Xayj+xiyGUiOsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BHfWWWcV61u3bi3W9+3bV6w/99xzxfrQ0FDD2n333Vect5v/NicSrrMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJcZ09ucHCwWH/wwQeL9SlTprS87OXLlxfrDz30ULE+PDxcrGfFdXYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSILr7Cg6//zzi/XVq1cX65dd1vpgv2vXri3WV6xYUazv2rWr5WUfz1q+zm77Adt7bG8fNe0O27tsb6seV9XZLID6jWc3/leSrhhj+r9ExJzq8bt62wJQt6Zhj4jNkvZ2oRcAHdTOCbqltl+rdvNPbfQh20tsD9lu/MfIAHRcq2H/haSzJM2RNCxpVaMPRsS6iJgbEXNbXBaAGrQU9ojYHRGHI+KIpF9KmldvWwDq1lLYbU8f9XZQ0vZGnwXQH5peZ7e9QdL3JZ0mabekn1fv50gKSe9J+mlENL25mOvsE8/UqVOL9WuuuaZhrdm98vaYl4u/smnTpmJ9wYIFxfpE1eg6+wnjmHHRGJPvb7sjAF3Fz2WBJAg7kARhB5Ig7EAShB1Igltc0TNffPFFsX7CCeWLRYcOHSrWL7/88oa1559/vjjv8Yw/JQ0kR9iBJAg7kARhB5Ig7EAShB1IgrADSTS96w25XXDBBcX6DTfcUKxfeOGFDWvNrqM3s2PHjmJ98+bNbX3/RMOWHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeS4Dr7BDd79uxifenSpcX6ddddV6yffvrpx9zTeB0+fLhYHx4u//XyI0eO1NnOcY8tO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwXX240Cza9mLFo010O6IZtfRZ86c2UpLtRgaGirWV6xYUaxv3LixznYmvKZbdttn2H7O9g7bb9i+rZo+zfYztt+unk/tfLsAWjWe3fhDkv4+Is6V9DeSbrV9rqTbJT0bEWdLerZ6D6BPNQ17RAxHxNbq9QFJb0qaIWmhpPXVx9ZLurZDPQKowTEds9ueKel7krZIGoiIoz9O/kDSQIN5lkha0kaPAGow7rPxtr8t6RFJP4uI/aNrMTI65JiDNkbEuoiYGxFz2+oUQFvGFXbb39JI0H8dEY9Wk3fbnl7Vp0va05kWAdSh6W68bUu6X9KbEbF6VGmjpJskrayeH+9IhxPAwMCYRzhfOffcc4v1e++9t1g/55xzjrmnumzZsqVYv/vuuxvWHn+8/E+GW1TrNZ5j9r+V9GNJr9veVk1brpGQ/9b2YknvS/phRzoEUIumYY+I/5I05uDuki6rtx0AncLPZYEkCDuQBGEHkiDsQBKEHUiCW1zHadq0aQ1ra9euLc47Z86cYn3WrFmttFSLF198sVhftWpVsf7UU08V65999tkx94TOYMsOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0mkuc5+0UUXFevLli0r1ufNm9ewNmPGjJZ6qsunn37asLZmzZrivHfddVexfvDgwZZ6Qv9hyw4kQdiBJAg7kARhB5Ig7EAShB1IgrADSaS5zj44ONhWvR07duwo1p988sli/dChQ8V66Z7zffv2FedFHmzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJR0T5A/YZkh6SNCApJK2LiH+1fYekWyT9b/XR5RHxuybfVV4YgLZFxJijLo8n7NMlTY+IrbanSHpZ0rUaGY/9k4i4Z7xNEHag8xqFfTzjsw9LGq5eH7D9pqTe/mkWAMfsmI7Zbc+U9D1JW6pJS22/ZvsB26c2mGeJ7SHbQ+21CqAdTXfjv/qg/W1JL0haERGP2h6Q9KFGjuP/SSO7+jc3+Q5244EOa/mYXZJsf0vSk5KeiojVY9RnSnoyIs5v8j2EHeiwRmFvuhtv25Lul/Tm6KBXJ+6OGpS0vd0mAXTOeM7Gz5f0n5Jel3Skmrxc0iJJczSyG/+epJ9WJ/NK38WWHeiwtnbj60LYgc5reTcewMRA2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLbQzZ/KOn9Ue9Pq6b1o37trV/7kuitVXX29peNCl29n/0bC7eHImJuzxoo6Nfe+rUvid5a1a3e2I0HkiDsQBK9Dvu6Hi+/pF9769e+JHprVVd66+kxO4Du6fWWHUCXEHYgiZ6E3fYVtt+y/Y7t23vRQyO237P9uu1tvR6frhpDb4/t7aOmTbP9jO23q+cxx9jrUW932N5Vrbtttq/qUW9n2H7O9g7bb9i+rZre03VX6Ksr663rx+y2J0n6g6QFknZKeknSoojY0dVGGrD9nqS5EdHzH2DY/jtJn0h66OjQWrb/WdLeiFhZ/Ud5akT8Q5/0doeOcRjvDvXWaJjxn6iH667O4c9b0Yst+zxJ70TEuxHxpaTfSFrYgz76XkRslrT3a5MXSlpfvV6vkX8sXdegt74QEcMRsbV6fUDS0WHGe7ruCn11RS/CPkPSH0e936n+Gu89JD1t+2XbS3rdzBgGRg2z9YGkgV42M4amw3h309eGGe+bddfK8Oft4gTdN82PiL+WdKWkW6vd1b4UI8dg/XTt9BeSztLIGIDDklb1splqmPFHJP0sIvaPrvVy3Y3RV1fWWy/CvkvSGaPef6ea1hciYlf1vEfSYxo57Ognu4+OoFs97+lxP1+JiN0RcTgijkj6pXq47qphxh+R9OuIeLSa3PN1N1Zf3VpvvQj7S5LOtv1d2ydK+pGkjT3o4xtsT65OnMj2ZEk/UP8NRb1R0k3V65skPd7DXv5Evwzj3WiYcfV43fV8+POI6PpD0lUaOSP/P5L+sRc9NOhrlqRXq8cbve5N0gaN7Nb9n0bObSyW9OeSnpX0tqT/kDStj3r7N40M7f2aRoI1vUe9zdfILvprkrZVj6t6ve4KfXVlvfFzWSAJTtABSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBL/DyJ7caZa7LphAAAAAElFTkSuQmCC\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "markdown", "metadata": { "id": "ZWFV7px3hGTy" }, "source": [ "## Preparing the data for training\n", "To get the data ready for training, it's best to squash the input pixel values to being in the range 0-1. Image data is normally stored using unsigned 8bit integers, so takes value between 0 and 255. We can simply convert to floats and divide by 255.\n", "\n", "We also hold some data from the test set out for validation, and shuffle the training data (in case the data order could cause some bias).\n", "\n", "Trace through the code below to understand what's going on.\n" ] }, { "cell_type": "code", "metadata": { "id": "4eZPB72Hhckz" }, "source": [ "# Divide the image data to put it in the right range and convert to floating point numbers\n", "x_train = x_train.astype(\"float32\") / 255\n", "x_test = x_test.astype(\"float32\") / 255\n", "# Convert the labels to floating point\n", "y_train = y_train.astype(\"float32\")\n", "y_test = y_test.astype(\"float32\")\n", "\n", "# Reserve the last 10,000 samples for validation. We can use these to optimise our hyper-parameters.\n", "x_val = x_train[-10000:]\n", "y_val = y_train[-10000:]\n", "x_train = x_train[:-10000]\n", "y_train = y_train[:-10000]\n", "\n", "# Prepare the training dataset into batches and shuffle the examples\n", "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n", "\n", "# Prepare the validation dataset into batches\n", "val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n", "val_dataset = val_dataset.batch(64)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "UHMkka1BcAc-" }, "source": [ "## Define an MLP model\n", "Firstly, let's try and define a simple MLP model. In TensorFlow, we refer to fully-connected layers as [Dense](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense).\n", "\n", "In this section, we will use the simple [sequential model definition approach](https://www.tensorflow.org/guide/keras/sequential_model). This assumes we have a single input and output array, and we successively apply the model layers (building blocks).\n", "\n", "Fill in the ? in the section below to create an MLP model and draw a picture of the model." ] }, { "cell_type": "code", "metadata": { "id": "2n4oyvOublI4", "colab": { "base_uri": "https://localhost:8080/", "height": 275 }, "outputId": "0d1fb4db-eca0-4a8a-db98-14f35393911a" }, "source": [ "def build_mlp_model():\n", " # Create an MLP model using the sequential model API\n", " # Note that the Input size relates to the size of each instance, so each image has size 28,28\n", " layers_list = [keras.Input(shape=(28,28), name=\"image\"),\n", " # First flatten out the spatial component, to make one long vector\n", " layers.Flatten(),\n", " # Next add a dense layer with a suitable activation function\n", " layers.Dense(100, activation='relu'),\n", " layers.Dense(60, activation='relu'),\n", " # Finally, add another dense layer which has the same number of units as the number of classes\n", " # This layer should use a softmax activation, which converts the numbers into class probabilities\n", " layers.Dense(10, activation='softmax')]\n", "\n", " # Create a model based on the layer list\n", " model = keras.Sequential(layers_list)\n", " return model\n", "\n", "# Create the model\n", "mlp_model = build_mlp_model()\n", "# Draw a picture of the MLP model details\n", "keras.utils.plot_model(mlp_model, \"mlp_model.png\", show_shapes=True)\n", "Image(retina=True, filename='mlp_model.png')\n" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": { "image/png": { "width": 164, "height": 258 } }, "execution_count": 4 } ] }, { "cell_type": "markdown", "metadata": { "id": "Qw4YC1qxKO6p" }, "source": [ "## Training the model\n", "In order to train the model, we need to [compile](https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile) it for training. This involves specifying the loss functions we will use, any \"metrics\" to calculate on the data and the optimizer to use it.\n", "\n", "1. Read this [section of the tutorial](https://keras.io/api/losses) on building models to identify some suitable losses and metrics to compile the model. Investigate the optimizer choices, and the \"learning_rate\" hyper-parameter.\n", "2. Call the model's [fit](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit) method. To understand what happens in fit and how you can change it, see this [tutorial](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit/). Remember an epoch, is the number of times we train the model with every item in the training dataset." ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "UWf2AHi0c6VY", "outputId": "193bdfd8-303b-4b97-87c4-d5a8d00ca7ae" }, "source": [ "mlp_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001), loss=\"sparse_categorical_crossentropy\", metrics=[\"sparse_categorical_accuracy\"])\n", "mlp_model.fit(train_dataset, epochs=20, validation_data=val_dataset)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "782/782 [==============================] - 14s 15ms/step - loss: 0.3271 - sparse_categorical_accuracy: 0.9078 - val_loss: 0.1765 - val_sparse_categorical_accuracy: 0.9496\n", "Epoch 2/20\n", "782/782 [==============================] - 7s 9ms/step - loss: 0.1424 - sparse_categorical_accuracy: 0.9574 - val_loss: 0.1175 - val_sparse_categorical_accuracy: 0.9667\n", "Epoch 3/20\n", "782/782 [==============================] - 9s 12ms/step - loss: 0.0972 - sparse_categorical_accuracy: 0.9710 - val_loss: 0.1019 - val_sparse_categorical_accuracy: 0.9700\n", "Epoch 4/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0749 - sparse_categorical_accuracy: 0.9771 - val_loss: 0.0932 - val_sparse_categorical_accuracy: 0.9727\n", "Epoch 5/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0586 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.0906 - val_sparse_categorical_accuracy: 0.9744\n", "Epoch 6/20\n", "782/782 [==============================] - 5s 7ms/step - loss: 0.0473 - sparse_categorical_accuracy: 0.9854 - val_loss: 0.0944 - val_sparse_categorical_accuracy: 0.9740\n", "Epoch 7/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0378 - sparse_categorical_accuracy: 0.9879 - val_loss: 0.0921 - val_sparse_categorical_accuracy: 0.9748\n", "Epoch 8/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0300 - sparse_categorical_accuracy: 0.9906 - val_loss: 0.0935 - val_sparse_categorical_accuracy: 0.9755\n", "Epoch 9/20\n", "782/782 [==============================] - 5s 7ms/step - loss: 0.0249 - sparse_categorical_accuracy: 0.9920 - val_loss: 0.0980 - val_sparse_categorical_accuracy: 0.9752\n", "Epoch 10/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0238 - sparse_categorical_accuracy: 0.9921 - val_loss: 0.1001 - val_sparse_categorical_accuracy: 0.9756\n", "Epoch 11/20\n", "782/782 [==============================] - 6s 8ms/step - loss: 0.0188 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1070 - val_sparse_categorical_accuracy: 0.9755\n", "Epoch 12/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0165 - sparse_categorical_accuracy: 0.9944 - val_loss: 0.0983 - val_sparse_categorical_accuracy: 0.9766\n", "Epoch 13/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0158 - sparse_categorical_accuracy: 0.9946 - val_loss: 0.1041 - val_sparse_categorical_accuracy: 0.9756\n", "Epoch 14/20\n", "782/782 [==============================] - 6s 7ms/step - loss: 0.0131 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1121 - val_sparse_categorical_accuracy: 0.9766\n", "Epoch 15/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0111 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.1159 - val_sparse_categorical_accuracy: 0.9738\n", "Epoch 16/20\n", "782/782 [==============================] - 5s 7ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9957 - val_loss: 0.1196 - val_sparse_categorical_accuracy: 0.9759\n", "Epoch 17/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0095 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.1219 - val_sparse_categorical_accuracy: 0.9753\n", "Epoch 18/20\n", "782/782 [==============================] - 5s 6ms/step - loss: 0.0097 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.1260 - val_sparse_categorical_accuracy: 0.9763\n", "Epoch 19/20\n", "782/782 [==============================] - 4s 5ms/step - loss: 0.0093 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.1354 - val_sparse_categorical_accuracy: 0.9736\n", "Epoch 20/20\n", "782/782 [==============================] - 5s 6ms/step - loss: 0.0120 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1254 - val_sparse_categorical_accuracy: 0.9763\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 5 } ] }, { "cell_type": "markdown", "metadata": { "id": "uTNm1Vp2dzfM" }, "source": [ "## Define a convolutional neural network model\n", "Let's follow the same procss, but build a convolutional neural network model. We'll also try out the [Functional API](https://www.tensorflow.org/guide/keras/functional). This offers more flexibility than the Sequential API, and makes it easier to use the existing components to make residual blocks.\n", "\n", "The key layers to use are:\n", "+ [layers.Conv2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D)\n", "+ [layers.MaxPool2D](https://www.tensorflow.org/api_docs/python/tf/keras/layers/MaxPool2D)\n", "\n", "Build a suitable convolutional model (don't add too many filters, and stick with a kernel_size of 3 to keep it quick). Remember you'll still want a Dense layer with a softmax activation for the final classification.\n" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 496 }, "id": "S-OWWFvbbNXY", "outputId": "4c509626-9401-422d-ba28-8e01b3f5bcd6" }, "source": [ "def build_conv_model():\n", " # Return a convolutional model using the functional API\n", " # Note that the Input size relates to the size of each instance\n", " inputs = keras.Input(shape=(28,28), name=\"image\")\n", " net = inputs\n", " # We need to reshape the data before convolution such that it's grayscale image with 1 channel\n", " # Note we create the object, and then call it\n", " net = layers.Reshape((28,28,1))(net)\n", "\n", " # We can also define the layer object, and store it in a Python variable and call it later\n", " # We're more likely to want to do this for layers that contain model weights\n", " conv1 = layers.Conv2D(filters=4, kernel_size=3, padding='same', activation='relu')\n", " net = conv1(net)\n", " net = layers.MaxPool2D()(net)\n", "\n", " conv2 = layers.Conv2D(filters=8, kernel_size=3, padding='same', activation='relu')\n", " net = conv2(net)\n", " net = layers.MaxPool2D()(net)\n", "\n", " dense1 = layers.Dense(units=84, activation='relu')\n", " # We need to flatten the spatial dimensions before putting it through the dense layers\n", " net = layers.Flatten()(net)\n", " net = dense1(net)\n", " dense2 = layers.Dense(units=10, activation='softmax')\n", " net = dense2(net)\n", "\n", " return keras.Model(inputs=inputs, outputs=net)\n", "\n", "conv_model = build_conv_model()\n", "\n", "# Draw the network diagram\n", "keras.utils.plot_model(conv_model, \"conv_model.png\", show_shapes=True)\n", "Image(retina=True, filename='conv_model.png')\n" ], "execution_count": null, "outputs": [ { "output_type": "execute_result", "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": { "image/png": { "width": 193, "height": 479 } }, "execution_count": 6 } ] }, { "cell_type": "markdown", "metadata": { "id": "u7nHeLJtM0AA" }, "source": [ "## Train the model\n", "Compile and fit the convolutional model, the same way you did with the MLP.\n", "1. Experiment with the model architecture and learning rates. See what gives you the best performance on the validation metrics" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "iU_IkfMsBUrV", "outputId": "24edd887-7d4b-480e-b015-7aac07c9cb6e" }, "source": [ "conv_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001), loss=\"sparse_categorical_crossentropy\", metrics=[\"sparse_categorical_accuracy\"])\n", "conv_model.fit(train_dataset, batch_size=64, epochs=20, validation_data=val_dataset)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "782/782 [==============================] - 25s 30ms/step - loss: 0.4111 - sparse_categorical_accuracy: 0.8782 - val_loss: 0.1566 - val_sparse_categorical_accuracy: 0.9553\n", "Epoch 2/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.1348 - sparse_categorical_accuracy: 0.9590 - val_loss: 0.1085 - val_sparse_categorical_accuracy: 0.9690\n", "Epoch 3/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.0922 - sparse_categorical_accuracy: 0.9715 - val_loss: 0.0852 - val_sparse_categorical_accuracy: 0.9764\n", "Epoch 4/20\n", "782/782 [==============================] - 22s 29ms/step - loss: 0.0739 - sparse_categorical_accuracy: 0.9774 - val_loss: 0.0724 - val_sparse_categorical_accuracy: 0.9796\n", "Epoch 5/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.0618 - sparse_categorical_accuracy: 0.9815 - val_loss: 0.0646 - val_sparse_categorical_accuracy: 0.9812\n", "Epoch 6/20\n", "782/782 [==============================] - 24s 30ms/step - loss: 0.0517 - sparse_categorical_accuracy: 0.9839 - val_loss: 0.0615 - val_sparse_categorical_accuracy: 0.9826\n", "Epoch 7/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.0450 - sparse_categorical_accuracy: 0.9858 - val_loss: 0.0617 - val_sparse_categorical_accuracy: 0.9815\n", "Epoch 8/20\n", "782/782 [==============================] - 27s 35ms/step - loss: 0.0393 - sparse_categorical_accuracy: 0.9877 - val_loss: 0.0633 - val_sparse_categorical_accuracy: 0.9820\n", "Epoch 9/20\n", "782/782 [==============================] - 24s 30ms/step - loss: 0.0338 - sparse_categorical_accuracy: 0.9892 - val_loss: 0.0596 - val_sparse_categorical_accuracy: 0.9832\n", "Epoch 10/20\n", "782/782 [==============================] - 23s 29ms/step - loss: 0.0301 - sparse_categorical_accuracy: 0.9909 - val_loss: 0.0573 - val_sparse_categorical_accuracy: 0.9847\n", "Epoch 11/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.0258 - sparse_categorical_accuracy: 0.9918 - val_loss: 0.0587 - val_sparse_categorical_accuracy: 0.9849\n", "Epoch 12/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.0218 - sparse_categorical_accuracy: 0.9932 - val_loss: 0.0605 - val_sparse_categorical_accuracy: 0.9854\n", "Epoch 13/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.0206 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.0619 - val_sparse_categorical_accuracy: 0.9841\n", "Epoch 14/20\n", "782/782 [==============================] - 23s 30ms/step - loss: 0.0179 - sparse_categorical_accuracy: 0.9945 - val_loss: 0.0671 - val_sparse_categorical_accuracy: 0.9834\n", "Epoch 15/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9946 - val_loss: 0.0662 - val_sparse_categorical_accuracy: 0.9829\n", "Epoch 16/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.0138 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.0583 - val_sparse_categorical_accuracy: 0.9866\n", "Epoch 17/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.0617 - val_sparse_categorical_accuracy: 0.9854\n", "Epoch 18/20\n", "782/782 [==============================] - 26s 33ms/step - loss: 0.0101 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.0621 - val_sparse_categorical_accuracy: 0.9862\n", "Epoch 19/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.0110 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.0728 - val_sparse_categorical_accuracy: 0.9844\n", "Epoch 20/20\n", "782/782 [==============================] - 26s 33ms/step - loss: 0.0094 - sparse_categorical_accuracy: 0.9968 - val_loss: 0.0811 - val_sparse_categorical_accuracy: 0.9833\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 7 } ] }, { "cell_type": "markdown", "metadata": { "id": "YbvtjpYoA3wu" }, "source": [ "## Evaluate the models\n", "Once we've experimented with various models and tuned our architectures/hyper-parameters on the validation set, we can look at the test data.\n", "\n", "1. Use the [model.evaluate](https://www.tensorflow.org/api_docs/python/tf/keras/Model#evaluate) method to compare metrics between the models on the test set. See an example in the section [API overview: a first end-to-end example](https://keras.io/guides/training_with_built_in_methods/). Which models works best on the test data?" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "hso2RBZvB20-", "outputId": "fb5ea9cd-5953-4de3-9856-ca6a3945d04b" }, "source": [ "# Evaluate the model on the test data using `evaluate`\n", "print(\"Evaluate on test data\")\n", "results = conv_model.evaluate(x_test, y_test, batch_size=128)\n", "print(\"Conv model test loss, test acc:\", results)\n", "results = mlp_model.evaluate(x_test, y_test, batch_size=128)\n", "print(\"MLP model test loss, test acc:\", results)\n" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Evaluate on test data\n", "79/79 [==============================] - 2s 21ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9841\n", "Conv model test loss, test acc: [0.06215307116508484, 0.9840999841690063]\n", "79/79 [==============================] - 0s 3ms/step - loss: 0.1214 - sparse_categorical_accuracy: 0.9743\n", "MLP model test loss, test acc: [0.12144602090120316, 0.9743000268936157]\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "ynvtcZftFTMK" }, "source": [ "## Switch to a tougher dataset\n", "[FashionMNIST](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist) is similar to MNIST, but contains grayscale images of fashion items. It's a bit tougher, but the images are the same size. Make sure you do the same preprocessing on this dataset.\n", "1. Have a go at running your MLP and convolutional models on this dataset. Do they work similarly well on this dataset?\n", "2. What changes can you make to get a classification on the validation set above 90%?\n" ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 408 }, "id": "nw0zYBR8FV9N", "outputId": "3878351c-29c2-4c39-d44b-003570460de8" }, "source": [ "(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()\n", "\n", "plt.imshow(x_train[0,...],cmap='gray')\n", "plt.show()\n", "\n", "# Divide the image data to put it in the right range and convert to floating point numbers\n", "x_train = x_train.astype(\"float32\") / 255\n", "x_test = x_test.astype(\"float32\") / 255\n", "# Convert the labels to floating point\n", "y_train = y_train.astype(\"float32\")\n", "y_test = y_test.astype(\"float32\")\n", "\n", "# Reserve the last 10,000 samples for validation. We can use these to optimise our hyper-parameters.\n", "x_val = x_train[-10000:]\n", "y_val = y_train[-10000:]\n", "x_train = x_train[:-10000]\n", "y_train = y_train[:-10000]\n", "\n", "# Prepare the training dataset\n", "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", "train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n", "\n", "# Prepare the validation dataset\n", "val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\n", "val_dataset = val_dataset.batch(64)\n" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n", "29515/29515 [==============================] - 0s 0us/step\n", "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n", "26421880/26421880 [==============================] - 0s 0us/step\n", "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n", "5148/5148 [==============================] - 0s 0us/step\n", "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n", "4422102/4422102 [==============================] - 0s 0us/step\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAAsTAAALEwEAmpwYAAAR10lEQVR4nO3db2yVdZYH8O+xgNqCBaxA+RPBESOTjVvWikbRjI4Q9IUwanB4scGo24kZk5lkTNa4L8bEFxLdmcm+IJN01AyzzjqZZCBi/DcMmcTdFEcqYdtKd0ZACK2lBUFoS6EUzr7og+lgn3Pqfe69z5Xz/SSk7T393fvrvf1yb+95fs9PVBVEdOm7LO8JEFF5MOxEQTDsREEw7ERBMOxEQUwq542JCN/6JyoxVZXxLs/0zC4iq0TkryKyV0SeyXJdRFRaUmifXUSqAPwNwAoAXQB2AlinqnuMMXxmJyqxUjyzLwOwV1X3q+owgN8BWJ3h+oiohLKEfR6AQ2O+7kou+zsi0iQirSLSmuG2iCijkr9Bp6rNAJoBvownylOWZ/ZuAAvGfD0/uYyIKlCWsO8EsFhEFonIFADfB7C1ONMiomIr+GW8qo6IyFMA3gNQBeBVVf24aDMjoqIquPVW0I3xb3aikivJQTVE9M3BsBMFwbATBcGwEwXBsBMFwbATBcGwEwXBsBMFwbATBcGwEwXBsBMFwbATBcGwEwVR1lNJU/mJjLsA6ktZVz1OmzbNrC9fvjy19s4772S6be9nq6qqSq2NjIxkuu2svLlbCn3M+MxOFATDThQEw04UBMNOFATDThQEw04UBMNOFAT77Je4yy6z/z8/d+6cWb/++uvN+hNPPGHWh4aGUmuDg4Pm2NOnT5v1Dz/80Kxn6aV7fXDvfvXGZ5mbdfyA9XjymZ0oCIadKAiGnSgIhp0oCIadKAiGnSgIhp0oCPbZL3FWTxbw++z33HOPWb/33nvNeldXV2rt8ssvN8dWV1eb9RUrVpj1l19+ObXW29trjvXWjHv3m2fq1KmptfPnz5tjT506VdBtZgq7iBwA0A/gHIARVW3Mcn1EVDrFeGa/W1WPFuF6iKiE+Dc7URBZw64A/igiH4lI03jfICJNItIqIq0Zb4uIMsj6Mn65qnaLyCwA20Tk/1T1/bHfoKrNAJoBQESynd2QiAqW6ZldVbuTj30AtgBYVoxJEVHxFRx2EakRkWkXPgewEkBHsSZGRMWV5WX8bABbknW7kwD8l6q+W5RZUdEMDw9nGn/LLbeY9YULF5p1q8/vrQl/7733zPrSpUvN+osvvphaa22130Jqb283652dnWZ92TL7Ra51v7a0tJhjd+zYkVobGBhIrRUcdlXdD+AfCx1PROXF1htREAw7URAMO1EQDDtREAw7URCSdcver3VjPIKuJKzTFnuPr7dM1GpfAcD06dPN+tmzZ1Nr3lJOz86dO8363r17U2tZW5L19fVm3fq5AXvuDz/8sDl248aNqbXW1lacPHly3F8IPrMTBcGwEwXBsBMFwbATBcGwEwXBsBMFwbATBcE+ewXwtvfNwnt8P/jgA7PuLWH1WD+bt21x1l64teWz1+PftWuXWbd6+ID/s61atSq1dt1115lj582bZ9ZVlX12osgYdqIgGHaiIBh2oiAYdqIgGHaiIBh2oiC4ZXMFKOexDhc7fvy4WffWbQ8NDZl1a1vmSZPsXz9rW2PA7qMDwJVXXpla8/rsd955p1m//fbbzbp3muxZs2al1t59tzRnZOczO1EQDDtREAw7URAMO1EQDDtREAw7URAMO1EQ7LMHV11dbda9frFXP3XqVGrtxIkT5tjPP//crHtr7a3jF7xzCHg/l3e/nTt3zqxbff4FCxaYYwvlPrOLyKsi0iciHWMumyki20Tkk+TjjJLMjoiKZiIv438N4OLTajwDYLuqLgawPfmaiCqYG3ZVfR/AsYsuXg1gU/L5JgBrijstIiq2Qv9mn62qPcnnhwHMTvtGEWkC0FTg7RBRkWR+g05V1TqRpKo2A2gGeMJJojwV2nrrFZF6AEg+9hVvSkRUCoWGfSuA9cnn6wG8UZzpEFGpuC/jReR1AN8BUCciXQB+CmADgN+LyOMADgJYW8pJXuqy9nytnq63Jnzu3Llm/cyZM5nq1np277zwVo8e8PeGt/r0Xp98ypQpZr2/v9+s19bWmvW2trbUmveYNTY2ptb27NmTWnPDrqrrUkrf9cYSUeXg4bJEQTDsREEw7ERBMOxEQTDsREFwiWsF8E4lXVVVZdat1tsjjzxijp0zZ45ZP3LkiFm3TtcM2Es5a2pqzLHeUk+vdWe1/c6ePWuO9U5z7f3cV199tVnfuHFjaq2hocEca83NauPymZ0oCIadKAiGnSgIhp0oCIadKAiGnSgIhp0oCCnndsE8U834vJ7uyMhIwdd96623mvW33nrLrHtbMmc5BmDatGnmWG9LZu9U05MnTy6oBvjHAHhbXXusn+2ll14yx7722mtmXVXHbbbzmZ0oCIadKAiGnSgIhp0oCIadKAiGnSgIhp0oiG/UenZrra7X7/VOx+ydztla/2yt2Z6ILH10z9tvv23WBwcHzbrXZ/dOuWwdx+Gtlfce0yuuuMKse2vWs4z1HnNv7jfddFNqzdvKulB8ZicKgmEnCoJhJwqCYScKgmEnCoJhJwqCYScKoqL67FnWRpeyV11qd911l1l/6KGHzPodd9yRWvO2PfbWhHt9dG8tvvWYeXPzfh+s88IDdh/eO4+DNzePd78NDAyk1h588EFz7JtvvlnQnNxndhF5VUT6RKRjzGXPiUi3iOxO/t1f0K0TUdlM5GX8rwGsGufyX6hqQ/LPPkyLiHLnhl1V3wdwrAxzIaISyvIG3VMi0pa8zJ+R9k0i0iQirSLSmuG2iCijQsP+SwDfAtAAoAfAz9K+UVWbVbVRVRsLvC0iKoKCwq6qvap6TlXPA/gVgGXFnRYRFVtBYReR+jFffg9AR9r3ElFlcM8bLyKvA/gOgDoAvQB+mnzdAEABHADwA1XtcW8sx/PGz5w506zPnTvXrC9evLjgsV7f9IYbbjDrZ86cMevWWn1vXba3z/hnn31m1r3zr1v9Zm8Pc2//9erqarPe0tKSWps6dao51jv2wVvP7q1Jt+633t5ec+ySJUvMetp5492DalR13TgXv+KNI6LKwsNliYJg2ImCYNiJgmDYiYJg2ImCqKgtm2+77TZz/PPPP59au+aaa8yx06dPN+vWUkzAXm75xRdfmGO95bdeC8lrQVmnwfZOBd3Z2WnW165da9ZbW+2joK1tmWfMSD3KGgCwcOFCs+7Zv39/as3bLrq/v9+se0tgvZam1fq76qqrzLHe7wu3bCYKjmEnCoJhJwqCYScKgmEnCoJhJwqCYScKoux9dqtfvWPHDnN8fX19as3rk3v1LKcO9k557PW6s6qtrU2t1dXVmWMfffRRs75y5Uqz/uSTT5p1a4ns6dOnzbGffvqpWbf66IC9LDnr8lpvaa/Xx7fGe8tnr732WrPOPjtRcAw7URAMO1EQDDtREAw7URAMO1EQDDtREGXts9fV1ekDDzyQWt+wYYM5ft++fak179TAXt3b/tfi9VytPjgAHDp0yKx7p3O21vJbp5kGgDlz5pj1NWvWmHVrW2TAXpPuPSY333xzprr1s3t9dO9+87Zk9ljnIPB+n6zzPhw+fBjDw8PssxNFxrATBcGwEwXBsBMFwbATBcGwEwXBsBMF4e7iWkwjIyPo6+tLrXv9ZmuNsLetsXfdXs/X6qt65/k+duyYWT948KBZ9+ZmrZf31ox757TfsmWLWW9vbzfrVp/d20bb64V75+u3tqv2fm5vTbnXC/fGW312r4dvbfFt3SfuM7uILBCRP4vIHhH5WER+lFw+U0S2icgnyUf7jP9ElKuJvIwfAfATVf02gNsA/FBEvg3gGQDbVXUxgO3J10RUodywq2qPqu5KPu8H0AlgHoDVADYl37YJwJoSzZGIiuBrvUEnIgsBLAXwFwCzVbUnKR0GMDtlTJOItIpIq/c3GBGVzoTDLiJTAfwBwI9V9eTYmo6uphl3RY2qNqtqo6o2Zl08QESFm1DYRWQyRoP+W1XdnFzcKyL1Sb0eQPrb7ESUO7f1JqM9glcAdKrqz8eUtgJYD2BD8vEN77qGh4fR3d2dWveW23Z1daXWampqzLHeKZW9Ns7Ro0dTa0eOHDHHTppk383e8lqvzWMtM/VOaewt5bR+bgBYsmSJWR8cHEytee3Q48ePm3XvfrPmbrXlAL815433tmy2lhafOHHCHNvQ0JBa6+joSK1NpM9+B4B/BtAuIruTy57FaMh/LyKPAzgIwN7Im4hy5YZdVf8HQNoRAN8t7nSIqFR4uCxREAw7URAMO1EQDDtREAw7URBlXeI6NDSE3bt3p9Y3b96cWgOAxx57LLXmnW7Z297XWwpqLTP1+uBez9U7stDbEtpa3uttVe0d2+BtZd3T02PWrev35uYdn5DlMcu6fDbL8lrA7uMvWrTIHNvb21vQ7fKZnSgIhp0oCIadKAiGnSgIhp0oCIadKAiGnSiIsm7ZLCKZbuy+++5LrT399NPm2FmzZpl1b9221Vf1+sVen9zrs3v9Zuv6rVMWA36f3TuGwKtbP5s31pu7xxpv9aonwnvMvFNJW+vZ29razLFr19qryVWVWzYTRcawEwXBsBMFwbATBcGwEwXBsBMFwbATBVH2Prt1nnKvN5nF3XffbdZfeOEFs2716Wtra82x3rnZvT6812f3+vwWawttwO/DW/sAAPZjOjAwYI717hePNXdvvbm3jt97TLdt22bWOzs7U2stLS3mWA/77ETBMexEQTDsREEw7ERBMOxEQTDsREEw7ERBuH12EVkA4DcAZgNQAM2q+h8i8hyAfwFwYXPyZ1X1bee6ytfUL6Mbb7zRrGfdG37+/Plm/cCBA6k1r5+8b98+s07fPGl99olsEjEC4CequktEpgH4SEQuHDHwC1X992JNkohKZyL7s/cA6Ek+7xeRTgDzSj0xIiqur/U3u4gsBLAUwF+Si54SkTYReVVEZqSMaRKRVhFpzTZVIspiwmEXkakA/gDgx6p6EsAvAXwLQANGn/l/Nt44VW1W1UZVbcw+XSIq1ITCLiKTMRr036rqZgBQ1V5VPaeq5wH8CsCy0k2TiLJywy6jp+h8BUCnqv58zOX1Y77tewA6ij89IiqWibTelgP4bwDtAC6sV3wWwDqMvoRXAAcA/CB5M8+6rkuy9UZUSdJab9+o88YTkY/r2YmCY9iJgmDYiYJg2ImCYNiJgmDYiYJg2ImCYNiJgmDYiYJg2ImCYNiJgmDYiYJg2ImCYNiJgpjI2WWL6SiAg2O+rksuq0SVOrdKnRfAuRWqmHO7Nq1Q1vXsX7lxkdZKPTddpc6tUucFcG6FKtfc+DKeKAiGnSiIvMPenPPtWyp1bpU6L4BzK1RZ5pbr3+xEVD55P7MTUZkw7ERB5BJ2EVklIn8Vkb0i8kwec0gjIgdEpF1Edue9P12yh16fiHSMuWymiGwTkU+Sj+PusZfT3J4Tke7kvtstIvfnNLcFIvJnEdkjIh+LyI+Sy3O974x5leV+K/vf7CJSBeBvAFYA6AKwE8A6Vd1T1omkEJEDABpVNfcDMETkLgADAH6jqv+QXPYigGOquiH5j3KGqv5rhcztOQADeW/jnexWVD92m3EAawA8ihzvO2Nea1GG+y2PZ/ZlAPaq6n5VHQbwOwCrc5hHxVPV9wEcu+ji1QA2JZ9vwugvS9mlzK0iqGqPqu5KPu8HcGGb8VzvO2NeZZFH2OcBODTm6y5U1n7vCuCPIvKRiDTlPZlxzB6zzdZhALPznMw43G28y+mibcYr5r4rZPvzrPgG3VctV9V/AnAfgB8mL1crko7+DVZJvdMJbeNdLuNsM/6lPO+7Qrc/zyqPsHcDWDDm6/nJZRVBVbuTj30AtqDytqLuvbCDbvKxL+f5fKmStvEeb5txVMB9l+f253mEfSeAxSKySESmAPg+gK05zOMrRKQmeeMEIlIDYCUqbyvqrQDWJ5+vB/BGjnP5O5WyjXfaNuPI+b7LfftzVS37PwD3Y/Qd+X0A/i2POaTM6zoA/5v8+zjvuQF4HaMv685i9L2NxwFcDWA7gE8A/AnAzAqa239idGvvNowGqz6nuS3H6Ev0NgC7k3/3533fGfMqy/3Gw2WJguAbdERBMOxEQTDsREEw7ERBMOxEQTDsREEw7ERB/D/+XzeWfiVg0AAAAABJRU5ErkJggg==\n" }, "metadata": { "needs_background": "light" } } ] }, { "cell_type": "code", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "mjH72fYTo3sh", "outputId": "9ad482cc-3859-41c9-9cc3-da4f79f5330e" }, "source": [ "\n", "# Remember to build a fresh model, the old weights won't be relevant\n", "conv_model = build_conv_model()\n", "\n", "# Fit the model parameters\n", "\n", "conv_model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.001), loss=\"sparse_categorical_crossentropy\", metrics=[\"sparse_categorical_accuracy\"])\n", "conv_model.fit(train_dataset, batch_size=64, epochs=20, validation_data=val_dataset)\n", "\n", "results = conv_model.evaluate(x_test, y_test, batch_size=128)\n", "print(\"Conv model test loss, test acc:\", results)" ], "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Epoch 1/20\n", "782/782 [==============================] - 27s 33ms/step - loss: 0.5765 - sparse_categorical_accuracy: 0.7967 - val_loss: 0.4383 - val_sparse_categorical_accuracy: 0.8418\n", "Epoch 2/20\n", "782/782 [==============================] - 26s 34ms/step - loss: 0.3934 - sparse_categorical_accuracy: 0.8602 - val_loss: 0.4248 - val_sparse_categorical_accuracy: 0.8455\n", "Epoch 3/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.3515 - sparse_categorical_accuracy: 0.8731 - val_loss: 0.3482 - val_sparse_categorical_accuracy: 0.8735\n", "Epoch 4/20\n", "782/782 [==============================] - 27s 35ms/step - loss: 0.3286 - sparse_categorical_accuracy: 0.8808 - val_loss: 0.3342 - val_sparse_categorical_accuracy: 0.8787\n", "Epoch 5/20\n", "782/782 [==============================] - 24s 30ms/step - loss: 0.3099 - sparse_categorical_accuracy: 0.8873 - val_loss: 0.3277 - val_sparse_categorical_accuracy: 0.8806\n", "Epoch 6/20\n", "782/782 [==============================] - 25s 33ms/step - loss: 0.2967 - sparse_categorical_accuracy: 0.8911 - val_loss: 0.3141 - val_sparse_categorical_accuracy: 0.8876\n", "Epoch 7/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.2830 - sparse_categorical_accuracy: 0.8975 - val_loss: 0.3010 - val_sparse_categorical_accuracy: 0.8900\n", "Epoch 8/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.2709 - sparse_categorical_accuracy: 0.9014 - val_loss: 0.3135 - val_sparse_categorical_accuracy: 0.8847\n", "Epoch 9/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.2612 - sparse_categorical_accuracy: 0.9047 - val_loss: 0.2979 - val_sparse_categorical_accuracy: 0.8933\n", "Epoch 10/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.2513 - sparse_categorical_accuracy: 0.9078 - val_loss: 0.2895 - val_sparse_categorical_accuracy: 0.8972\n", "Epoch 11/20\n", "782/782 [==============================] - 27s 35ms/step - loss: 0.2426 - sparse_categorical_accuracy: 0.9116 - val_loss: 0.3101 - val_sparse_categorical_accuracy: 0.8868\n", "Epoch 12/20\n", "782/782 [==============================] - 27s 34ms/step - loss: 0.2352 - sparse_categorical_accuracy: 0.9139 - val_loss: 0.2873 - val_sparse_categorical_accuracy: 0.8978\n", "Epoch 13/20\n", "782/782 [==============================] - 24s 31ms/step - loss: 0.2281 - sparse_categorical_accuracy: 0.9150 - val_loss: 0.2898 - val_sparse_categorical_accuracy: 0.8960\n", "Epoch 14/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.2209 - sparse_categorical_accuracy: 0.9183 - val_loss: 0.2959 - val_sparse_categorical_accuracy: 0.8974\n", "Epoch 15/20\n", "782/782 [==============================] - 23s 29ms/step - loss: 0.2135 - sparse_categorical_accuracy: 0.9212 - val_loss: 0.3009 - val_sparse_categorical_accuracy: 0.8935\n", "Epoch 16/20\n", "782/782 [==============================] - 26s 33ms/step - loss: 0.2093 - sparse_categorical_accuracy: 0.9222 - val_loss: 0.2926 - val_sparse_categorical_accuracy: 0.8963\n", "Epoch 17/20\n", "782/782 [==============================] - 25s 32ms/step - loss: 0.2015 - sparse_categorical_accuracy: 0.9258 - val_loss: 0.2932 - val_sparse_categorical_accuracy: 0.8987\n", "Epoch 18/20\n", "782/782 [==============================] - 23s 30ms/step - loss: 0.1966 - sparse_categorical_accuracy: 0.9267 - val_loss: 0.2948 - val_sparse_categorical_accuracy: 0.8991\n", "Epoch 19/20\n", "782/782 [==============================] - 28s 36ms/step - loss: 0.1891 - sparse_categorical_accuracy: 0.9300 - val_loss: 0.3059 - val_sparse_categorical_accuracy: 0.8951\n", "Epoch 20/20\n", "782/782 [==============================] - 27s 35ms/step - loss: 0.1838 - sparse_categorical_accuracy: 0.9316 - val_loss: 0.2951 - val_sparse_categorical_accuracy: 0.9001\n", "79/79 [==============================] - 2s 21ms/step - loss: 0.3045 - sparse_categorical_accuracy: 0.8970\n", "Conv model test loss, test acc: [0.30447232723236084, 0.8970000147819519]\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "7lavD5hPDZJT" }, "source": [ "## Extension Tasks\n", "You don't need to do these extensions, but you'll certainly learn something by giving them a go.\n", "1. Find some validation/test examples where your model performs poorly. Visualise those examples. Can you think of a reason why it doesn't work well for those ones? Is it ambiguous?\n", "1. Look at a more complex (but still low-resolution) image dataset. [CIFAR 10](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/cifar10) or [CIFAR 100](https://www.tensorflow.org/api_docs/python/tf/keras/datasets/cifar100).\n", "3. Try including [BatchNorm](https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization) or [Dropout](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout)\n", "3. Implement a residual block, you could try sub-classing the [Layer base class](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer)\n", "4. Learn how to use [Tensorboard](https://www.tensorflow.org/tensorboard), an excellent tool for visualising your model training, which you can either run from [Colab](https://colab.research.google.com/github/tensorflow/tensorboard/blob/master/docs/tensorboard_in_notebooks.ipynb) or install on your local machine\n" ] } ] }