Computer_Vision/Chapter05/Implementing_VGG16_for_image_classification.ipynb

571 lines
177 KiB
Plaintext
Raw Normal View History

2024-02-13 03:34:51 +01:00
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "Implementing_VGG16_for_image_classification.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-PyTorch/blob/master/Chapter05/Implementing_VGG16_for_image_classification.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "X1mQ8Y9ERCpa"
},
"source": [
"import torchvision\n",
"import torch.nn as nn\n",
"import torch\n",
"import torch.nn.functional as F\n",
"from torchvision import transforms,models,datasets\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"import numpy as np\n",
"from torch import optim\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"import cv2, glob, numpy as np, pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from glob import glob\n",
"import torchvision.transforms as transforms\n",
"from torch.utils.data import DataLoader, Dataset"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CQcFhpxVRNev",
"outputId": "87230d9b-ab1d-4b8e-ab5f-f2b0a916c338",
"colab": {
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUo
"ok": true,
"headers": [
[
"content-type",
"application/javascript"
]
],
"status": 200,
"status_text": ""
}
},
"base_uri": "https://localhost:8080/",
"height": 91
}
},
"source": [
"!pip install -q kaggle\n",
"from google.colab import files\n",
"files.upload()\n",
"!mkdir -p ~/.kaggle\n",
"!cp kaggle.json ~/.kaggle/\n",
"!ls ~/.kaggle\n",
"!chmod 600 /root/.kaggle/kaggle.json"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-954feacf-38d2-4ef8-8541-5b2f2491ae93\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-954feacf-38d2-4ef8-8541-5b2f2491ae93\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "stream",
"text": [
"Saving kaggle.json to kaggle.json\n",
"kaggle.json\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "FAYvATjiRPep"
},
"source": [
"!kaggle datasets download -d tongpython/cat-and-dog\n",
"!unzip cat-and-dog.zip"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nCvdJ9U-RWb3"
},
"source": [
"train_data_dir = 'training_set/training_set'\n",
"test_data_dir = 'test_set/test_set'"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NDfNnADpRYAV"
},
"source": [
"class CatsDogs(Dataset):\n",
" def __init__(self, folder):\n",
" cats = glob(folder+'/cats/*.jpg')\n",
" dogs = glob(folder+'/dogs/*.jpg')\n",
" self.fpaths = cats[:500] + dogs[:500]\n",
" self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])\n",
" from random import shuffle, seed; seed(10); shuffle(self.fpaths)\n",
" self.targets = [fpath.split('/')[-1].startswith('dog') for fpath in self.fpaths] \n",
" def __len__(self): return len(self.fpaths)\n",
" def __getitem__(self, ix):\n",
" f = self.fpaths[ix]\n",
" target = self.targets[ix]\n",
" im = (cv2.imread(f)[:,:,::-1])\n",
" im = cv2.resize(im, (224,224))\n",
" im = torch.tensor(im/255)\n",
" im = im.permute(2,0,1)\n",
" im = self.normalize(im) \n",
" return im.float().to(device), torch.tensor([target]).float().to(device)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-90fyHONRah5"
},
"source": [
"data = CatsDogs(train_data_dir)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4VvoZixHRcNM",
"outputId": "3cc48845-a8b7-432e-cc4a-1e64eab13e58",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 305
}
},
"source": [
"im, label = data[200]\n",
"plt.imshow(im.permute(1,2,0).cpu())\n",
"print(label)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"tensor([0.], device='cuda:0')\n"
],
"name": "stdout"
},
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQEAAAD8CAYAAAB3lxGOAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOydd5glRfX3P919c5obZ+ZODrszm9jMsuQcBX8IkkVAJIMIKLpgIAk/FREwEEQREBRJAgICorCALLALbGDj7OzkPHPv3Jy6z/tHL7jILmzifXkf9/s8/czcruo61d1V3zp16pxqRUTYhV3Yhf9eqP+vK7ALu7AL/2+xiwR2YRf+y7GLBHZhF/7LsYsEdmEX/suxiwR2YRf+y7GLBHZhF/7L8ZmRgKIoRyiKskZRlDZFUb77WcnZhV3YhR2D8ln4CSiKogFrgUOBHuBt4BQRWbnThe3CLuzCDuGz0gTmAW0i0i4iBeBPwP98RrJ2YRd2YQdg+YzKrQa6N/ndA+yxpcyKogjAnDlzdkjo6HiWvr4eCpnEJ+abNm0adrt9s2lpgdXvLPmgYrROmY3HsX316evL09/fC8QAKAtWU10VwWnXNptfcgaKY+fwcqEIg4MxhgbbN56xEo1WUVUVBgqADtiAD+qiA8LOaBICjOfyrH9/jSlZCxEJuYjWBj6Sbzg+Sm9vL3quuMlZjUhlM9EKL9btrEo2Y7By1VLAQFU1qqubKC/3bV9h24j1nRAf+aD9VDNnduVm8xmik87F8TpDOywzb8DalW0U8gnQfMyY1oTF8vF2tGTJkhERiXwsQUR2+gF8Gbhnk9+nA7/8jzznAos3HqJpmuwolvTE5LBTviaY7fAjh6KoommaaJomqqoJKHLfvwYlmy9KsViUkl6SzmJRNhKSqJpVfvXQDlfpQ5x55vuiaQeKppmyj/rqrbK+K2bKLhliGIaMrX5e9FJh5wndiMf+OihNUy8VTdNEURRxgNz3m59KsZiQUqkkhqGLiCEi4yJS2omSDclks3L/Hx8WVTWfK84qoWzqJu9FEU3T5JQL75bVXTtPcqEgctevxkXTdt/4zt0ybdqpUiwWpVgqia7vPFn/CcMQWRUzPmxvdnuFKbdY/FBuvpiT9v7lYhjGTpUdjZ4nmuYUmCjJZEqKxaIYhiGGiACLZXP9dXMnd/QA9gSe3+T3AmDBJ+QXTdMkl8tJLpeX7X0shohceumlHyMAVbXIqae+JX19Zr7n/ipSXn6q2GxusVhsAsjeF8wVlH9fc/kN21mJrcDhh/9IbLao2Gw2UVWbnPj1F2RwKC2p4ecln0vv9IaxKRYs+JHYbH6xWW2iKqrsN3uSLHnzUcnl4lLI5zchhB2AYYgYJTH0tJSKGySZH5bfPfnjjc9WEVW1iqpaBVS59tprJZcrSKn0GfRKQ0R0kWJR5LWFA2KzHSdWq01qJh4otz9itrVCobjz5W6CQqEkNucCsdls4vZG5I4ncpLL5SSTLUlvLCuD+k5kvk0QbLxHbDa3KKpN3stkJJfL/V8nAQvQDjRi6pxLgamfkF/AIprFJU7nBOlJpSSVSkk6nZZ8fmsbhyHJfF7Oueiij5HAhd96WfoHN3eFyL1/LYrV5vqwgYJDQJENQx/UIbejXeITcfoZIoHAFHG5XAKKTKueLiPDI5JKpSSbK3x2hDCckDO+fJ64XAFxOlyiKqoc0Fguq5cuklQqKdlMfttlG4YYRlH03JCUNtwssccuk/sv/LLMaZojqmoXl8slZeEj5NTzl8sl394gLtc54nK5xWLZV37xixfN553J7URC2MgCm6CrX+T0S14Vl8stdnulnHTSVR/KLRQ/O/UgmzPkxAs2iMvlEqejXKq8P5IFCx6XpT2rRf8MSf+4c0Vcbo/YHa7/uyQgZsc+CnOFYD1w9afkFUW5Us68Ji8+X434fD7xen3i8TXIed9eLvF4XOLxcUkmM6Jv4Xllsxm54orLNjsV+NvzL0jpQ/3PkJKRk/HxuAzE4zJ374dFUSaJgld8vslSE1ghPl+jWQefX+y+MyQWj0s8Hpfx8ZR8lgNHY8Ne4vP5xefzidPlk3nH3CixWMy8/0TqM5M7uLgkB8w7THw+n/g8XlEVRQ6bc4F0buiVeDwu+aL+yYSwsfMbhXHJj94jXXceLv84Z7bcfejBcufZ90qN7wY5/vg3N3tpSUSu/O6L4vPtJ16fTyKTj5e/vrhE4vG4pNIF0bf0wrcChpTEkA1bTH/ppWFpar5WfD6fVE/6svz4ziUSj49LOpP/zKYLuiGysmNIfL69xeX0SOtuU2T5wHoZTyQll//syOD1Vf+XpwPbepjTgTM/UulEQpcTT1gq4XBYwuGw+P21Mn3O+bJ6cEiGhoZkdGxMMrm8iIhkMhm56qqrNksAgEw64Avy2tvvyNDQkAwNdcirvd+XYLnrw3R/ICQTK96XUunfL8EQkbFCUcLhUyQcDkswWCE+/3HyxxdM+cMjI5JK53bOG/pAZsLUpEVE/vCMSHl5hSk7HBZ34yEb6z8kw8PDO1XupsitG5a5M2ZLOByWUDgsiqLIr59eLQMbn/sHc0xT5c+KUUyIkRsVI/aiJB67XhbdeIq8ccuF8vZ9P5MX73lIuldtPWvGROSkMy6ScDgqZf6wHHDs7+WdZT0yNDQkmWxJtl0x0cXQ45+aLysiP7vrIQmHJ4ivbLqcftF9smTVkAwNxSSX2/k2GhGRfE7kqb+sljJ/WEKhsEycdqTcdGeXDA2PSiK5c9vVB/j/gARaP/EGVqx4X6LRWolGoxIMBqV64lT533v+JL29vXLN9dcI2uYJ4NMOv8cpy9enP7WBxWI5qWq+Q6LRqEQiFeIJNMkFP7hbent7pa9/UOKJ7Fa9iG2FYYj0lQyJRqMSjUYlEKwUl8slvb290tvbJwMDY5+JXBGzc0yaNOlD2YC8vvBlKaZHxMiOSCnxCyktPkkKj+0v2R9eJLEf/kQ2/PpuWfTwQ/KnP/1O/rr4JRnU09sle/FKkb33O1Ki0ag4nOVy7a2LZf2GXuntHRZdN7ZqimYYJrFuC15+TeTIY66RaDQqTudhcuedf5Pe3l6JJ3Kys80WyZQhr72tS64g8vizSyUa3V/8kd3lpDN/I729fRIfT8jOFLklEvhMnIW2FYqiiGa5m1LxnK3K//LLL3PWWV8jlUoxMjKC2Z83KQ+wK6YThKZZEFRKhmAYCoqqoUgJq0VBjCI/O/8Ejv/OjwhEylHsXhRF+VT5GzbEOeig24F7KelQtFRy+JfO4pqLDwXA5/MRCu340s9/whD422t5LvrqJATIF1WkNJM33rgZULDb3VRVfXwFaIchOYqpBPW11ezuVvnFjw6kOqQw2L0Wa7pEWVklQz31lLx19BdVRtxh0qEm/I2TOGLvaTss/rwr3uefz5xFMT9MT+9k3n3nFtxuO+CkoaFiy+/sg2bx6a90s7jiil/z+OO/o1Aa5aiTfsY5p84iEnBQXR3BarNsb7EfIpvW6VyTYdJs74fnnnlmCRdffA25fAdHn3gql116Ml67h2DAj9tp3SF5iqIsEZG5Hzv/uSEB3yOUxr+81dek0jnuuOsBrrzim0Dmw/MqELTAXI+NCAYVkWpK9gj9KYWhtB2HJ4Q/O8SEqJNMYhXJ2DC1kx2ccv6lOPc4jVB5BJsvtFVkADAUg7sfeI0//PrrAIyPpzj00MO46qorAQgEygmFglh2skeGLrC0a4xTDz8IyFEoWHA65/DYYwsAcLk9VFXVYNm8S8KnIAeZFOlYDNBReJu+Z+/nzddeYp/9ZuHGg131Y6S82LUabJ56OgeKrO4YRg+7Wdi2hrq5M7j48m/uxDs2cfj/vEDX2m8jkmfNmt1YseJaNE3F5orSVFf2kbwiAqKjqDv28Nf1wHcvP5P3ly2ib2A6Dz70DSY0hXF5a6iJutHU7aODQk4Y7hSqWz++pv/uSrj1tt/y5is/RYvsyzlnncIRe9VTVVWJz+feLnmffxJwnkV37x1EA5t34tkUmVyOB/+ykEsvu5bswL8+PG8FwsBBVVBrV
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "JwhHv9VYRfhj"
},
"source": [
"def get_model():\n",
" model = models.vgg16(pretrained=True)\n",
" for param in model.parameters():\n",
" param.requires_grad = False\n",
" model.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))\n",
" model.classifier = nn.Sequential(nn.Flatten(),\n",
" nn.Linear(512, 128),\n",
" nn.ReLU(),\n",
" nn.Dropout(0.2),\n",
" nn.Linear(128, 1),\n",
" nn.Sigmoid())\n",
" loss_fn = nn.BCELoss()\n",
" optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)\n",
" return model.to(device), loss_fn, optimizer"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "WEodSA2URqK8",
"outputId": "6b602eff-94f6-4013-bde2-531571b75349",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
}
},
"source": [
"!pip install torch_summary\n",
"from torchsummary import summary\n",
"model, criterion, optimizer = get_model()\n",
"summary(model, torch.zeros(1,3,224,224))"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"Requirement already satisfied: torch_summary in /usr/local/lib/python3.6/dist-packages (1.4.3)\n",
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"├─Sequential: 1-1 [-1, 512, 7, 7] --\n",
"| └─Conv2d: 2-1 [-1, 64, 224, 224] (1,792)\n",
"| └─ReLU: 2-2 [-1, 64, 224, 224] --\n",
"| └─Conv2d: 2-3 [-1, 64, 224, 224] (36,928)\n",
"| └─ReLU: 2-4 [-1, 64, 224, 224] --\n",
"| └─MaxPool2d: 2-5 [-1, 64, 112, 112] --\n",
"| └─Conv2d: 2-6 [-1, 128, 112, 112] (73,856)\n",
"| └─ReLU: 2-7 [-1, 128, 112, 112] --\n",
"| └─Conv2d: 2-8 [-1, 128, 112, 112] (147,584)\n",
"| └─ReLU: 2-9 [-1, 128, 112, 112] --\n",
"| └─MaxPool2d: 2-10 [-1, 128, 56, 56] --\n",
"| └─Conv2d: 2-11 [-1, 256, 56, 56] (295,168)\n",
"| └─ReLU: 2-12 [-1, 256, 56, 56] --\n",
"| └─Conv2d: 2-13 [-1, 256, 56, 56] (590,080)\n",
"| └─ReLU: 2-14 [-1, 256, 56, 56] --\n",
"| └─Conv2d: 2-15 [-1, 256, 56, 56] (590,080)\n",
"| └─ReLU: 2-16 [-1, 256, 56, 56] --\n",
"| └─MaxPool2d: 2-17 [-1, 256, 28, 28] --\n",
"| └─Conv2d: 2-18 [-1, 512, 28, 28] (1,180,160)\n",
"| └─ReLU: 2-19 [-1, 512, 28, 28] --\n",
"| └─Conv2d: 2-20 [-1, 512, 28, 28] (2,359,808)\n",
"| └─ReLU: 2-21 [-1, 512, 28, 28] --\n",
"| └─Conv2d: 2-22 [-1, 512, 28, 28] (2,359,808)\n",
"| └─ReLU: 2-23 [-1, 512, 28, 28] --\n",
"| └─MaxPool2d: 2-24 [-1, 512, 14, 14] --\n",
"| └─Conv2d: 2-25 [-1, 512, 14, 14] (2,359,808)\n",
"| └─ReLU: 2-26 [-1, 512, 14, 14] --\n",
"| └─Conv2d: 2-27 [-1, 512, 14, 14] (2,359,808)\n",
"| └─ReLU: 2-28 [-1, 512, 14, 14] --\n",
"| └─Conv2d: 2-29 [-1, 512, 14, 14] (2,359,808)\n",
"| └─ReLU: 2-30 [-1, 512, 14, 14] --\n",
"| └─MaxPool2d: 2-31 [-1, 512, 7, 7] --\n",
"├─AdaptiveAvgPool2d: 1-2 [-1, 512, 1, 1] --\n",
"├─Sequential: 1-3 [-1, 1] --\n",
"| └─Flatten: 2-32 [-1, 512] --\n",
"| └─Linear: 2-33 [-1, 128] 65,664\n",
"| └─ReLU: 2-34 [-1, 128] --\n",
"| └─Dropout: 2-35 [-1, 128] --\n",
"| └─Linear: 2-36 [-1, 1] 129\n",
"| └─Sigmoid: 2-37 [-1, 1] --\n",
"==========================================================================================\n",
"Total params: 14,780,481\n",
"Trainable params: 65,793\n",
"Non-trainable params: 14,714,688\n",
"Total mult-adds (G): 15.36\n",
"==========================================================================================\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 103.36\n",
"Params size (MB): 56.38\n",
"Estimated Total Size (MB): 160.32\n",
"==========================================================================================\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"==========================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"==========================================================================================\n",
"├─Sequential: 1-1 [-1, 512, 7, 7] --\n",
"| └─Conv2d: 2-1 [-1, 64, 224, 224] (1,792)\n",
"| └─ReLU: 2-2 [-1, 64, 224, 224] --\n",
"| └─Conv2d: 2-3 [-1, 64, 224, 224] (36,928)\n",
"| └─ReLU: 2-4 [-1, 64, 224, 224] --\n",
"| └─MaxPool2d: 2-5 [-1, 64, 112, 112] --\n",
"| └─Conv2d: 2-6 [-1, 128, 112, 112] (73,856)\n",
"| └─ReLU: 2-7 [-1, 128, 112, 112] --\n",
"| └─Conv2d: 2-8 [-1, 128, 112, 112] (147,584)\n",
"| └─ReLU: 2-9 [-1, 128, 112, 112] --\n",
"| └─MaxPool2d: 2-10 [-1, 128, 56, 56] --\n",
"| └─Conv2d: 2-11 [-1, 256, 56, 56] (295,168)\n",
"| └─ReLU: 2-12 [-1, 256, 56, 56] --\n",
"| └─Conv2d: 2-13 [-1, 256, 56, 56] (590,080)\n",
"| └─ReLU: 2-14 [-1, 256, 56, 56] --\n",
"| └─Conv2d: 2-15 [-1, 256, 56, 56] (590,080)\n",
"| └─ReLU: 2-16 [-1, 256, 56, 56] --\n",
"| └─MaxPool2d: 2-17 [-1, 256, 28, 28] --\n",
"| └─Conv2d: 2-18 [-1, 512, 28, 28] (1,180,160)\n",
"| └─ReLU: 2-19 [-1, 512, 28, 28] --\n",
"| └─Conv2d: 2-20 [-1, 512, 28, 28] (2,359,808)\n",
"| └─ReLU: 2-21 [-1, 512, 28, 28] --\n",
"| └─Conv2d: 2-22 [-1, 512, 28, 28] (2,359,808)\n",
"| └─ReLU: 2-23 [-1, 512, 28, 28] --\n",
"| └─MaxPool2d: 2-24 [-1, 512, 14, 14] --\n",
"| └─Conv2d: 2-25 [-1, 512, 14, 14] (2,359,808)\n",
"| └─ReLU: 2-26 [-1, 512, 14, 14] --\n",
"| └─Conv2d: 2-27 [-1, 512, 14, 14] (2,359,808)\n",
"| └─ReLU: 2-28 [-1, 512, 14, 14] --\n",
"| └─Conv2d: 2-29 [-1, 512, 14, 14] (2,359,808)\n",
"| └─ReLU: 2-30 [-1, 512, 14, 14] --\n",
"| └─MaxPool2d: 2-31 [-1, 512, 7, 7] --\n",
"├─AdaptiveAvgPool2d: 1-2 [-1, 512, 1, 1] --\n",
"├─Sequential: 1-3 [-1, 1] --\n",
"| └─Flatten: 2-32 [-1, 512] --\n",
"| └─Linear: 2-33 [-1, 128] 65,664\n",
"| └─ReLU: 2-34 [-1, 128] --\n",
"| └─Dropout: 2-35 [-1, 128] --\n",
"| └─Linear: 2-36 [-1, 1] 129\n",
"| └─Sigmoid: 2-37 [-1, 1] --\n",
"==========================================================================================\n",
"Total params: 14,780,481\n",
"Trainable params: 65,793\n",
"Non-trainable params: 14,714,688\n",
"Total mult-adds (G): 15.36\n",
"==========================================================================================\n",
"Input size (MB): 0.57\n",
"Forward/backward pass size (MB): 103.36\n",
"Params size (MB): 56.38\n",
"Estimated Total Size (MB): 160.32\n",
"=========================================================================================="
]
},
"metadata": {
"tags": []
},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "wRSbFt3BRr5B"
},
"source": [
"def train_batch(x, y, model, opt, loss_fn):\n",
" model.train()\n",
" prediction = model(x)\n",
" batch_loss = loss_fn(prediction, y)\n",
" batch_loss.backward()\n",
" optimizer.step()\n",
" optimizer.zero_grad()\n",
" return batch_loss.item()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Fp2yASc_RuO2"
},
"source": [
"@torch.no_grad()\n",
"def accuracy(x, y, model):\n",
" model.eval()\n",
" prediction = model(x)\n",
" is_correct = (prediction > 0.5) == y\n",
" return is_correct.cpu().numpy().tolist()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tbOeDzCPSVfj"
},
"source": [
"def get_data():\n",
" train = CatsDogs(train_data_dir)\n",
" trn_dl = DataLoader(train, batch_size=32, shuffle=True, drop_last = True)\n",
" val = CatsDogs(test_data_dir)\n",
" val_dl = DataLoader(val, batch_size=32, shuffle=True, drop_last = True)\n",
" return trn_dl, val_dl"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hz7QoetLSXNI"
},
"source": [
"trn_dl, val_dl = get_data()\n",
"model, loss_fn, optimizer = get_model()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "X_vtpUGRSYvZ",
"outputId": "342455c8-4da3-49c4-cec3-62b1e04aeb60",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 108
}
},
"source": [
"train_losses, train_accuracies = [], []\n",
"val_accuracies = []\n",
"for epoch in range(5):\n",
" print(f\" epoch {epoch + 1}/5\")\n",
" train_epoch_losses, train_epoch_accuracies = [], []\n",
" val_epoch_accuracies = []\n",
"\n",
" for ix, batch in enumerate(iter(trn_dl)):\n",
" x, y = batch\n",
" batch_loss = train_batch(x, y, model, optimizer, loss_fn)\n",
" train_epoch_losses.append(batch_loss) \n",
" train_epoch_loss = np.array(train_epoch_losses).mean()\n",
"\n",
" for ix, batch in enumerate(iter(trn_dl)):\n",
" x, y = batch\n",
" is_correct = accuracy(x, y, model)\n",
" train_epoch_accuracies.extend(is_correct)\n",
" train_epoch_accuracy = np.mean(train_epoch_accuracies)\n",
"\n",
" for ix, batch in enumerate(iter(val_dl)):\n",
" x, y = batch\n",
" val_is_correct = accuracy(x, y, model)\n",
" val_epoch_accuracies.extend(val_is_correct)\n",
" val_epoch_accuracy = np.mean(val_epoch_accuracies)\n",
"\n",
" train_losses.append(train_epoch_loss)\n",
" train_accuracies.append(train_epoch_accuracy)\n",
" val_accuracies.append(val_epoch_accuracy)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
" epoch 1/5\n",
" epoch 2/5\n",
" epoch 3/5\n",
" epoch 4/5\n",
" epoch 5/5\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "5aFfpJGZSb5v",
"outputId": "017f433d-0176-4144-b2a3-42d883f28fa3",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 310
}
},
"source": [
"epochs = np.arange(5)+1\n",
"import matplotlib.ticker as mtick\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib.ticker as mticker\n",
"%matplotlib inline\n",
"plt.plot(epochs, train_accuracies, 'bo', label='Training accuracy')\n",
"plt.plot(epochs, val_accuracies, 'r', label='Validation accuracy')\n",
"plt.gca().xaxis.set_major_locator(mticker.MultipleLocator(1))\n",
"plt.title('Training and validation accuracy with VGG16 \\nand 1K training data points')\n",
"plt.xlabel('Epochs')\n",
"plt.ylabel('Accuracy')\n",
"plt.ylim(0.95,1)\n",
"plt.gca().set_yticklabels(['{:.0f}%'.format(x*100) for x in plt.gca().get_yticks()]) \n",
"plt.legend()\n",
"plt.grid('off')\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAY4AAAElCAYAAADz3wVRAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3xUVfrH8c8DqBBAEFwQpFooCgYIoqIIWNZeQEUQC+uu2BHX1XXXvpbfFta2awMLFhR7L+sKggVXREUFRBCNigoKSBNpyfP749yEIWSSuUkmk/J9v17zYm4/czPMc8859zzX3B0REZFU1cl0AUREpHpR4BARkVgUOEREJBYFDhERiUWBQ0REYlHgEBGRWBQ4qgEze9nMTqvodTPJzHLN7KA07NfNbJfo/Z1mdkUq65bhOMPN7NWyllNSY2b9zOyzEpZ3iP6O9SqzXLWdAkeamNnqhFe+mf2SMD08zr7c/TB3v7+i163p3P0sd7+2vPsp7sfJ3Se4+6/Lu28pmbu/6e6dC6bLc8FhZpea2RvFzN/ezNabWbdoupWZjTOz76L/r1+Y2Xgz65KwzdZmdqWZfWZmP5vZt9FF268T1jnPzGaY2TozG1/McbPM7HYzW2JmK4orW1WlwJEm7t6o4AV8DRyVMG9CwXq6UpKqpIZ/Hx8C+ppZxyLzhwKfuPssM2sOTAOygH5AY6AXMBU4OGGbJ4BjgFOB7YCOwC3AEQnrfAdcB9ybpDxjgWZA1+jfC8v8ySqbu+uV5heQCxwUvR8ALAT+CCwCHiR88V4AfgR+it63Sdh+CvC76P0I4C1gTLTul8BhZVy3I/AGsAp4DbgNeCjJZ0iljNcCb0f7exXYPmH5KcBXwFLgssRzUuQ4e0XnpW7CvEHAx9H7PsA7wHLge+DfwNYJ6zqwS/R+PHBdwrKLo22+A04vsu4RwIfASuAb4OqE7b6O1l0dvfYpOLcJ6/QF3gNWRP/2TfXcxDzPzYD7os/wE/BMwrJjgJnRZ1gAHFr0+xdNX13wdwY6RJ/tt9HnfCOa/3j0d1gRfUd2T9i+AfDP6O+5gvAdawC8CJxf5PN8DAwq5nPeD1wUvd8xKsO50fTOwDLChe0AYGE0/0EgH/gl+jtcklD+06LyLwEuK+H/4qvAlUXmTQcuiN5fB3wE1ClhHwdFZWiTbJ0i618HjC8yr0v0d9o2079PZXmpxpEZOxB+ANoDIwn/Qe6LptsRvpT/LmH7vYDPgO2BvwP3mJmVYd2HCf9pmhN+TE4p4ZiplPEk4DdAC2Br4A8AZrYbcEe0/9bR8doUdxB3fxf4GTigyH4fjt7nEa7Mtif8gB8InFNCuYnKcGhUnoOBXQn/+RP9TLh6bEoIImeb2bHRsv2jf5t6qDG+U2TfzQg/mrdGn+1G4MXo6jXxM2xxbopR2nl+kHA1vHu0r5uiMvQBHiAEx6ZRmXOTnY9i9Cdc+R4STb9MOE8tgA+ACQnrjgFyCMGyGeEHPJ8QDE4uWMnMsglB4cVijjeVEBQKjv0Fm85zf+BNd89P3MDdT2Hz2vvfExbvB3QmfB+uNLOuST7n/SR8z82sM9CDTd+vg4Cnix67iIOAd919YQnrlKYPIfBeEzVVfWJmx5Vjf5Ur05GrNrzYssaxHqhfwvo9gJ8SpqeweS3i84RlWYQrrh3irEv4UdoIZCUsf4gkNY4Uy3h5wvQ5wCvR+yuBiQnLGkbnYIsaR7T8OuDe6H1jwo96+yTrjib8Ry+YLrbGQWgu+GvCep0S1y1mvzcDN0XvO0Tr1ktYPoKoxkH4IZpeZPt3gBGlnZs45xloRfiB3q6Y9e4qKG9J379o+mq2rHHsVEIZmkbrNCEEtl+A7GLWq0+oBe0aTY8Bbk+yz52jdesAdwJnsqlmcT/w+4T/LwtL+CwF5U+slU0HhiY5bhbhSr9vNH098GzC8s+BsxKmjybUblcBr0bz7mbz73OzaJ0VwNok3+fxReb9OSr31YQLif6EWlTXVL4XmX6pxpEZP7r72oKJqJPsLjP7ysxWEpoGmppZ3STbLyp44+5roreNYq7bGliWMA9CE02xUizjooT3axLK1Dpx3+7+M6HJKpmHgcFmtg0wGPjA3b+KytHJzF4ws0VROW4g1D5Ks1kZCFd7iZ9vLzN73cx+NLMVwFkp7rdg318VmfcV4Wq7QLJzs5lSznNbwt/sp2I2bUtoniqrwnNjZnXN7K9mtiAqQ260aPvoVb+4Y0Xf6UeBk82sDjCMUEPagrsvIFwQ9CD0JbwAfBfVAPoTaiRxpHR+o+/748CpUc17OKGmVmApIUAXrP+cuzcl1HK3TrLOsmidHGCbFMv7C7CBcGGz3t2nAq8D1eKGCwWOzCiakvgiQjV7L3fflk1V9mTNTxXhe6CZmWUlzGtbwvrlKeP3ifuOjtk82cruPofww3sYmzdTQWjymku4qt2WcOUWuwyEGleih4HngLbu3oRwFVyw39JSSH9HaFpK1A74NoVyFVXSef6G8DdrWsx23xCu4ovzM+FKu8AOxayT+BlPIvSXHESoZXRIKMMSYG0Jx7qf8GN8ILDGizTrFTEVOJ7QR/VtNH0aoZ9nZpJtSvtbpOJ+YAih2bIx8HzCsknAsVHgS2YSsKeZFdvcmqKPi5lXEZ+tUihwVA2NCVcgy6P28qvSfcDoCn4GcHV0a+E+wFFpKuMTwJFmtp+ZbQ38hdK/ew8DFxB+OB8vUo6VwOro9sizUyzDY8AIM9stClxFy9+YcDW/NuovOClh2Y+EJqKdkuz7JaCTmZ1kZvXM7ERgN8JVdFxJz7O7f0/oe7jdzLYzs63MrCCw3AP8xswONLM6ZrZjwu2jM4Gh0fq9CT/WpZVhHeHKOotQqysoQz6h2e9GM2sd1U72iWqHRIEin9B5XmxtI8FU4DxCrQpCk955hCbAvCTbLCb53yFVbxKalsYSmpzWJyy7kRC4HjSznS1oTKgZAeDurxJqB89ENdWtzWwrYO/Eg0TfhfpAXaCumdVPuGvtDUJ/zZ+i9fYFBgL/KednqxQKHFXDzYS7UpYA/wNeqaTjDid0MC8ltMM+SvjBKE6Zy+jus4FzCcHge0Lbdmkdi48Qmiwmu/uShPl/IPyorwLGRWVOpQwvR59hMqEde3KRVc4B/mJmqwh9Mo8lbLuG0Bb+tpktN7PNfiDcfSlwJKG2sJTQWXxkkXKnqrTzfAqhiWMu8AOhjwd3n07ofL+J0NY+lU21oCvY1KdwDZvX4IrzAKHG9y0wJypHoj8AnxDuHlsG/I3Nf0seALoT+sxKMpUQpAoCx1uEQFXSeIb/Ay6P/g7JbjAokYdOhgcI5+eBIsuWEALA2qg8qwiBtzGbX6QMIlwYPEQIQl8S/j8dkrDO5YSLgEsJNw38Es3D3TcQanWHE/5e44BT3X1uWT5TZbOoo0YEM3sUmOvuaa/xSM1lZqcCI919v0yXRdJDNY5azMz2jKrjdaLbVY8Bnsl0uaT6ipoBzyE0A0kNpcBRu+1AaFdeTRiDcLa7f5jREkm1ZWaHEPqDFlN6c5hUY2qqEhGRWFTjEBGRWBQ4JOPM7GozK+0OnCrDqmiaeytHmvjKYmbtLGScTTa4VaoBBQ6p0qJ75J+wkE7bzWxAkeXjzey6hOndzez7ZLdqVsSPq1fzNPeWwWdYuPvXHvJMJRunUSiT5ZSSKXBIdfAW4T74RSWtZGY9CQOzrnP3MWU5kH6kREqnwCFlYuGhOAvMbJWZzTGzQQnLRpjZW2Y2xsx+MrMvzeywhOUdzWxqtO1/KSEnVJTH52Z3f4uQGTdZefoA/wX+7O63JVmnYGDZR1FzyYlmNsDMFprZH81sEXBfNCr7BQt5q36K3rdJ2M8UM/tdip81zrodzeyN6Ly8Zma3ldSEZ2YXR7Wr78zs9CLLjjCzD81spZl9Y2ZXJywuOA/Lo/OwT3Rb9mQzW2ohW+sEKz61ScH+3cxGWXjI0RIz+4dFaTqi27svt5Bv6wcze8DMmkTLNqtFROfnWjN7O/rcr5pZwfehuHLuEn13VkTHTWkAqFQsBQ4pqwWE5HRNC
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"tags": [],
"needs_background": "light"
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "jc88Ywn8TA67"
},
"source": [
""
],
"execution_count": null,
"outputs": []
}
]
}