Computer_Vision/Chapter07/Training_Fast_R_CNN.ipynb

951 lines
868 KiB
Plaintext
Raw Normal View History

2024-02-13 03:34:51 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-PyTorch/blob/master/Chapter07/Training_Fast_R_CNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 225,
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk
"headers": [
[
"content-type",
"application/javascript"
]
],
"ok": true,
"status": 200,
"status_text": ""
}
}
},
"id": "Ebq6AcmaF7wn",
"outputId": "27b10e0f-ff09-4006-f30a-0f1a61e4b856"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 36.7MB 77kB/s \n",
"\u001b[K |████████████████████████████████| 61kB 9.2MB/s \n",
"\u001b[K |████████████████████████████████| 102kB 13.6MB/s \n",
"\u001b[?25h Building wheel for selectivesearch (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Building wheel for contextvars (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
]
},
{
"data": {
"text/html": [
"\n",
" <input type=\"file\" id=\"files-5ca4c4b8-0766-4a39-aa43-80979f547422\" name=\"files[]\" multiple disabled\n",
" style=\"border:none\" />\n",
" <output id=\"result-5ca4c4b8-0766-4a39-aa43-80979f547422\">\n",
" Upload widget is only available when the cell has been executed in the\n",
" current browser session. Please rerun this cell to enable.\n",
" </output>\n",
" <script src=\"/nbextensions/google.colab/files.js\"></script> "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Saving kaggle.json to kaggle.json\n",
"kaggle.json\n",
"Downloading open-images-bus-trucks.zip to /content\n",
" 94% 344M/367M [00:02<00:00, 111MB/s]\n",
"100% 367M/367M [00:03<00:00, 128MB/s]\n"
]
}
],
"source": [
"!pip install -q --upgrade selectivesearch torch_snippets\n",
"from torch_snippets import *\n",
"import selectivesearch\n",
"from google.colab import files\n",
"files.upload() # upload kaggle.json file which you can get \n",
"# by clicking on Create New API token in your personal account\n",
"!mkdir -p ~/.kaggle\n",
"!mv kaggle.json ~/.kaggle/\n",
"!ls ~/.kaggle\n",
"!chmod 600 /root/.kaggle/kaggle.json\n",
"!kaggle datasets download -d sixhky/open-images-bus-trucks/\n",
"!unzip -qq open-images-bus-trucks.zip\n",
"from torchvision import transforms, models, datasets\n",
"from torch_snippets import Report\n",
"from torchvision.ops import nms\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
},
"id": "uzzreNa9GHg5",
"outputId": "5ebb0013-852d-4422-da2b-2407371b195a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" ImageID Source LabelName ... XClick2Y XClick3Y XClick4Y\n",
"0 0000599864fd15b3 xclick Bus ... 0.512700 0.650047 0.457197\n",
"1 00006bdb1eb5cd74 xclick Truck ... 0.241855 0.352130 0.437343\n",
"2 00006bdb1eb5cd74 xclick Truck ... 0.398496 0.409774 0.295739\n",
"3 00010bf498b64bab xclick Bus ... 0.493882 0.705228 0.521691\n",
"4 00013f14dd4e168f xclick Bus ... 0.303940 0.999062 0.523452\n",
"\n",
"[5 rows x 21 columns]\n"
]
}
],
"source": [
"IMAGE_ROOT = 'images/images'\n",
"DF_RAW = pd.read_csv('df.csv')\n",
"print(DF_RAW.head())"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 404
},
"id": "bZKVc06QGRGB",
"outputId": "b39388f9-1462-486f-c0d2-37e8f21ef389"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjwAAAGDCAYAAAAve8qnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9V5MkWZKl911mxN2DZGZldc9s784CI4v9/4IfgHfgZV8gsoslPaRZ0QzixOxSPKhec4+s3u55gUBkJKwkJCsjnZhdonr06FG9prXG+/V+vV/v1/v1fr1f79e/5sv+/30D79f79X69X+/X+/V+vV//X1/vgOf9er/er/fr/Xq/3q9/9dc74Hm/3q/36/16v96v9+tf/fUOeN6v9+v9er/er/fr/fpXf70Dnvfr/Xq/3q/36/16v/7VX++A5/16v96v9+v9er/er3/1l/9L/zjcDQ3Ae4+xhgZgwBiDMYZGo1FxznE4HAC4rGdqrex2O7z3tNqgNay1eO+5nC/88P1P1Fox+lnBeKyxWCs/3nmC8+SciTHSGhQAuYPtfcYYnLc4b7Z/ba3SWgXAObO9DvQ1DXLO5JwptVKa3P88zxhjKKVgjOHu7o4QAs/PT5xOx+3eMPI5xhq891hrGecR51y/Mf38ouMDMSZOxwutVqCCgXEa8d7jvcN7y34/c3c3k1LifD5TSyWvkVahFGgVWrG0aqil0UqjxEKNMo7OGqBRiEBjHD3O2e2e4ppZ14wPnnGeqLUSYwQD0y5grSHGlVoLYQz44Jnnif1+h7Ue5wdyzry8vFBq0bFupJippWJu/qNByZVlWWmtbXPg/YCzDu89zjlSzsScqLVSq8xbKXmbZwxYZ7DO4JzFOQOmAY3dbsff/M2vGIaBaTfhvGMYB4w1/PGPf+TLly/8+//17/gP/9vf8+XLF/7pn/6BGBPn84XWwNkJmuGnH585nxeCDwQf+vTivWWaR6wF5yshOH7zd79mf9hxOr8Q48I4BqZp4Hy+8OXLEylVLkujlMa6JGptDMOMs35bc+tl5XK8bHtMxstiMDjnsNbirMMaS2uN1irDGNgfZnLJnM+vGAPzPOGcw3mZ4vP5yLIsDINnnAZ288zHDx/BNGKNWGf5+PkTwzjwen5ljQvWSMSTUySuF3JOLBdZ648PB4L3TNMk6zQ4vHfsD3fc3T+wLomXlzPnU+IPv38mxSrmpNk+dVhrZU+lQl4zIQwc7u7JufDDj1+IKRLzSqXwzedvuLs/kPJCSovYGQfWgveGnBN//NMfiTHy+PjANI2kuJJzouRCzpXD/o6//dW/BSzrulJrJaVCa41pDgyDZ5g84+xJKbIsJ8Zx4tM3v6KUwj/+4+85nxaOxwspZsZxZhxHvLMMztFaJZeINTAMDmsNwyBr5ocffuJ4PPH4+Mjj4yPOyXg5ZwnBU2vheHqh1sL9/R3DEFjWhTWunE8XXl/OUC2UgWGY+PzpW6yxnI4vpJRIKVFKYVlW1jVijcNaWbPzOFNr5fX1SC2VcZxxzlFKpdaGdyPBTby+HvnuTz8gnUgcNNnDAMMYcN7SqEDDOIMNRuxEln2cs6zZcRzFFrYGTdaftUbWa63c39/xH//jf8AYw+//8DtiXLm72zEMgdP5zLIuup/FFgzDILbcOVqDZYnU1vj06RP7/Z6npyeen585n888PT1hMPr8lnEU2+usx2BIKZFzZgiBcRgA2UPGGILztFY5nU+Ukpn3s85NpbWGdQ4fAjlnjsej2OlxAgOn04WcM7tpxzCM1FoppZBLJqZV9sZ+h3VWfVGj1gI0wjjiw0CtVfxOqaRYtr+31gjWYYxh0X142O057PZq+8TX7KcZDMQcaa3hg8NYw7quYstbwzQYw8DD/p6cMz//9IUUE5fTQikF58T2Pn544OHxnmFy7A6BNUZ+fn5iWSLfff/CuiZiKtTaMNZhjKXkSs4VsRqeWhsxJfFyptCM4IFmKoaGofH4eM+3337DvJv49M29YITLCah8eHxkHAei7uOrr2bz1QA5F758+UJKiWEc8cFvr7XWYIz49PP5TGuN//N//+31Q766/iLguX5xwzTTMcObG2tN/i6O7eqrvu7vsz2IuQUhb18jBr5dn7T/nj/fK0i+o0ETMPbnXtcdrnzX23EwxuiGNW9/p++rtb55XmvtBmK+ukHZMNZi1MjrpwFtm5iKQbHY9sb+UnH2hVrLdRy2D++f9fWz//lr+8xfvOv63v5+01/Yn90obLkBiWI0+o+CL9PefJ65/dJ286H9czHQbj+nvb0PfY0xVt5rruD27eegoJsNKBljNoDQ59NaMTxiUKoC1pt5BtqfGZ2m997vXj7LYZ2TP63Fml8So8ZY5CvL9j5Zdm+fUcaYN+vuZii/mtfr/3ejbLhdW2yT3BrU2mj0/WW39dqfq5RCyoVa5LMwMiYdzFtr1SAanHXqgJquSfuLOZP1fhts9MXR3uyJWiulVpzOV61120fOWawxOJ0fQ3+fgFy5P7ZAqPqqgYLfAhtrHMYUvA9Y525skM6l7mXZX4Zara5ZCxhqKRQdF3FS1/nr9ujNmgVqq5hmNZ64jp/Ywuseq9szX9dA/72M4+0avAnMaqWZX5jCbW5v7Uz/jv4RPSDsY9Bqo1C218g6urEvfQG1bkVlZxhjtwC3NZCp/soh3e79m9/3wPFrG//VS7exuL3fWiu1NUoplFK2tWadBM0AzogTvrXX/c+3P3UDPMWUm981Wr21s1e78WafXt2WBvlvv8PAFqj0t2A18KNu/vFrH2OswaLfBRgr+8DePI+sPbO9piGApo/qdW/rGOoA9jXap/drm/ZmnGrbxvt2XzvncBWg6lqzVAvG6j0h0ZJ1updkK1EpVAzWNPmx12dsjZt1Z6it3uyP+sZuy35p1Mq2Bt7ijqtd69/R7dVfuv5FgEcMRcUHLyhYI/TaKinnNzeZcxHj5iRyD97jFL2CLI7dbpZod137KNB0gbfWqL7KRLQq8UYfrW2yrg9cqyzG9gaKtO1++uB0QwFmY5tqa2q0rsyOcw5gQ83GGOZ5p4vAUmol5bRFNaZHNq0xDgMhBC7WbnfR9JmdcxRFobVWaFXu3Xp8sCzLhfP5tVsouc83Gcem0ZfZFimAcRJp1dYwpqnxb5tlqd2QWXBBHOCyLBhrN6Tcx0qcnScMXubaOkoRA0ReSTmzLMICWSdzKk6i0pDNamhwY0j6mBpkjFMSpsNaGctSC857xnGELSpiY3JKKzRl4ax1OG/wXpzL8/Mz0zzx6ZtPDNPI6XQkpcQ0TnzzzTe02vjTn/5Ea5X7+3uWZeF0OlFrxTqLwSkIUvNxYyT77533fPwojMLHD5+Y9xM5R2JaaQ1SyljruLu7Y10zl+VIa5VhGGgNYszUGJmnmf2844Qhresbp9zXc9/0BYmWugEouXA6nYTJ8+7KNqIgpECKsK6VMFiCn5mmHXd3D9RayK+FnDPf/ekHKg0/eJy3HPYju91MyZEheIxpOPsBaw3zKKbh6elnjqeFw+HANE3EuHA+e5wLPD4+4NyFP/lnSJn1spBzJa3CRngXcNZTUiFFYXhSFuO6rgsN+ObzN0zzxDA6QrCUslBqZpp3fPrmAwC5RNmf1lJL5cPHB+Z5oqpdarVRSmUIA/v5npwSy7JQaybGlVIKl0WA77yf2O1HhnFgt9vTWuP7H34gxsTPX34ixkTwE+M0YmjkFGnFUq0FKrUJ3Rpj0ah+zziO3N8Xgp9wzhLXqEBBQFUIbgNGxhrOlwWWJs7NqIFrYp+cF/tzPJ50f+UbI24Yx4lx3OkzN0ouPD090ZrsXwHlTteHOPvj8chySeRUNqAlDPTVWaRcycWIQ7MQnMcYsd1NQYbB0xrklEgpbw6oNUOtAhCtMaQU+e6777HWkFLe7EspDWssQwjCbqgtEBa/bX9uqI3K6XxkHEc+fnxkH3fsdvPG8NRaOR2PZAXyrbZtTayxkOLyJnC16idSSpuTrm0gDNefaZpISVlnBdRiwz0gzxbXuO3NcRr4ePdIA0oRvzBMA9YZ9WcVY70GTGY
"text/plain": [
"<Figure size 720x720 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"class OpenImages(Dataset):\n",
" def __init__(self, df, image_folder=IMAGE_ROOT):\n",
" self.root = image_folder\n",
" self.df = df\n",
" self.unique_images = df['ImageID'].unique()\n",
" def __len__(self): return len(self.unique_images)\n",
" def __getitem__(self, ix):\n",
" image_id = self.unique_images[ix]\n",
" image_path = f'{self.root}/{image_id}.jpg'\n",
" image = cv2.imread(image_path, 1)[...,::-1] # conver BGR to RGB\n",
" h, w, _ = image.shape\n",
" df = self.df.copy()\n",
" df = df[df['ImageID'] == image_id]\n",
" boxes = df['XMin,YMin,XMax,YMax'.split(',')].values\n",
" boxes = (boxes * np.array([w,h,w,h])).astype(np.uint16).tolist()\n",
" classes = df['LabelName'].values.tolist()\n",
" return image, boxes, classes, image_path\n",
"ds = OpenImages(df=DF_RAW)\n",
"im, bbs, clss, _ = ds[9]\n",
"show(im, bbs=bbs, texts=clss, sz=10)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "wYp_nxdDGTS-"
},
"outputs": [],
"source": [
"def extract_candidates(img):\n",
" img_lbl, regions = selectivesearch.selective_search(img, scale=200, min_size=100)\n",
" img_area = np.prod(img.shape[:2])\n",
" candidates = []\n",
" for r in regions:\n",
" if r['rect'] in candidates: continue\n",
" if r['size'] < (0.05*img_area): continue\n",
" if r['size'] > (1*img_area): continue\n",
" x, y, w, h = r['rect']\n",
" candidates.append(list(r['rect']))\n",
" return candidates\n",
"def extract_iou(boxA, boxB, epsilon=1e-5):\n",
" x1 = max(boxA[0], boxB[0])\n",
" y1 = max(boxA[1], boxB[1])\n",
" x2 = min(boxA[2], boxB[2])\n",
" y2 = min(boxA[3], boxB[3])\n",
" width = (x2 - x1)\n",
" height = (y2 - y1)\n",
" if (width<0) or (height <0):\n",
" return 0.0\n",
" area_overlap = width * height\n",
" area_a = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])\n",
" area_b = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])\n",
" area_combined = area_a + area_b - area_overlap\n",
" iou = area_overlap / (area_combined+epsilon)\n",
" return iou"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "HziOjkZqGWCI"
},
"outputs": [],
"source": [
"FPATHS, GTBBS, CLSS, DELTAS, ROIS, IOUS = [], [], [], [], [], []\n",
"N = 500\n",
"for ix, (im, bbs, labels, fpath) in enumerate(ds):\n",
" if(ix==N):\n",
" break\n",
" H, W, _ = im.shape\n",
" candidates = extract_candidates(im)\n",
" candidates = np.array([(x,y,x+w,y+h) for x,y,w,h in candidates])\n",
" ious, rois, clss, deltas = [], [], [], []\n",
" ious = np.array([[extract_iou(candidate, _bb_) for candidate in candidates] for _bb_ in bbs]).T\n",
" for jx, candidate in enumerate(candidates):\n",
" cx,cy,cX,cY = candidate\n",
" candidate_ious = ious[jx]\n",
" best_iou_at = np.argmax(candidate_ious)\n",
" best_iou = candidate_ious[best_iou_at]\n",
" best_bb = _x,_y,_X,_Y = bbs[best_iou_at]\n",
" if best_iou > 0.3: clss.append(labels[best_iou_at])\n",
" else : clss.append('background')\n",
" delta = np.array([_x-cx, _y-cy, _X-cX, _Y-cY]) / np.array([W,H,W,H])\n",
" deltas.append(delta)\n",
" rois.append(candidate / np.array([W,H,W,H]))\n",
" FPATHS.append(fpath)\n",
" IOUS.append(ious)\n",
" ROIS.append(rois)\n",
" CLSS.append(clss)\n",
" DELTAS.append(deltas)\n",
" GTBBS.append(bbs)\n",
"FPATHS = [f'{IMAGE_ROOT}/{stem(f)}.jpg' for f in FPATHS] \n",
"FPATHS, GTBBS, CLSS, DELTAS, ROIS = [item for item in [FPATHS, GTBBS, CLSS, DELTAS, ROIS]]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "_nU6cn_7GYWJ"
},
"outputs": [],
"source": [
"targets = pd.DataFrame(flatten(CLSS), columns=['label'])\n",
"label2target = {l:t for t,l in enumerate(targets['label'].unique())}\n",
"target2label = {t:l for l,t in label2target.items()}\n",
"background_class = label2target['background']"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "Pke_rPIxGaVq"
},
"outputs": [],
"source": [
"normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
"def preprocess_image(img):\n",
" img = torch.tensor(img).permute(2,0,1)\n",
" img = normalize(img)\n",
" return img.to(device).float()\n",
"def decode(_y):\n",
" _, preds = _y.max(-1)\n",
" return preds\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "2BKC4AgmGb4i"
},
"outputs": [],
"source": [
"class FRCNNDataset(Dataset):\n",
" def __init__(self, fpaths, rois, labels, deltas, gtbbs):\n",
" self.fpaths = fpaths\n",
" self.gtbbs = gtbbs\n",
" self.rois = rois\n",
" self.labels = labels\n",
" self.deltas = deltas\n",
" def __len__(self): return len(self.fpaths)\n",
" def __getitem__(self, ix):\n",
" fpath = str(self.fpaths[ix])\n",
" image = cv2.imread(fpath, 1)[...,::-1]\n",
" gtbbs = self.gtbbs[ix]\n",
" rois = self.rois[ix]\n",
" labels = self.labels[ix]\n",
" deltas = self.deltas[ix]\n",
" assert len(rois) == len(labels) == len(deltas), f'{len(rois)}, {len(labels)}, {len(deltas)}'\n",
" return image, rois, labels, deltas, gtbbs, fpath\n",
"\n",
" def collate_fn(self, batch):\n",
" input, rois, rixs, labels, deltas = [], [], [], [], []\n",
" for ix in range(len(batch)):\n",
" image, image_rois, image_labels, image_deltas, image_gt_bbs, image_fpath = batch[ix]\n",
" image = cv2.resize(image, (224,224))\n",
" input.append(preprocess_image(image/255.)[None])\n",
" rois.extend(image_rois)\n",
" rixs.extend([ix]*len(image_rois))\n",
" labels.extend([label2target[c] for c in image_labels])\n",
" deltas.extend(image_deltas)\n",
" input = torch.cat(input).to(device)\n",
" rois = torch.Tensor(rois).float().to(device)\n",
" rixs = torch.Tensor(rixs).float().to(device)\n",
" labels = torch.Tensor(labels).long().to(device)\n",
" deltas = torch.Tensor(deltas).float().to(device)\n",
" return input, rois, rixs, labels, deltas"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "yIZ1grVAJBoi"
},
"outputs": [],
"source": [
"n_train = 9*len(FPATHS)//10\n",
"train_ds = FRCNNDataset(FPATHS[:n_train], ROIS[:n_train], CLSS[:n_train], DELTAS[:n_train], GTBBS[:n_train])\n",
"test_ds = FRCNNDataset(FPATHS[n_train:], ROIS[n_train:], CLSS[n_train:], DELTAS[n_train:], GTBBS[n_train:])\n",
"\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"train_loader = DataLoader(train_ds, batch_size=2, collate_fn=train_ds.collate_fn, drop_last=True)\n",
"test_loader = DataLoader(test_ds, batch_size=2, collate_fn=test_ds.collate_fn, drop_last=True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "Z64bqciDJvYp"
},
"outputs": [],
"source": [
"from torchvision.ops import RoIPool\n",
"class FRCNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" rawnet = torchvision.models.vgg16_bn(pretrained=True)\n",
" for param in rawnet.features.parameters():\n",
" param.requires_grad = True\n",
" self.seq = nn.Sequential(*list(rawnet.features.children())[:-1])\n",
" self.roipool = RoIPool(7, spatial_scale=14/224)\n",
" feature_dim = 512*7*7\n",
" self.cls_score = nn.Linear(feature_dim, len(label2target))\n",
" self.bbox = nn.Sequential(\n",
" nn.Linear(feature_dim, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 4),\n",
" nn.Tanh(),\n",
" )\n",
" self.cel = nn.CrossEntropyLoss()\n",
" self.sl1 = nn.L1Loss()\n",
" def forward(self, input, rois, ridx):\n",
" res = input\n",
" res = self.seq(res)\n",
" rois = torch.cat([ridx.unsqueeze(-1), rois*224], dim=-1)\n",
" res = self.roipool(res, rois)\n",
" feat = res.view(len(res), -1)\n",
" cls_score = self.cls_score(feat)\n",
" bbox = self.bbox(feat) # .view(-1, len(label2target), 4)\n",
" return cls_score, bbox\n",
" def calc_loss(self, probs, _deltas, labels, deltas):\n",
" detection_loss = self.cel(probs, labels)\n",
" ixs, = torch.where(labels != background_class)\n",
" _deltas = _deltas[ixs]\n",
" deltas = deltas[ixs]\n",
" self.lmb = 10.0\n",
" if len(ixs) > 0:\n",
" regression_loss = self.sl1(_deltas, deltas)\n",
" return detection_loss + self.lmb * regression_loss, detection_loss.detach(), regression_loss.detach()\n",
" else:\n",
" regression_loss = 0\n",
" return detection_loss + self.lmb * regression_loss, detection_loss.detach(), regression_loss\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "bv5RfeAWKP_f"
},
"outputs": [],
"source": [
"def train_batch(inputs, model, optimizer, criterion):\n",
" input, rois, rixs, clss, deltas = inputs\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" _clss, _deltas = model(input, rois, rixs)\n",
" loss, loc_loss, regr_loss = criterion(_clss, _deltas, clss, deltas)\n",
" accs = clss == decode(_clss)\n",
" loss.backward()\n",
" optimizer.step()\n",
" return loss.detach(), loc_loss, regr_loss, accs.cpu().numpy()\n",
"def validate_batch(inputs, model, criterion):\n",
" input, rois, rixs, clss, deltas = inputs\n",
" with torch.no_grad():\n",
" model.eval()\n",
" _clss,_deltas = model(input, rois, rixs)\n",
" loss, loc_loss, regr_loss = criterion(_clss, _deltas, clss, deltas)\n",
" _clss = decode(_clss)\n",
" accs = clss == _clss\n",
" return _clss, _deltas, loss.detach(), loc_loss, regr_loss, accs.cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 557,
"referenced_widgets": [
"51c017bf540f4b4b99b98c3b45e9c683",
"8e5a5f297879447eb77a3d1a69e4d823",
"bde7f64fb9ff4bb39448d90f70996b30",
"6227ae96f0c8475ea3a056f1dd1a62f9",
"737a4e3955454a29a6aa33fcc78416b2",
"c3d91b7255274b3f8894d84dc34fa381",
"ae9603e8db914d4f87646bb7726e7211",
"f6997519e2f6415590f2cd8a4779c586"
]
},
"id": "EYAKuLTqKQUf",
"outputId": "08f760ff-079b-467b-b107-beec795c5b79"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/vgg16_bn-6c64b313.pth\" to /root/.cache/torch/hub/checkpoints/vgg16_bn-6c64b313.pth\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "51c017bf540f4b4b99b98c3b45e9c683",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=553507836.0), HTML(value='')))"
]
},
"metadata": {
"tags": []
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"EPOCH: 4.760\tval_loss: 1.340\tval_loc_loss: 0.679\tval_regr_loss: 0.066\tval_acc: 0.765\t(64.17s - 3.24s remaining)"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/5 [00:00<?, ?it/s]/usr/local/lib/python3.6/dist-packages/numpy/core/fromnumeric.py:3335: RuntimeWarning: Mean of empty slice.\n",
" out=out, **kwargs)\n",
"/usr/local/lib/python3.6/dist-packages/numpy/core/_methods.py:161: RuntimeWarning: invalid value encountered in double_scalars\n",
" ret = ret.dtype.type(ret / rcount)\n",
"100%|██████████| 5/5 [00:00<00:00, 739.32it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"EPOCH: 4.800\tval_loss: 0.984\tval_loc_loss: 0.369\tval_regr_loss: 0.061\tval_acc: 0.854\t(64.19s - 2.67s remaining)\r",
"EPOCH: 4.840\tval_loss: 1.032\tval_loc_loss: 0.419\tval_regr_loss: 0.061\tval_acc: 0.779\t(64.21s - 2.12s remaining)\r",
"EPOCH: 4.880\tval_loss: 1.138\tval_loc_loss: 0.489\tval_regr_loss: 0.065\tval_acc: 0.691\t(64.24s - 1.58s remaining)\r",
"EPOCH: 4.920\tval_loss: 1.693\tval_loc_loss: 0.811\tval_regr_loss: 0.088\tval_acc: 0.609\t(64.26s - 1.04s remaining)\r",
"EPOCH: 4.960\tval_loss: 1.475\tval_loc_loss: 0.567\tval_regr_loss: 0.091\tval_acc: 0.824\t(64.28s - 0.52s remaining)\r",
"EPOCH: 5.000\tval_loss: 1.979\tval_loc_loss: 0.967\tval_regr_loss: 0.101\tval_acc: 0.455\t(64.30s - 0.00s remaining)"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfEAAAF0CAYAAAAzTwAWAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOzdd3QV1d7G8e8+qaSQkE5IIKG30BJAelAEBQRUEBERFS5WEPV6La+9XPVeRS827IAFFERFsCBqQBSVBELvRQg1dBIMhGTePyYYgpQgOZmU57PWWSRn5sz5ZQN5zp7Zs7exLAsREREpf1xOFyAiIiJ/j0JcRESknFKIi4iIlFMKcRERkXJKIS4iIlJOKcRFRETKKbeFuDEm1hjzgzFmhTFmuTHmjlPsY4wxY40x64wxS4wxrdxVj4iISEXj6cZjHwPutixroTEmEEgzxnxrWdaKE/a5FKhX8GgLvFbw52mFhYVZcXFxJVZkdnY2/v7+JXa88k7tUZTao5Daoii1R1Fqj0LuaIu0tLTdlmWFn/y820LcsqztwPaCrw8ZY1YCNYATQ7wvMNGyZ5z5xRgTbIypXvDaU4qLiyM1NbXE6kxJSSE5ObnEjlfeqT2KUnsUUlsUpfYoSu1RyB1tYYz5/ZTPl8aMbcaYOGAu0NSyrIMnPD8DeMayrHkF338H3GtZVupJrx8BjACIjIxMnDx5conVlpWVRUBAQIkdr7xTexSl9iiktihK7VGU2qOQO9qia9euaZZlJZ38vDtPpwNgjAkAPgFGnxjg58KyrDeANwCSkpKskvyEo0+PRak9ilJ7FFJbFKX2KErtUag028Kto9ONMV7YAf6BZVnTTrHLViD2hO9jCp4TERGRs3BbT9wYY4C3gZWWZY05zW7TgduNMZOxB7QdONP1cBERKX9yc3PJyMggJyfH6VJKRVBQECtXrvxbr/X19SUmJgYvL69i7e/O0+kdgCHAUmNMesFzDwA1ASzLGgd8CfQE1gGHgRvcWI+IiDggIyODwMBA4uLisPt3FduhQ4cIDAw859dZlsWePXvIyMggPj6+WK9x5+j0ecAZ/7YKRqXf5q4aRETEeTk5OZUmwM+HMYbQ0FAyMzOL/RrN2CYiIm6nAC+ec20nhbiIiEg5pRAXEZEKb//+/bz66qslcqzk5OQSnXTsfCjERUSkwjtdiB87dsyBakqO2yd7EREROe6xL5azYtvfmvfrtBpHV+WRy5qccZ/77ruP9evX06JFC7y8vPD19aVatWqsWrWKN954g0cffZSwsDCWLVtGYmIi77//frGuT0+aNIl///vfWJZFr169ePbZZ8nLy+P6668nNTUVYww33ngjd955J2PHjmXcuHF4enrSuHFjSmL20Uod4vn5Fr9sO0YXy9KgCxGRCuyZZ55h2bJlpKenk5KSQq9evVi2bBnx8fGkpKSwaNEili9fTnR0NB06dOCnn36iY8eOZzzmtm3buPfee0lLS6NatWp0796dzz77jJCQELZu3cqyZcsA+yzA8Ro2btyIj4/Pn8+dr0od4rNW7GDckiP4RqxldLf6TpcjIlLhna3HXFratGlT5F7sNm3aEBMTA0CLFi3YtGnTWUN8wYIFJCcnEx5uLy42ePBg5s6dy+jRo9mwYQMjR46kV69edO/eHYBmzZoxePBg+vXrR79+/Urk56jU18R7NImiUw1PXpy9ls/TNduriEhlcfJSoT4+Pn9+7eHhcV7XyqtVq8bixYtJTk5m3LhxDB8+HICZM2dy2223sXDhQlq3bl0i1+MrdYgbYxjaxJs28SHcM3UJab/vc7okERFxg8DAQA4dOlSix2zTpg1z5sxh9+7d5OXlMWnSJLp06cKePXvIz8/nyiuv5Mknn2ThwoXk5+ezZcsWunbtyrPPPsuBAwfIyso67xoq9el0AE+X4fVrE7n81Z+46b1UPr21A7Ehfk6XJSIiJSg0NJQOHTrQtGlTqlSpQmRk5Hkfs3r16jzzzDN07dr1z4Ftffv25eeff+aKK64gPz8fgKeffpq8vDyuvfZaDhw4gGVZjBo1iuDg4POuodKHOEA1f2/evr41l7/yE8MnpDL1lnYE+hZv8nkRESkfPvzww1M+n5ycXGTp0JdffvmMx0lJSfnz60GDBjFo0KAi2xMSEli4cOFfXjdv3rziF1tMlfp0+onqhAfw2rWJrM/MYuSkRRzLy3e6JBERkTNSiJ+gQ90wnujXlJTVmTw58+8tIyciIhXD5ZdfTosWLYo8vvnmG6fLKkKn008yqE1N1u/K4q15G6kT7s+QdnFOlyQiIg749NNPnS7hrBTip3B/z0Zs2pPNo1+soGaoP13qhztdkoiIyF/odPopeLgM/7u6JfUjA7n9g4Ws3VmytyWIiIiUBIX4afj7ePL20CR8vT24ccIC9mQdcbokERGRIhTiZxAdXIW3rkti18Ej3PReGkeO5TldkoiIyJ8U4mfRPDaYMVe1IPX3fdz3yVIsy3K6JBERcbOAgIDTbtu0aRNNmzYtxWpOTyFeDL2aVeef3evz6aKtvPLDOqfLERERATQ6vdhu61qXDZnZPDdrDXFh/vRuFu10SSIi5dO7vf76XJN+0OYfcPQwfDDgr9tbXAMtB0P2Hvj4uqLbbph51re87777iI2N5bbbbgPg0UcfxdPTkx9++IF9+/aRm5vLk08+Sd++fc/pR8nJyeGWW24hNTUVT09PxowZQ1JSEsuXL+eGG27g6NGj5Ofn88knnxAdHc1VV11FRkYGeXl5PPTQQwwcOPCc3u9kCvFiMsbw9JUJbN57mLs/XkxMNT9axJ7/vLciIuJ+AwcOZPTo0X+G+Mcff8w333zDqFGjqFq1Krt37+aCCy6gT58+GGOKfdxXXnkFYwxLly5l1apVdO/enbS0NMaNG8cdd9zB4MGDOXr0KHl5eXz55ZdER0czc6b9oePAgQPn/XMpxM+Bj6cHrw9JpN+r9hzrn9/egRrBVZwuS0SkfDlTz9nb78zb/UOL1fM+WcuWLdm1axfbtm0jMzOTatWqERUVxZ133sncuXNxuVxs3bqVnTt3EhUVVezjzps3j5EjRwLQsGFDatWqxbp162jXrh1PPfUUGRkZXHHFFdSrV4+EhATuvvtu7r33Xnr37k2nTp3O+ec4ma6Jn6PQAB/eGdqaI7l5DBu/gKwj578erIiIuN+AAQOYOnUqH330EQMHDuSDDz4gMzOTtLQ00tPTiYyMJCcnp0Te65prrmH69OlUqVKFnj178v3331O/fn0WLlxIQkICDz74II8//vh5v49C/G+oFxnIK4NbsXZXFndMWkRevkasi4iUdQMHDmTy5MlMnTqVAQMGcODAASIiIvDy8uKHH37g999/P+djdurUiQ8++ACANWvWsHnzZurVq8eGDRuoXbs2o0aNom/fvixZsoRt27bh5+fHtddeyz333HPKlc7OlUL8b+pcP5xH+zThu1W7ePpLLZYiIlLWNWnShEOHDlGjRg2qV6/O4MGDSU1NJSEhgYkTJ9KwYcNzPuatt95Kfn4+CQkJDBw4kPHjx+Pj48PHH39M06ZNadGiBcuWLeO6665j6dKltGnThhYtWvDYY4/x4IMPnvfPpGvi52HIBbX+XCyldngA17St6XRJIiJyBkuXLv3z67CwMObPn3/K/bKysk57jLi4OJYtWwaAr68v7777bpHthw4d4r777uO+++4r8nyPHj3o0aPH3y39lBTi5+mh3o35fU82D3++jFqhfnSoG+Z0SSIiUknodPp58nAZxg5qSZ3wAG5+P411u07/6U1ERMqPpUuX/mU98bZt2zpdVhHqiZeAQF8v3hqaxOWv/sSwCQv47NYOVPP3drosEZEyw7Ksc7r/uixISEggPT29VN/zXKf2Vk+8hMSG+PH6kCS2H8jhpve1WIqIyHG+vr7s2bNHa0+chWVZ7NmzB19f32K/Rj3xEpRYqxrPDWjOqEmLeGDaMp4b0KzcffIUESlpMTExZGRkkJmZ6XQppSInJ+ecgvhEvr6+xMTEFHt/hXgJ69M8mg2ZWbw4ey11IwK
"text/plain": [
"<Figure size 576x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"frcnn = FRCNN().to(device)\n",
"criterion = frcnn.calc_loss\n",
"optimizer = optim.SGD(frcnn.parameters(), lr=1e-3)\n",
"\n",
"n_epochs = 5\n",
"log = Report(n_epochs)\n",
"for epoch in range(n_epochs):\n",
"\n",
" _n = len(train_loader)\n",
" for ix, inputs in enumerate(train_loader):\n",
" loss, loc_loss, regr_loss, accs = train_batch(inputs, frcnn, \n",
" optimizer, criterion)\n",
" pos = (epoch + (ix+1)/_n)\n",
" log.record(pos, trn_loss=loss.item(), trn_loc_loss=loc_loss, \n",
" trn_regr_loss=regr_loss, \n",
" trn_acc=accs.mean(), end='\\r')\n",
" \n",
" _n = len(test_loader)\n",
" for ix,inputs in enumerate(test_loader):\n",
" _clss, _deltas, loss, \\\n",
" loc_loss, regr_loss, accs = validate_batch(inputs, \n",
" frcnn, criterion)\n",
" pos = (epoch + (ix+1)/_n)\n",
" log.record(pos, val_loss=loss.item(), val_loc_loss=loc_loss, \n",
" val_regr_loss=regr_loss, \n",
" val_acc=accs.mean(), end='\\r')\n",
" \n",
" log.report_avgs(epoch+1)\n",
"\n",
"# Plotting training and validation metrics\n",
"log.plot_epochs('trn_loss,val_loss'.split(','))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "HlQozCQsKS2B"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"%matplotlib inline\n",
"import matplotlib.patches as mpatches\n",
"from torchvision.ops import nms\n",
"from PIL import Image\n",
"def test_predictions(filename):\n",
" img = cv2.resize(np.array(Image.open(filename)), (224,224))\n",
" candidates = extract_candidates(img)\n",
" candidates = [(x,y,x+w,y+h) for x,y,w,h in candidates]\n",
" input = preprocess_image(img/255.)[None]\n",
" rois = [[x/224,y/224,X/224,Y/224] for x,y,X,Y in candidates]\n",
" rixs = np.array([0]*len(rois))\n",
" rois, rixs = [torch.Tensor(item).to(device) for item in [rois, rixs]]\n",
" with torch.no_grad():\n",
" frcnn.eval()\n",
" probs, deltas = frcnn(input, rois, rixs)\n",
" confs, clss = torch.max(probs, -1)\n",
" candidates = np.array(candidates)\n",
" confs, clss, probs, deltas = [tensor.detach().cpu().numpy() for tensor in [confs, clss, probs, deltas]]\n",
" \n",
" ixs = clss!=background_class\n",
" confs, clss, probs, deltas, candidates = [tensor[ixs] for tensor in [confs, clss, probs, deltas, candidates]]\n",
" bbs = candidates + deltas\n",
" ixs = nms(torch.tensor(bbs.astype(np.float32)), torch.tensor(confs), 0.05)\n",
" confs, clss, probs, deltas, candidates, bbs = [tensor[ixs] for tensor in [confs, clss, probs, deltas, candidates, bbs]]\n",
" if len(ixs) == 1:\n",
" confs, clss, probs, deltas, candidates, bbs = [tensor[None] for tensor in [confs, clss, probs, deltas, candidates, bbs]]\n",
" \n",
" bbs = bbs.astype(np.uint16)\n",
" _, ax = plt.subplots(1, 2, figsize=(20,10))\n",
" show(img, ax=ax[0])\n",
" ax[0].grid(False)\n",
" ax[0].set_title(filename.split('/')[-1])\n",
" if len(confs) == 0:\n",
" ax[1].imshow(img)\n",
" ax[1].set_title('No objects')\n",
" plt.show()\n",
" return\n",
" else:\n",
" show(img, bbs=bbs.tolist(), texts=[target2label[c] for c in clss.tolist()], ax=ax[1])\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 554
},
"id": "fPutsAY4LYhC",
"outputId": "7de26a81-91b5-4500-a98f-f68d54753405"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABGoAAAIZCAYAAAD+0dlTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9e7B9W3bX9RlzzrX2Puf8zv3dV7/ShG4TEwsSKiAWDxGEKiQWYBkeKhKCEv0DEKp8UIWKUZ6WDwTKV1JiCixjrIQU8qwClQIs1JQoAWOCwbz6QXc63fd5XnuvNecc/jHGXGvtc3/3d28nt7vPTc/vvft3ztl7r7XmmmvOMb5zvKaoKh0dHR0dHR0dHR0dHR0dHR0dX3iEL3QDOjo6Ojo6Ojo6Ojo6Ojo6OjoM3VDT0dHR0dHR0dHR0dHR0dHR8UDQDTUdHR0dHR0dHR0dHR0dHR0dDwTdUNPR0dHR0dHR0dHR0dHR0dHxQNANNR0dHR0dHR0dHR0dHR0dHR0PBN1Q09HR0dHR0dHR0dHR0dHR0fFA0A01HR2fA4jIrxGRj4nItYj8nM/hdT4sIioi6U0+/70i8m2fq+t/NhCRHxWRX/4FuO61iHzZ5/u6HR0dHR0dHR0dHR0dPxF0Q03HFzVE5HkR+R9E5EZEPiIiv9Hf/1Ui8jdE5FUR+TER+a9F5HJz3D8rIv+biNyKyF97wqn/MPA7VPWRqn7P5rivEJHD1ngiIr9MRL7Xr/WSt+eDm88/KCJ/VkReFpGPi8hvfYf74DeIyN/1PvghEfnF/v4oIt/lBhYVkV/6Dl7zq0XkL4vIZ0REP8tjn9pf9+HP4Id/8q3u6Ojo6Ojo6Ojo6Oj43KMbajq+2PFfABPwPuDrgW8Wka8CHgN/EPgS4GcAHwT+481xLwN/DPgP3uS8HwK+702u9zfvvff9wNeq6rN+vf8P+ObN598G/Ii38VcB/76I/LK3eX9PhYj8E8B/CPwW4BL4JcDWqPE3gN8E/Ng7cb0NZuA7gX/pJ3DsW/VXR0dHR0dHR0dHR0fHuxbdUNPxRQsRuQB+HfBNqnqtqn8D+HPAN6jqt6vqX1LVW1V9BfjjwC9qx6rq/6yq3wl84t45dyJyDUTg74jID20++w3Aq8Bf2R6jqp9S1e15CvAP+jGPgF8K/CFVnVX17wDfBXzjvdv5RhH5hIh8UkR+173P9iLyHSJyJSJ/S0S+ZvPZ7wN+v6p+t6pWVf37qvr3vV2Tqv4x75fyhP77VSLyPSLyuqd5/d57n3+DRym9JCK/5949/4CqfitPNmYhIv+mR/dcicj3i8iveTv99SbnUhFp/fknReRbROR/8nP/dRH50Oa7v0JEfkBEXhOR/9I//5ff7NwdHR0dHR0dHR0dHR3vNLqhpuOLGV8JZFX9e5v3/g7wVU/47i/hTYwKW6jqUVUf+Z9fo6pfDiAizwC/H/jXn3SciPx0EXkVuAN+F/AftY/u/Wy/f/W9U/wy4CuAXwH87nu1YP5p4E8BzwPfDvwZERlEJAL/CPAeEflBT6v6z0Xk7K3u03ED/GbgWSzS57eJyNf5/fxMLMrlG7ColxeAn/Y2zwvwQ8AvxiKbfh/wbSLygfbhU/rr7eDrgT8AvAj8beC/83O+iBnB/i1v7w8A/+hncd6Ojo6Ojo6Ojo6Ojo6fNLqhpuOLGY+A1++99xqWArTA04P+BeDf/Ulc6w8A36qqH3/Sh6r6UU/leRH4d4D/19+/Av5X4JtEZC8i/zAWBXR+7xS/T1VvVPV7gT8B/PObz/4vVf0uVZ2BPwLsgV+ApVINwK/HjCI/G/g5fv23hKr+NVX9Xo/E+b+B/x74x/3jXw/8BVX9X1T1CHwTUN/Oef3cf0pVP+Hn/g4svennbT5/Yn+9TfzFTbt+D/ALReRLgV8JfJ+q/mlVzcB/yjuf8tXR0dHR0dHR0dHR0fFUdENNxxczroFn7r33DHDV/hCRX4BFofz6e5E3bxsi8rOBXw780bf6rqq+DPw3wJ/d7OT09cA/AHwMi1L5NuC+wedjm98/gkWxvOEzVa1+7Jdg0SgA/5mqflJVP4MZcn7l27yvny8if1VEPi0irwG/FTOc4OffXvcGeOntnNfP/ZtF5G97weBXsQiiF+9/7036662wbdc1Vm/oS57QZuWN/dzR0dHR0dHR0dHR0fE5RTfUdHwx4+8BSUS+YvPe1+ApTmLbav854BtV9a884fi3i18KfBj4qIj8GJaq8+tE5G+9yfcT8F7ciKSqH1HVX62q71HVn48ZLP6Pe8d86eb3n85p7ZzlMxEJWArSJ7z2zseB7a5Ln80OTN+O9c+Xqupj4FtYU7Q+ee+651g60VvCa8b8ceB3AC945Mz/w2n61xYn/fU2sG3XIywl7BPe5p+2+Uz47NK1Ojo6Ojo6Ojo6Ojo6ftLohpqOL1p4lMefBn6/iFyIyC/C6rn8tyLy1cBfAn6nqv75+8eKSBSRPWYkCJ6WNLzJpf4r4Mux1KKfjRk0/iLwtX6uXysi/5CIBBF5DxbV8j0eLYKI/AwRufTtsn8TVofmj9y7xjeJyLnvWPVbgO/YfPZz/RoJ+FeBI/Dd/tmfAH6niLxXRJ4D/jXgL2zuc+f3CTD6fTaDySXwsqoeROTnAb9xc83vAn61iPxjIjJi9XnC5rzi5x39772I7PzjC8xg9Gn/7LewqcnzVv31NvArN+36A8B3q+rHsGfys0Tk67yv/hXg/W/znB0dHR0dHR0dHR0dHe8IuqGm44sdvx04A34cq7Hy21T1+4B/A3gP8K0icu2vbTHhb8BSh74Zq+9yh0WBvAG+c9SPtReWcnVQ1U/7Vz6IGYWugO/Farn8ms0pvhbbMvsVLL3on9wc2/DXgR/EdpT6w6r6P24++7PAP+fHfwPwa71eDZih4m9i0UV/F/ge4A9tjv0Bv7cPAn/Zf2+7JP12zMh1hdXv+c7NPX8fZuj4dixSpUXvNHzIz9X69M6vhap+P/CfAP878CngZ2F1ehqe2l++q9O38Ob4duDfw1Kefi62/Tie+vXPYIWJXwJ+JvB/Yoatjo6Ojo6Ojo6Ojo6OzwvEyjB0dHR0/NSDp3oV4EOq+lER+ZPAx1X1LQsm+7EfB75eVf/q57alHR0dHR0dHR0dHR0dhh5R09HR8VMZXw0ceJu7N4nI14rIs56G9W9jdXG++y0O6+jo6Ojo6Ojo6OjoeMfQDTUdHR0/JSEivw74q8DvVtXpbR72C4EfAj4D/FPA16nq3dMP6ejo6Ojo6Ojo6OjoeOfQU586Ojo6Ojo6Ojo6Ojo6Ojo6Hgh6RE1HR0dHR0dHR0dHR0dHR0fHA0E31HR0dHR0dHR0dHR0dHR0dHQ8EKSnffgHf+e/qKhSjxlyYZcSZ2kghchuGJEQCENCQoAxITGwvzjn/JlLQkoMuz2qys3tLfM8U2ul1sqjR4944fnn0FI53N4wHY989Id+hFdfeYV8PFKmIzWN5P05548e8ZVf8ZU8evSIEBMigVdef42XX32VECPjbscwjrz3xRfY7Uamq2vy3YFP/shH+dHv/wFCrZxLJIgQQ0REyCFQonD54gu850NfynBxzuX730fcjTAMEAOHj3+Mu098nNdefpWP/8jHKFVJw4407vjwV30V7/ngB3nhAx/gPV/60/ihH/4R/syf+/O8+sprvPLSKxwPR0rJlJpRQFEQiICIEEJAAIKgIkt/C0LAPhMRBKhaqVpBQIMiAsj2u0KUSCARQ2I37EkxcfnoGXb7PV/1s76KD3/Zh3n/Bz/Al33llxNiQK0hqGAv7GV1UwGUgLK2rOPdgm0qo2zG1pPeV9XlJSL2vkIphTxnPvKRH+Xll1/i9uaG6+trzs/OeOHFF0gxkpLNqVoLqpVaKqWWNzk/qAIiiIT1PUAkICL2OUKeZw6HAyJCipEQAo8enbPb7ZinI9PxwO3tDT/+qU9RSkHc1LyOYbtOVWWeM7UqpVQUeP/738973vMiZ/s9l888AlWm42RySSu1Kp++nvj0zcThcOD1115DVRljIIXAe559zOXFGWNM7FLkOE28+tp
"text/plain": [
"<Figure size 1440x720 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
}
],
"source": [
"test_predictions(test_ds[29][-1])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "mri7kXHfLbC0"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"include_colab_link": true,
"name": "Training_Fast_R_CNN.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"51c017bf540f4b4b99b98c3b45e9c683": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_bde7f64fb9ff4bb39448d90f70996b30",
"IPY_MODEL_6227ae96f0c8475ea3a056f1dd1a62f9"
],
"layout": "IPY_MODEL_8e5a5f297879447eb77a3d1a69e4d823"
}
},
"6227ae96f0c8475ea3a056f1dd1a62f9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f6997519e2f6415590f2cd8a4779c586",
"placeholder": "",
"style": "IPY_MODEL_ae9603e8db914d4f87646bb7726e7211",
"value": " 528M/528M [00:05&lt;00:00, 97.5MB/s]"
}
},
"737a4e3955454a29a6aa33fcc78416b2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"8e5a5f297879447eb77a3d1a69e4d823": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ae9603e8db914d4f87646bb7726e7211": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"bde7f64fb9ff4bb39448d90f70996b30": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_c3d91b7255274b3f8894d84dc34fa381",
"max": 553507836,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_737a4e3955454a29a6aa33fcc78416b2",
"value": 553507836
}
},
"c3d91b7255274b3f8894d84dc34fa381": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f6997519e2f6415590f2cd8a4779c586": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}