Computer_Vision/Chapter07/Training_RCNN.ipynb

1047 lines
603 KiB
Plaintext
Raw Permalink Normal View History

2024-02-13 03:34:51 +01:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/PacktPublishing/Hands-On-Computer-Vision-with-PyTorch/blob/master/Chapter07/Training_RCNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 140,
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
"data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk
"headers": [
[
"content-type",
"application/javascript"
]
],
"ok": true,
"status": 200,
"status_text": ""
}
}
},
"id": "pIOo-A9m3_Th",
"outputId": "440b254e-3d29-4ef1-fd7e-6d6deb3661ba"
},
"outputs": [],
"source": [
"!pip install -q --upgrade selectivesearch torch_snippets\n",
"from torch_snippets import *\n",
"import selectivesearch\n",
"#from google.colab import files\n",
"#files.upload() # upload kaggle.json file which you can get \n",
"# by clicking on Create New API token in your personal account\n",
"#!mkdir -p ~/.kaggle\n",
"#!mv kaggle.json ~/.kaggle/\n",
"#!ls ~/.kaggle\n",
"#!chmod 600 /root/.kaggle/kaggle.json\n",
"#!kaggle datasets download -d sixhky/open-images-bus-trucks/\n",
"#!unzip -qq open-images-bus-trucks.zip\n",
"from torchvision import transforms, models, datasets\n",
"from torch_snippets import Report\n",
"from torchvision.ops import nms\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">cpu\n",
"</pre>\n"
],
"text/plain": [
"cpu\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"print(device)\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 153
},
"id": "14-PnjIp4Le_",
"outputId": "1e839b05-4bf8-4b64-e4e3-b2f8be285772",
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"> ImageID Source LabelName Confidence XMin XMax \\\n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> 0000599864fd15b3 xclick Bus <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.343750</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.908750</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> 00006bdb1eb5cd74 xclick Truck <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.276667</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.697500</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> 00006bdb1eb5cd74 xclick Truck <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.702500</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.999167</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> 00010bf498b64bab xclick Bus <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.156250</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.371250</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> 00013f14dd4e168f xclick Bus <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.287500</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.999375</span> \n",
"\n",
" YMin YMax IsOccluded IsTruncated <span style=\"color: #808000; text-decoration-color: #808000\">...</span> IsDepiction IsInside \\\n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.156162</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.650047</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.141604</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.437343</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.204261</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.409774</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.269188</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.705228</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.194184</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.999062</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #808000; text-decoration-color: #808000\">...</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> \n",
"\n",
" XClick1X XClick2X XClick3X XClick4X XClick1Y XClick2Y XClick3Y \\\n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.421875</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.343750</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.795000</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.908750</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.156162</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.512700</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.650047</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.299167</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.276667</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.697500</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.659167</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.141604</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.241855</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.352130</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.849167</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.702500</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.906667</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.999167</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.204261</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.398496</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.409774</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.274375</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.371250</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.311875</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.156250</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.269188</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.493882</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.705228</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.920000</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.999375</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.648750</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.287500</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.194184</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.303940</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.999062</span> \n",
"\n",
" XClick4Y \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.457197</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.437343</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.295739</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.521691</span> \n",
"<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.523452</span> \n",
"\n",
"<span style=\"font-weight: bold\">[</span><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">5</span> rows x <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">21</span> columns<span style=\"font-weight: bold\">]</span>\n",
"</pre>\n"
],
"text/plain": [
" ImageID Source LabelName Confidence XMin XMax \\\n",
"\u001b[1;36m0\u001b[0m 0000599864fd15b3 xclick Bus \u001b[1;36m1\u001b[0m \u001b[1;36m0.343750\u001b[0m \u001b[1;36m0.908750\u001b[0m \n",
"\u001b[1;36m1\u001b[0m 00006bdb1eb5cd74 xclick Truck \u001b[1;36m1\u001b[0m \u001b[1;36m0.276667\u001b[0m \u001b[1;36m0.697500\u001b[0m \n",
"\u001b[1;36m2\u001b[0m 00006bdb1eb5cd74 xclick Truck \u001b[1;36m1\u001b[0m \u001b[1;36m0.702500\u001b[0m \u001b[1;36m0.999167\u001b[0m \n",
"\u001b[1;36m3\u001b[0m 00010bf498b64bab xclick Bus \u001b[1;36m1\u001b[0m \u001b[1;36m0.156250\u001b[0m \u001b[1;36m0.371250\u001b[0m \n",
"\u001b[1;36m4\u001b[0m 00013f14dd4e168f xclick Bus \u001b[1;36m1\u001b[0m \u001b[1;36m0.287500\u001b[0m \u001b[1;36m0.999375\u001b[0m \n",
"\n",
" YMin YMax IsOccluded IsTruncated \u001b[33m...\u001b[0m IsDepiction IsInside \\\n",
"\u001b[1;36m0\u001b[0m \u001b[1;36m0.156162\u001b[0m \u001b[1;36m0.650047\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;36m0\u001b[0m \u001b[33m...\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m0\u001b[0m \n",
"\u001b[1;36m1\u001b[0m \u001b[1;36m0.141604\u001b[0m \u001b[1;36m0.437343\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;36m0\u001b[0m \u001b[33m...\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m0\u001b[0m \n",
"\u001b[1;36m2\u001b[0m \u001b[1;36m0.204261\u001b[0m \u001b[1;36m0.409774\u001b[0m \u001b[1;36m1\u001b[0m \u001b[1;36m1\u001b[0m \u001b[33m...\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m0\u001b[0m \n",
"\u001b[1;36m3\u001b[0m \u001b[1;36m0.269188\u001b[0m \u001b[1;36m0.705228\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m0\u001b[0m \u001b[33m...\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m0\u001b[0m \n",
"\u001b[1;36m4\u001b[0m \u001b[1;36m0.194184\u001b[0m \u001b[1;36m0.999062\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m1\u001b[0m \u001b[33m...\u001b[0m \u001b[1;36m0\u001b[0m \u001b[1;36m0\u001b[0m \n",
"\n",
" XClick1X XClick2X XClick3X XClick4X XClick1Y XClick2Y XClick3Y \\\n",
"\u001b[1;36m0\u001b[0m \u001b[1;36m0.421875\u001b[0m \u001b[1;36m0.343750\u001b[0m \u001b[1;36m0.795000\u001b[0m \u001b[1;36m0.908750\u001b[0m \u001b[1;36m0.156162\u001b[0m \u001b[1;36m0.512700\u001b[0m \u001b[1;36m0.650047\u001b[0m \n",
"\u001b[1;36m1\u001b[0m \u001b[1;36m0.299167\u001b[0m \u001b[1;36m0.276667\u001b[0m \u001b[1;36m0.697500\u001b[0m \u001b[1;36m0.659167\u001b[0m \u001b[1;36m0.141604\u001b[0m \u001b[1;36m0.241855\u001b[0m \u001b[1;36m0.352130\u001b[0m \n",
"\u001b[1;36m2\u001b[0m \u001b[1;36m0.849167\u001b[0m \u001b[1;36m0.702500\u001b[0m \u001b[1;36m0.906667\u001b[0m \u001b[1;36m0.999167\u001b[0m \u001b[1;36m0.204261\u001b[0m \u001b[1;36m0.398496\u001b[0m \u001b[1;36m0.409774\u001b[0m \n",
"\u001b[1;36m3\u001b[0m \u001b[1;36m0.274375\u001b[0m \u001b[1;36m0.371250\u001b[0m \u001b[1;36m0.311875\u001b[0m \u001b[1;36m0.156250\u001b[0m \u001b[1;36m0.269188\u001b[0m \u001b[1;36m0.493882\u001b[0m \u001b[1;36m0.705228\u001b[0m \n",
"\u001b[1;36m4\u001b[0m \u001b[1;36m0.920000\u001b[0m \u001b[1;36m0.999375\u001b[0m \u001b[1;36m0.648750\u001b[0m \u001b[1;36m0.287500\u001b[0m \u001b[1;36m0.194184\u001b[0m \u001b[1;36m0.303940\u001b[0m \u001b[1;36m0.999062\u001b[0m \n",
"\n",
" XClick4Y \n",
"\u001b[1;36m0\u001b[0m \u001b[1;36m0.457197\u001b[0m \n",
"\u001b[1;36m1\u001b[0m \u001b[1;36m0.437343\u001b[0m \n",
"\u001b[1;36m2\u001b[0m \u001b[1;36m0.295739\u001b[0m \n",
"\u001b[1;36m3\u001b[0m \u001b[1;36m0.521691\u001b[0m \n",
"\u001b[1;36m4\u001b[0m \u001b[1;36m0.523452\u001b[0m \n",
"\n",
"\u001b[1m[\u001b[0m\u001b[1;36m5\u001b[0m rows x \u001b[1;36m21\u001b[0m columns\u001b[1m]\u001b[0m\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"IMAGE_ROOT = 'images/images'\n",
"DF_RAW = pd.read_csv('images/df.csv')\n",
"print(DF_RAW.head())"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 404
},
"id": "P-kZg83t5rwY",
"outputId": "0ab8b4fd-4e2a-400b-ff81-577fef95d9c9"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAxoAAAIZCAYAAAAho0lIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9abBla3rXB757noczj5nn5Jw3885Vt+qWKKkoiULCBgl3AAZD2yHR0V9sgwlHY+RwEERA2LTVdjcR3Q7L4cCNmQxhUKMSmko136Fu3fnmPJ+TZx72PM/94VLr/1vJ3qpT5iCF8fP/Um/tu85a7/yunc9v/5/AaDQaOZPJZDKZTCaTyWQ6RQV/vytgMplMJpPJZDKZ/s2TfdEwmUwmk8lkMplMpy77omEymUwmk8lkMplOXfZFw2QymUwmk8lkMp267IuGyWQymUwmk8lkOnXZFw2TyWQymUwmk8l06rIvGiaTyWQymUwmk+nUZV80TCaTyWQymUwm06krfNILo5moylGVhwHl+wuFQl55MOp75ampKa9ca1S9ci6X0336A68ci8W88pNHm145GAzomqDqEA6rGclYwis3Gg2v3EdawgDqHAz6v2vF4rrX0Om64VDtCYf1N/z70Uj1a7VaXrk70N9ms1nVqa/PZ2dnvfLOzpZXjkQiur9u7xuDVCbllQMYg1arPbYtxUJFNxqp33mfeFzPXVqa8crlctkrdxu6f7+nWw576sN+R/fv1XVRBH3YH6mvkkm1i22pVnRNOpdxVLPZxH+Le+VWS+OfzCS98uzstOoR1edHR0dqA/qr3ex45VBAdQoMNSC1mp7FORGLaj7G46pbvaU6cx70+101DPM0HMU9Y+rfkVP/XrlyyStn0EeJtOpw8+ZNr/yH/8hPeeUPPnjPK1cqNa8cCaW98ubGnldOJTRXHPoqm1N/hiOq2/MvX/bKxdKBo/J5PWNra8crV2tDr9yoawySCa0hrrPyUdkrBwIam8BQY8b1FA1rrg2wRqdntS+VKwWvnM2qnpGo7l8oaN5kc+qXMyurqudA4726ftYrH6EvItjfGjW1pVYteuXFea3FdFr1iSfUlqWVM175YF/3uXv70Cu7EdYZ9kbupe2a1vfM7LxX3ni666hWt+6Vz186pzY0S145qG53UczlBw/ve+Xl5UWv3GxoDnba2jeeu/Si7t9Qn3Y6uiab03zPTqtcRT+eXbvglT/5+I5XLhzrfMpm8145FVN/dXvYrxLqr0RC59ajRxte+dw59UkM94nHVS6WNDbz8zoLKjXt1fu7x17Z9dSu9TPndR/MxXZb41cuq13OORcJ6+9zac3342PN93RK66zb1fpIxvJe+f69R1454DTIw6HWbjKlfW/ocI4mcFZ1NJa9niZkMqn9ZDTQPXkGD7B/fvGLX/DKd+/d9spzc3oHOS6qjRwPPovroFbXeK+vr3vljY0Nr7y7qzURxmRPpbQfREJ6Fs+sLNYx3zNiEV1frmjuOudcbmr8e0QsoXEtFvU3ybTqUSppTk3ldBZ2uzp7mm2dZ9PTmh98BxkMtOaSGdWn09Fe3W71xn4eC6uPanXVZ2luwStHsE/kcf9WV+MRxXtKtao5HhxqDi1Ma+/a2dI4VcvaY6JRrd3189qf87Oau9v7Ov8ePdZ6bWF/CuFMYdudcy7g9Iwm+sIFNX6DgMqhAM/2i175zJr6qFrVHruC/bOB/ZPv5czPzfdVzl/OFf5tKKTr+R74/a8+cT9MFtEwmUwmk8lkMplMpy77omEymUwmk8lkMplOXSdGp4giMASWBG6TQNiu1QG2gzAqQ9w+5CVJFEPK5xW2q1QqY68hIhXQoxzIFudG+g+IHvmQiU//v/5oCCzFf40+Z9tCIWAZQJuIbvR6vbHXlEoKgTFkH41qiOpNtTMUQugY9ZmbViiUdWPIc35B+IUPFxqibkDQHj9+rM/x3GhYYUVfSA79xroFw0Drhvo8HCYepooyJB4B0larKSzonHOpTBp/o/smgPfEYqorbuu6PfUp8Sfie33c0wGLCqB/uT44rpzjXDd9zEcfQojQeSAIfAvh4jDCs8mUytvb2175xy/8BD4Xire2tuaV792755WXlpa8MkPuqYTGhsgh59aza+gHOndOeMrykjAiIizOOdduK2Q/h9B5rb7vlYkJMWy7srDilVs1ITyc8P2O6krMIIBrIhHNLx9yAOwjAGxuABazXtf8mJ/PoKxQ9tbOU6/80Uc3vPL0rJCOuTPLXjkRU7+vndXn8ajqcO/+La/M8atUEE5f0d/ev6fQfxHrvlbVvE9EtWZaDc3XTld92Abm4pxzL74knCkY0h5Srurvr1+56pUbLY1TLPaCV15F+0cD7iHAcGKaB7dv3fXKzab2hOOC2jnX1H545qzmCuf+3r5wvekpjVm3ozXngDQOgQU3GzqTrl275pVXloWvNYDe1OtqezSqMU4kNc/2DzU2xLHcCHsv9pjDQ2EcQXJw0Pzcou//c81tbWl/yABR8WF0bbX5wYMHqhLWULujeUREsdlUm4nQxdJ6FvfbFNpMHIYYx3CoviCGe/eu5kQAKGK/rzmUwFnQwhgTNeMZw/NpgDOSe2kmq3VDdIp7cgNzlOci0Ut+7tz4fcs550IRtTkHND2fz4+9Vw97NBHeQoEYmfr94kXheNWq5nh2Su3ku4wLaq7wLIzjPaLfUR04f3td9Htd433urNZQCygX34n4ORGsANbKdF7vO3wvHeAMDmIOFUpCFBs9jWWri7EHQkd0qoYxTj1zLA4Hmr8ZvKfmgDm3+5qPiZiuTwCN5X6SSAj329vTvtduN3EN9nSgxnyf5vkRjWsecP4SG7x0ac79KLKIhslkMplMJpPJZDp12RcNk8lkMplMJpPJdOo6MTpFJx2GvSIImTFcWq0Tb1Gojk4KnZbCtxk4XLRaCrmvrCjczbAzUR3WjbgUMRqG8waj8WiPc87hR/++X9nT5YL38qEYCLFGE2pzD6EuIhdxOCExBBhGPzKkR3yETj8RuBCVywr7RYEqBVCOpFW3UFDtj0R0TQyOKMmYQo/U0b4Qk3BIIVIiQpGwwnDdjsbbPx66Z4w4Fsami7B8Isl+cK4FHI94TyKE0D8cIGpVhhVV7zDm8pBhT0yqYEDj0WprXHs9zindP5VUWJSh6RHC+pzLdP4gBxgBZkGVigp/9tFfv/5rv+GVv/LTf8grv/H2m155fl5YSRSh+CHC7FNwhArCBWMWTiTFYhk1wjrB9vLWm+965XPnVx3VaqsNT57IZW5uTuFcN0K4vK650EOo+YUXnvPKR0dCAu7eFLZEJGk0GI+CJZLYFhHirlQU1qebTCapvavX1rje+Eh4zjCoGy0syNVkZ1fYTiyoud9saG1duiikiCjXH/zyj3vlN9/8ruqP9X37zsde+eWX5Vzy+KHcrm5/IucgOnGFUlobh/vCO1zQj+fs7Sm8nskB+VpXvZ9uCb+kuw+deB7c1TXEZ7gmlpaE1r3w4hWv/Ou//pteOZcTSkIXqSCQlkJB++RZIFu1qtAK7p8p9EUDiF5uSo42W0/VpyvLmuOsfwhuZdvbG145ltBaTCW1trhnnFnTXrL9WC4xMTj1TU/LsWp3R+jh/qHf6Y04xSLwujrw0UZb+/jegfCsHs9MICopH2YI1zfMl25f9yQyEwCSS3cpYk5B7M/EsVo17QHDHvoaZycRNB/mivnehEuhH1XS9a9/7nOqP9rS76ktva76JwukarOkfSgGpKgHHMuPe+k+RCydc64FzCuJ/bqEPapY1r7K95xoRH2awFokElmrYd2ENB4HB9qvuI7pZNVsqm7pVN4r54GEHwPd7MAd0+Ec2t/Ts3J5IPp4bg7na5zrIK914IAsLQPHqgIdauOdpTvSePTb6jeuh6kZjUd6SmdkY6j33vlp/5iFsL8Xj/TsVFbtmYtpH8ik9Xmlqv2qT9Ssqueto22lks4/voNF03Ra1LrM5/Xchw8feuWlBSGXczPq0xs3hP+eRBbRMJlMJpPJZDKZTKcu+6JhMplMJpPJZDKZTl2BEbmN30XhlMKWdMkhLuV
"text/plain": [
"<Figure size 1000x1000 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"class OpenImages(Dataset):\n",
" def __init__(self, df, image_folder=IMAGE_ROOT):\n",
" self.root = image_folder\n",
" self.df = df\n",
" self.unique_images = df['ImageID'].unique()\n",
" def __len__(self): return len(self.unique_images)\n",
" def __getitem__(self, ix):\n",
" image_id = self.unique_images[ix]\n",
" image_path = f'{self.root}/{image_id}.jpg'\n",
" image = cv2.imread(image_path, 1)[...,::-1] # conver BGR to RGB\n",
" h, w, _ = image.shape\n",
" df = self.df.copy()\n",
" df = df[df['ImageID'] == image_id]\n",
" boxes = df['XMin,YMin,XMax,YMax'.split(',')].values\n",
" boxes = (boxes * np.array([w,h,w,h])).astype(np.uint16).tolist()\n",
" classes = df['LabelName'].values.tolist()\n",
" return image, boxes, classes, image_path\n",
"ds = OpenImages(df=DF_RAW)\n",
"im, bbs, clss, _ = ds[9]\n",
"show(im, bbs=bbs, texts=clss, sz=10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "bt_LSq5_55TZ"
},
"outputs": [],
"source": [
"def extract_candidates(img):\n",
" img_lbl, regions = selectivesearch.selective_search(img, scale=200, min_size=100)\n",
" img_area = np.prod(img.shape[:2])\n",
" candidates = []\n",
" for r in regions:\n",
" if r['rect'] in candidates: continue\n",
" if r['size'] < (0.05*img_area): continue\n",
" if r['size'] > (1*img_area): continue\n",
" x, y, w, h = r['rect']\n",
" candidates.append(list(r['rect']))\n",
" return candidates\n",
"def extract_iou(boxA, boxB, epsilon=1e-5):\n",
" x1 = max(boxA[0], boxB[0])\n",
" y1 = max(boxA[1], boxB[1])\n",
" x2 = min(boxA[2], boxB[2])\n",
" y2 = min(boxA[3], boxB[3])\n",
" width = (x2 - x1)\n",
" height = (y2 - y1)\n",
" if (width<0) or (height <0):\n",
" return 0.0\n",
" area_overlap = width * height\n",
" area_a = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])\n",
" area_b = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])\n",
" area_combined = area_a + area_b - area_overlap\n",
" iou = area_overlap / (area_combined+epsilon)\n",
" return iou"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "TtCQPF8J6CGB"
},
"outputs": [],
"source": [
"FPATHS, GTBBS, CLSS, DELTAS, ROIS, IOUS = [], [], [], [], [], []\n",
"N = 500\n",
"for ix, (im, bbs, labels, fpath) in enumerate(ds):\n",
" if(ix==N):\n",
" break\n",
" H, W, _ = im.shape\n",
" candidates = extract_candidates(im)\n",
" candidates = np.array([(x,y,x+w,y+h) for x,y,w,h in candidates])\n",
" ious, rois, clss, deltas = [], [], [], []\n",
" ious = np.array([[extract_iou(candidate, _bb_) for candidate in candidates] for _bb_ in bbs]).T\n",
" for jx, candidate in enumerate(candidates):\n",
" cx,cy,cX,cY = candidate\n",
" candidate_ious = ious[jx]\n",
" best_iou_at = np.argmax(candidate_ious)\n",
" best_iou = candidate_ious[best_iou_at]\n",
" best_bb = _x,_y,_X,_Y = bbs[best_iou_at]\n",
" if best_iou > 0.3: clss.append(labels[best_iou_at])\n",
" else : clss.append('background')\n",
" delta = np.array([_x-cx, _y-cy, _X-cX, _Y-cY]) / np.array([W,H,W,H])\n",
" deltas.append(delta)\n",
" rois.append(candidate / np.array([W,H,W,H]))\n",
" FPATHS.append(fpath)\n",
" IOUS.append(ious)\n",
" ROIS.append(rois)\n",
" CLSS.append(clss)\n",
" DELTAS.append(deltas)\n",
" GTBBS.append(bbs)\n",
"FPATHS = [f'{IMAGE_ROOT}/{stem(f)}.jpg' for f in FPATHS] \n",
"FPATHS, GTBBS, CLSS, DELTAS, ROIS = [item for item in [FPATHS, GTBBS, CLSS, DELTAS, ROIS]]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "yxzZs0Gs7bQt"
},
"outputs": [],
"source": [
"targets = pd.DataFrame(flatten(CLSS), columns=['label'])\n",
"label2target = {l:t for t,l in enumerate(targets['label'].unique())}\n",
"target2label = {t:l for l,t in label2target.items()}\n",
"background_class = label2target['background']"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "vEtHtZOO725v"
},
"outputs": [],
"source": [
"normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
" std=[0.229, 0.224, 0.225])\n",
"def preprocess_image(img):\n",
" img = torch.tensor(img).permute(2,0,1)\n",
" img = normalize(img)\n",
" return img.to(device).float()\n",
"def decode(_y):\n",
" _, preds = _y.max(-1)\n",
" return preds\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "4vLi9hII7-WS"
},
"outputs": [],
"source": [
"class RCNNDataset(Dataset):\n",
" def __init__(self, fpaths, rois, labels, deltas, gtbbs):\n",
" self.fpaths = fpaths\n",
" self.gtbbs = gtbbs\n",
" self.rois = rois\n",
" self.labels = labels\n",
" self.deltas = deltas\n",
" def __len__(self): return len(self.fpaths)\n",
" def __getitem__(self, ix):\n",
" fpath = str(self.fpaths[ix])\n",
" image = cv2.imread(fpath, 1)[...,::-1]\n",
" H, W, _ = image.shape\n",
" sh = np.array([W,H,W,H])\n",
" gtbbs = self.gtbbs[ix]\n",
" rois = self.rois[ix]\n",
" bbs = (np.array(rois)*sh).astype(np.uint16)\n",
" labels = self.labels[ix]\n",
" deltas = self.deltas[ix]\n",
" crops = [image[y:Y,x:X] for (x,y,X,Y) in bbs]\n",
" return image, crops, bbs, labels, deltas, gtbbs, fpath\n",
" def collate_fn(self, batch):\n",
" input, rois, rixs, labels, deltas = [], [], [], [], []\n",
" for ix in range(len(batch)):\n",
" image, crops, image_bbs, image_labels, image_deltas, image_gt_bbs, image_fpath = batch[ix]\n",
" crops = [cv2.resize(crop, (224,224)) for crop in crops]\n",
" crops = [preprocess_image(crop/255.)[None] for crop in crops]\n",
" input.extend(crops)\n",
" labels.extend([label2target[c] for c in image_labels])\n",
" deltas.extend(image_deltas)\n",
" input = torch.cat(input).to(device)\n",
" labels = torch.Tensor(labels).long().to(device)\n",
" deltas = torch.Tensor(deltas).float().to(device)\n",
" return input, labels, deltas\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "dzwT5C-J8G0j"
},
"outputs": [],
"source": [
"n_train = 9*len(FPATHS)//10\n",
"train_ds = RCNNDataset(FPATHS[:n_train], ROIS[:n_train], CLSS[:n_train], DELTAS[:n_train], GTBBS[:n_train])\n",
"test_ds = RCNNDataset(FPATHS[n_train:], ROIS[n_train:], CLSS[n_train:], DELTAS[n_train:], GTBBS[n_train:])\n",
"\n",
"from torch.utils.data import TensorDataset, DataLoader\n",
"train_loader = DataLoader(train_ds, batch_size=2, collate_fn=train_ds.collate_fn, drop_last=True)\n",
"test_loader = DataLoader(test_ds, batch_size=2, collate_fn=test_ds.collate_fn, drop_last=True)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 732,
"referenced_widgets": [
"a929d30788fe4772ab578390ce90597e",
"4f05bb580a544d788143bd4fdea8150a",
"804c3bc6fa9e48d9b59ce11e855b0999",
"aaf77f8c6ddc496b8975b648830878e6",
"2bde1618e6a244dba48a884d32e3dcd4",
"9610567dd5d1460f99516aac12d77581",
"1051fab2a4524a0a897e4af2f8186e0d",
"0de938505a4445caa4bff6402a9efa56"
]
},
"id": "QcLxxBJm8HUz",
"outputId": "18905158-d20e-4f9a-fc90-d63b444d2708"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\frakt\\anaconda3\\lib\\site-packages\\torchvision\\models\\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
" warnings.warn(\n",
"C:\\Users\\frakt\\anaconda3\\lib\\site-packages\\torchvision\\models\\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG16_Weights.IMAGENET1K_V1`. You can also use `weights=VGG16_Weights.DEFAULT` to get the most up-to-date weights.\n",
" warnings.warn(msg)\n",
"Downloading: \"https://download.pytorch.org/models/vgg16-397923af.pth\" to C:\\Users\\frakt/.cache\\torch\\hub\\checkpoints\\vgg16-397923af.pth\n",
"100%|███████████████████████████████████████████████████████████████████████████████| 528M/528M [01:04<00:00, 8.64MB/s]\n"
]
},
{
"data": {
"text/plain": [
"VGG(\n",
" (features): Sequential(\n",
" (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (1): ReLU(inplace=True)\n",
" (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (3): ReLU(inplace=True)\n",
" (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (6): ReLU(inplace=True)\n",
" (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (8): ReLU(inplace=True)\n",
" (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (11): ReLU(inplace=True)\n",
" (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (13): ReLU(inplace=True)\n",
" (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (15): ReLU(inplace=True)\n",
" (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (18): ReLU(inplace=True)\n",
" (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (20): ReLU(inplace=True)\n",
" (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (22): ReLU(inplace=True)\n",
" (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (25): ReLU(inplace=True)\n",
" (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (27): ReLU(inplace=True)\n",
" (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
" (29): ReLU(inplace=True)\n",
" (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" )\n",
" (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))\n",
" (classifier): Sequential()\n",
")"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vgg_backbone = models.vgg16(pretrained=True)\n",
"vgg_backbone.classifier = nn.Sequential()\n",
"for param in vgg_backbone.parameters():\n",
" param.requires_grad = False\n",
"vgg_backbone.eval().to(device)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "Fs3XfT418aGk"
},
"outputs": [],
"source": [
"class RCNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" feature_dim = 25088\n",
" self.backbone = vgg_backbone\n",
" self.cls_score = nn.Linear(feature_dim, len(label2target))\n",
" self.bbox = nn.Sequential(\n",
" nn.Linear(feature_dim, 512),\n",
" nn.ReLU(),\n",
" nn.Linear(512, 4),\n",
" nn.Tanh(),\n",
" )\n",
" self.cel = nn.CrossEntropyLoss()\n",
" self.sl1 = nn.L1Loss()\n",
" def forward(self, input):\n",
" feat = self.backbone(input)\n",
" cls_score = self.cls_score(feat)\n",
" bbox = self.bbox(feat)\n",
" return cls_score, bbox\n",
" def calc_loss(self, probs, _deltas, labels, deltas):\n",
" detection_loss = self.cel(probs, labels)\n",
" ixs, = torch.where(labels != 0)\n",
" _deltas = _deltas[ixs]\n",
" deltas = deltas[ixs]\n",
" self.lmb = 10.0\n",
" if len(ixs) > 0:\n",
" regression_loss = self.sl1(_deltas, deltas)\n",
" return detection_loss + self.lmb * regression_loss, detection_loss.detach(), regression_loss.detach()\n",
" else:\n",
" regression_loss = 0\n",
" return detection_loss + self.lmb * regression_loss, detection_loss.detach(), regression_loss"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "MiMmezgp9E-s"
},
"outputs": [],
"source": [
"def train_batch(inputs, model, optimizer, criterion):\n",
" input, clss, deltas = inputs\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" _clss, _deltas = model(input)\n",
" loss, loc_loss, regr_loss = criterion(_clss, _deltas, clss, deltas)\n",
" accs = clss == decode(_clss)\n",
" loss.backward()\n",
" optimizer.step()\n",
" return loss.detach(), loc_loss, regr_loss, accs.cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "vNBqA98I9G6O"
},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def validate_batch(inputs, model, criterion):\n",
" input, clss, deltas = inputs\n",
" with torch.no_grad():\n",
" model.eval()\n",
" _clss,_deltas = model(input)\n",
" loss, loc_loss, regr_loss = criterion(_clss, _deltas, clss, deltas)\n",
" _, _clss = _clss.max(-1)\n",
" accs = clss == _clss\n",
" return _clss, _deltas, loss.detach(), loc_loss, regr_loss, accs.cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "JIobqz0W9I6s"
},
"outputs": [],
"source": [
"rcnn = RCNN().to(device)\n",
"criterion = rcnn.calc_loss\n",
"optimizer = optim.SGD(rcnn.parameters(), lr=1e-3)\n",
"n_epochs = 5\n",
"log = Report(n_epochs)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 514
},
"id": "hReU_vnH9Kk0",
"outputId": "b5f5094d-eeb6-4aec-fbcd-0ca419b8d0ee"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\frakt\\AppData\\Local\\Temp\\ipykernel_27780\\4275183504.py:32: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\torch\\csrc\\utils\\tensor_new.cpp:248.)\n",
" deltas = torch.Tensor(deltas).float().to(device)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EPOCH: 0.409 trn_loss: 2.199 trn_loc_loss: 0.554 trn_regr_loss: 0.164 trn_acc: 0.792 (821.90s - 9228.49s remaining))"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_27780\\1533349013.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0m_n\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mix\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minputs\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m loss, loc_loss, regr_loss, accs = train_batch(inputs, rcnn, \n\u001b[0m\u001b[0;32m 6\u001b[0m optimizer, criterion)\n\u001b[0;32m 7\u001b[0m \u001b[0mpos\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mepoch\u001b[0m \u001b[1;33m+\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mix\u001b[0m\u001b[1;33m+\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m/\u001b[0m\u001b[0m_n\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_27780\\1507503000.py\u001b[0m in \u001b[0;36mtrain_batch\u001b[1;34m(inputs, model, optimizer, criterion)\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0m_clss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_deltas\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 6\u001b[0m \u001b[0mloss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mloc_loss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mregr_loss\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_clss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0m_deltas\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mclss\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdeltas\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0maccs\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mclss\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mdecode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_clss\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_27780\\2122026028.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msl1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mL1Loss\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 15\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 16\u001b[1;33m \u001b[0mfeat\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbackbone\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 17\u001b[0m \u001b[0mcls_score\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcls_score\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeat\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 18\u001b[0m \u001b[0mbbox\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbbox\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfeat\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torchvision\\models\\vgg.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, x)\u001b[0m\n\u001b[0;32m 64\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 65\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mx\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 66\u001b[1;33m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 67\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mavgpool\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 68\u001b[0m \u001b[0mx\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mflatten\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 215\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 216\u001b[0m \u001b[1;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 217\u001b[1;33m \u001b[0minput\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 218\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 219\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1499\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1500\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[1;32m-> 1501\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1502\u001b[0m \u001b[1;31m# Do not call functions when jit is used\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1503\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001b[0m in \u001b[0;36mforward\u001b[1;34m(self, input)\u001b[0m\n\u001b[0;32m 461\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 462\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0minput\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 463\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 464\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 465\u001b[0m \u001b[1;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;32m~\\anaconda3\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[1;34m(self, input, weight, bias)\u001b[0m\n\u001b[0;32m 457\u001b[0m \u001b[0mweight\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstride\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 458\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[1;32m--> 459\u001b[1;33m return F.conv2d(input, weight, bias, self.stride,\n\u001b[0m\u001b[0;32m 460\u001b[0m self.padding, self.dilation, self.groups)\n\u001b[0;32m 461\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"for epoch in range(n_epochs):\n",
"\n",
" _n = len(train_loader)\n",
" for ix, inputs in enumerate(train_loader):\n",
" loss, loc_loss, regr_loss, accs = train_batch(inputs, rcnn, \n",
" optimizer, criterion)\n",
" pos = (epoch + (ix+1)/_n)\n",
" log.record(pos, trn_loss=loss.item(), trn_loc_loss=loc_loss, \n",
" trn_regr_loss=regr_loss, \n",
" trn_acc=accs.mean(), end='\\r')\n",
" \n",
" _n = len(test_loader)\n",
" for ix,inputs in enumerate(test_loader):\n",
" _clss, _deltas, loss, \\\n",
" loc_loss, regr_loss, accs = validate_batch(inputs, \n",
" rcnn, criterion)\n",
" pos = (epoch + (ix+1)/_n)\n",
" log.record(pos, val_loss=loss.item(), val_loc_loss=loc_loss, \n",
" val_regr_loss=regr_loss, \n",
" val_acc=accs.mean(), end='\\r')\n",
"\n",
"# Plotting training and validation metrics\n",
"log.plot_epochs('trn_loss,val_loss'.split(','))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qksBAgNJ9NhK"
},
"outputs": [],
"source": [
"def test_predictions(filename, show_output=True):\n",
" img = np.array(cv2.imread(filename, 1)[...,::-1])\n",
" candidates = extract_candidates(img)\n",
" candidates = [(x,y,x+w,y+h) for x,y,w,h in candidates]\n",
" input = []\n",
" for candidate in candidates:\n",
" x,y,X,Y = candidate\n",
" crop = cv2.resize(img[y:Y,x:X], (224,224))\n",
" input.append(preprocess_image(crop/255.)[None])\n",
" input = torch.cat(input).to(device)\n",
" with torch.no_grad():\n",
" rcnn.eval()\n",
" probs, deltas = rcnn(input)\n",
" probs = torch.nn.functional.softmax(probs, -1)\n",
" confs, clss = torch.max(probs, -1)\n",
" candidates = np.array(candidates)\n",
" confs, clss, probs, deltas = [tensor.detach().cpu().numpy() for tensor in [confs, clss, probs, deltas]]\n",
"\n",
" ixs = clss!=background_class\n",
" confs, clss, probs, deltas, candidates = [tensor[ixs] for tensor in [confs, clss, probs, deltas, candidates]]\n",
" bbs = (candidates + deltas).astype(np.uint16)\n",
" ixs = nms(torch.tensor(bbs.astype(np.float32)), torch.tensor(confs), 0.05)\n",
" confs, clss, probs, deltas, candidates, bbs = [tensor[ixs] for tensor in [confs, clss, probs, deltas, candidates, bbs]]\n",
" if len(ixs) == 1:\n",
" confs, clss, probs, deltas, candidates, bbs = [tensor[None] for tensor in [confs, clss, probs, deltas, candidates, bbs]]\n",
" if len(confs) == 0 and not show_output:\n",
" return (0,0,224,224), 'background', 0\n",
" if len(confs) > 0:\n",
" best_pred = np.argmax(confs)\n",
" best_conf = np.max(confs)\n",
" best_bb = bbs[best_pred]\n",
" x,y,X,Y = best_bb\n",
" _, ax = plt.subplots(1, 2, figsize=(20,10))\n",
" show(img, ax=ax[0])\n",
" ax[0].grid(False)\n",
" ax[0].set_title('Original image')\n",
" if len(confs) == 0:\n",
" ax[1].imshow(img)\n",
" ax[1].set_title('No objects')\n",
" plt.show()\n",
" return\n",
" ax[1].set_title(target2label[clss[best_pred]])\n",
" show(img, bbs=bbs.tolist(), texts=[target2label[c] for c in clss.tolist()], ax=ax[1], title='predicted bounding box and class')\n",
" plt.show()\n",
" return (x,y,X,Y),target2label[clss[best_pred]],best_conf"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 350
},
"id": "0qpUkdh4E_zl",
"outputId": "0237ce9e-ddc5-4560-ca48-fd999ca80a52"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAABGoAAAF7CAYAAACKBW2fAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nOy9ebBlWXbW91t773Puve/lUFOrq6snDS2BCIfAQIMhwMYWNmhiskJgI0BgBRgwYpAYhBiEjBRmMGCQmSxbhAUCtYUZBIgQKBojoQEHBs0gqVtDd/VQ1ZXje+/ee87ee/mPtfa592VlZmW3VN2lrv1VvMrMd889Z49rfWdNW1SVjo6Ojo6Ojo6Ojo6Ojo6Ojo6PPMJHugEdHR0dHR0dHR0dHR0dHR0dHYZuqOno6Ojo6Ojo6Ojo6Ojo6Oh4haAbajo6Ojo6Ojo6Ojo6Ojo6OjpeIeiGmo6Ojo6Ojo6Ojo6Ojo6Ojo5XCLqhpqOjo6Ojo6Ojo6Ojo6Ojo+MVgm6o6ejo6Ojo6Ojo6Ojo6Ojo6HiFoBtqOjpeoRCRPywiX/WTfe0j3EtF5C0P+OwbReQ3/WQ8p6Ojo6Ojo+OjFyLyoyLyS/3vP2k85SWe+UtE5N0P+fyBHOfDCRH5GyLyJ/3vv1hE/sPL9Jx/ISKf/3Lc++XC8br5IL7zsT636eVqV0fHhxt9MXd0fBggIp8HfCHwCcAd4O8BX6yqtx70HVX9ike9/wdz7U8EqvppH47ndHR0dHR0dHz04FF5ioj8DeDdqvpHXt4WvXKgqt8C/LSPdDs6OjpeWegRNR0dLzNE5AuBPwX8fuA68J8Abwb+mYiMD/hON6J2dHR0dHR0vCLQeUlHR0fHhxfdUNPR8TJCRK4BfwL4Xar6T1V1VtUfBT4H+Fjgc/26LxWRrxeRvykid4DP89/9zaN7/UYR+TEReUFE/ug9IcXLtUfhn79JRH5cRD4gIl9ydJ+fJyLfLiK3ROS9IvKVDzIY3ac/SwitiHyeiPwrEfnzfq93isgv9N+/S0SeO06TEpHPEJF/KyJ3/PMvvefeD+tfEJE/JCLv8M/fJiJPfNAT0tHR0dHR0QEsKSZfLCLfLyI3ReSrRWTtn/0SEXm3iPxBEXkf8NUvpYtF5Dcc6fEvuedZ93KaXyQi3+b84V3OHX4r8OuBPyAiZyLyDX7tMyLyd0XkeRH5ERH5gqP7bDyN6KaIfD/w1kfo+qc7Z/mAiPwZEQl+ryAif8T78JyI/J8icv14PO4zfsc87G3+nbsi8n0i8nOPrv2PReT/88++DlgffXbp3n7fLxKR7xaR2yLydW1e/PM/4PztPSLy+fLS6VyfICL/2vnXP7hnzn6Ft/WWc7xP9t//QRH5TnEDnYj8dr9ufe/NReRxEflHPj83/e9vOPr8X4jI/yjGGe+KyDeJyFNHnz9w3dznWRsR+Z/9+tsi8q0isrnPdb9ZRH7An/dOEfltR5895W28JSI3RORbjtbAHxSRZ/17/0FEPvVh7enoeDnRDTUdHS8vfiGmjP/v41+q6hnwT4D/8ujXvxL4euAx4G8dXy8iPwP4yxiBeR0WmfP6l3j2L8JCaT8V+GNN+QIF+L3AU8Av8M9/xwfZr4afD3w38CTwtcDfwUjSWzAj1FeKyBW/9hz4jd6/zwB+u4j8qkfs3+8CfhXwnwHPADeB//VDbHNHR0dHR0eH4dcDvwxLzf4k4Djl6GngCSwK+LfyEF3sevyvAL/BP3sSeAP3gYi8GfhG4C8BrwF+FvDvVPWvY/znT6vqFVX9LH+B/gbguzBe8KnA7xGRX+a3++Pe9k/wfjxKHb1fDfxc4Gdj3Ou3+O8/z3/+c+DjgSvAVz7C/Rp+BcaDHgP+YfuumDPs7wNfg43n/wX81y9xr88BfjnwccCneLsQkV8O/D7gl2Jc65c8Qrt+I9bH1wEZ+It+r08C/jbwe7B5+CfAN3h7/wywB/6IiHwi8BXA56rq7j73D8BXY+vkTcCWF4/bfwv8ZuBjgBH4Im/DI68bx58Ffg7Gr58A/gBQ73Pdc8BnAtf8uX9eRH62f/aFwLu9z68F/jCgIvLTgP8BeKuqXsXW048+pC0dHS8ruqGmo+PlxVPAB1Q13+ez9/rnDd+uqn9fVauqbu+59rOBb1DVb1XVCfhjgL7Es/+Eqm5V9bswgvMzAVT136jqd6hq9uiev4aRrg8FP6KqX62qBfg64I3Al6nqXlW/CZgwIoGq/gtV/R7v33dj5KA996X6998DX6Kq71bVPfClwGdLD8Xu6Ojo6Oj4ieArVfVdqnoD+HLgvzn6rAJ/3HX6lofr4s8G/pGq/kv/7I9y/xdosJf2f66qf9sjjV9Q1X/3gGvfCrxGVb9MVSdVfSfwvwG/zj//HODLVfWGqr4LN0K8BP6UX//jwF846vOvB/6cqr7THWpfDPy6D4JrfKuq/hPnRF+D8y4s5X0A/oL39+uB//cl7vUXVfU9Pi/fgBmzWn+/WlW/T1UvsDl4KXyNqn6vqp5j8/I5IhKBXwv8Y1X9Z6o6Y0aQDfALVbViBp4vwIxOf1pV/+39bu7z93dV9UJV72Lr6F5e+dWq+oO+jt521J9HXjdutPstwO9W1WdVtajqt/n37m3TP1bVd6jh/wG+CfjF/vGMGa3e7PPxLaqqmCNzBfwMERlU9UdV9R0PH9qOjpcP3VDT0fHy4gPAUw9Q8q/zzxve9ZD7PHP8uSvnF17i2e87+vsF5hlCRD7JQz7fJ5Zm9RVcNhh9MHj/0d+33rZ7f9ee+/NF5O0eGnsbI3ztuS/VvzcDf8/DVG8BP4Ap1Nd+iO3u6Ojo6OjouMw9fgzTxw3P3xNB8TBdfK8eP+fBPOWNwKO+AL8ZeKY905/7hzno/0vP9T68FB7U52fu+f6PYQevPCrXuJd3rZ3/PQM868aAR23nfTkcL+7vw7jj/a75Mcxo9BT39NeNM+/CI5rdmfd2LFX/gVHMInIiIn/N05HuAP8SeMyNQR9Uf15i3TyFRam/5NoRkU8Tke/w1KZbwKdz4Jx/Bvhh4Js8LeoP+bN/GIsu+lLgORH5OyLyzH1u39HxYUE31HR0vLz4dix09Ncc/9LTgT4N+OajXz8sQua9HIWCej7ukx9im/4K8O+BT1TVaxjhkQ/xXh8MvhbzyrxRVa8Df/XouS/Vv3cBn6aqjx39rFX12Q9Duzs6Ojo6Oj5a8cajv78JeM/Rv+/lJQ/Txe89vpeInPBgnvIuLFXpfrjfM3/knmdeVdVP988vPdf78FJ4UJ/fgxmGjj/LmFPqHDhpH7gR4jWP8KzWxteLyDHXepR2Puhex6lBb3zQhQ+45k1YRMkHuKe/3r43As/6vz8DS5H/Zsy48SB8IZZq//OdV/6n7ZaP0LYPZt18ANjx4LXT7rEC/i4WIfRaVX0MS+sSAFW9q6pfqKofj6Wr/b5Wi0ZVv1ZVfxE2LoodBtLR8RFBN9R0dLyMUNXbWDHhvyQiv1xEBhH5WCzs891YaOyj4OuBzxIr1jti1v4P1bhyFTsi/ExEfjrw2z/E+3woz72hqjsR+XlY6HPDS/XvrwJf7nntiMhrRORXfpja3dHR0dHR8dGK3ykibxArMPslWBrzg/AwXfz1wGeKFQkegS/jwe8Zfwv4pSLyOSKSRORJEWmpMO/H6sM0/Gvgrhd53YhIFJH/SERa0eC3AV8sVtD2DVgdnZfC7/fr3wj87qM+/23g94rIx7lD7SuAr/P09R/EImQ+Q0QGrJbP6hGeBea0y8AXOA/8NcDPe8Tv3ou3Ab9ZRD7ZjRp/9BG+87ki8jP8+i8Dvt7Ts94GfIaIfKr36Qsx5+K3iRX7/Srg87G6P58lIp/+gPtfxSKob/k6+uMfRH8eed14xM//Afw5sQLTUUR+gRtmjjFic/M8kEXk04D/qn0oIp8pIm9xw9RtLCqsishPE5H/wu+38z49KH2vo+NlRzfUdHS8zFDVP41FrfxZzEDynZiH6FPvl1f7gHt8H0Y+/g7mfTjDCqU90vfvwRdhRpK7WJ73w0jZTyZ+B/BlInIXq0HztvbBI/Tvf8G
"text/plain": [
"<Figure size 1440x720 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light",
"tags": []
},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"((16, 60, 218, 133), 'Bus', 0.9851093)"
]
},
"execution_count": 18,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"image, crops, bbs, labels, deltas, gtbbs, fpath = test_ds[7]\n",
"test_predictions(fpath)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1OxuR9fjFB45"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"name": "Training_RCNN.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"latex_envs": {
"LaTeX_envs_menu_present": true,
"autoclose": false,
"autocomplete": true,
"bibliofile": "biblio.bib",
"cite_by": "apalike",
"current_citInitial": 1,
"eqLabelWithNumbers": true,
"eqNumInitial": 1,
"hotkeys": {
"equation": "Ctrl-E",
"itemize": "Ctrl-I"
},
"labels_anchors": false,
"latex_user_defs": false,
"report_style_numbering": false,
"user_envs_cfg": false
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"0de938505a4445caa4bff6402a9efa56": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1051fab2a4524a0a897e4af2f8186e0d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"2bde1618e6a244dba48a884d32e3dcd4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"4f05bb580a544d788143bd4fdea8150a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"804c3bc6fa9e48d9b59ce11e855b0999": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_9610567dd5d1460f99516aac12d77581",
"max": 553433881,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_2bde1618e6a244dba48a884d32e3dcd4",
"value": 553433881
}
},
"9610567dd5d1460f99516aac12d77581": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a929d30788fe4772ab578390ce90597e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_804c3bc6fa9e48d9b59ce11e855b0999",
"IPY_MODEL_aaf77f8c6ddc496b8975b648830878e6"
],
"layout": "IPY_MODEL_4f05bb580a544d788143bd4fdea8150a"
}
},
"aaf77f8c6ddc496b8975b648830878e6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0de938505a4445caa4bff6402a9efa56",
"placeholder": "",
"style": "IPY_MODEL_1051fab2a4524a0a897e4af2f8186e0d",
"value": " 528M/528M [00:02&lt;00:00, 234MB/s]"
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}