SW-ramon-dyzman/cnn/visualize.ipynb

514 lines
55 KiB
Plaintext
Raw Normal View History

2021-12-19 22:37:44 +01:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from torch.utils.data import random_split\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import torchvision\n",
"from torch.utils.tensorboard import SummaryWriter\n",
"\n",
"writer = SummaryWriter()"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"class TreesDataset(Dataset):\n",
" def __init__(self, data_links) -> None:\n",
" self.X, self.Y = readData(data_links)\n",
"\n",
" def __len__(self):\n",
" return len(self.X)\n",
"\n",
" def __getitem__(self, index):\n",
" return (self.X[index], self.Y[index])\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(3264, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 2)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = torch.flatten(x, 1)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def create_datalinks(root_dir):\n",
" data_links = os.listdir(root_dir)\n",
" data_links = [root_dir + \"/\" + x for x in data_links]\n",
" return data_links\n",
"\n",
"def preprocess(img):\n",
" scale_percent = 10\n",
" width = int(img.shape[1] * scale_percent / 100)\n",
" height = int(img.shape[0] * scale_percent / 100)\n",
" dim = (width, height)\n",
" resized = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)\n",
" resized = torchvision.transforms.functional.to_tensor(resized)\n",
" return resized\n",
"\n",
"def readData(data_links):\n",
" x, y = [], []\n",
" for link in data_links:\n",
" img = cv2.imread(link, cv2.IMREAD_COLOR)\n",
" img = preprocess(img)\n",
" if(\"ground\" in link):\n",
" label = 1\n",
" elif(\"AS12\" in link):\n",
" label = 0\n",
" else:\n",
" label = 0\n",
" x.append(img)\n",
" y.append(label)\n",
"\n",
" return x, y"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"links_3_plus_ground = create_datalinks(\"new_data/AS12_3\") + create_datalinks(\"new_data/ground\")\n",
"\n",
"dataset = TreesDataset(links_3_plus_ground)\n",
"\n",
"train_set, test_set = random_split(dataset, [300, 50], generator=torch.Generator().manual_seed(42))\n",
"\n",
"trainloader = DataLoader(train_set, batch_size=10, shuffle=True, num_workers=2)\n",
"testloader = DataLoader(test_set, batch_size=10, shuffle=True, num_workers=2)\n",
"\n",
"classes = ('tree', 'ground')\n",
"epochs_num = 15\n",
"\n",
"net = Net()\n",
"criterion = nn.CrossEntropyLoss()\n",
"optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[1, 1] loss: 0.074\n",
"[1, 11] loss: 0.725\n",
"[1, 21] loss: 0.695\n",
"[2, 1] loss: 0.063\n",
"[2, 11] loss: 0.606\n",
"[2, 21] loss: 0.594\n",
"[3, 1] loss: 0.071\n",
"[3, 11] loss: 0.405\n",
"[3, 21] loss: 0.477\n",
"[4, 1] loss: 0.015\n",
"[4, 11] loss: 0.327\n",
"[4, 21] loss: 0.484\n",
"[5, 1] loss: 0.052\n",
"[5, 11] loss: 0.486\n",
"[5, 21] loss: 0.370\n",
"[6, 1] loss: 0.014\n",
"[6, 11] loss: 0.454\n",
"[6, 21] loss: 0.317\n",
"[7, 1] loss: 0.052\n",
"[7, 11] loss: 0.434\n",
"[7, 21] loss: 0.467\n",
"[8, 1] loss: 0.051\n",
"[8, 11] loss: 0.438\n",
"[8, 21] loss: 0.457\n",
"[9, 1] loss: 0.071\n",
"[9, 11] loss: 0.422\n",
"[9, 21] loss: 0.358\n",
"[10, 1] loss: 0.013\n",
"[10, 11] loss: 0.447\n",
"[10, 21] loss: 0.373\n",
"[11, 1] loss: 0.052\n",
"[11, 11] loss: 0.314\n",
"[11, 21] loss: 0.387\n",
"[12, 1] loss: 0.037\n",
"[12, 11] loss: 0.437\n",
"[12, 21] loss: 0.395\n",
"[13, 1] loss: 0.013\n",
"[13, 11] loss: 0.431\n",
"[13, 21] loss: 0.423\n",
"[14, 1] loss: 0.017\n",
"[14, 11] loss: 0.466\n",
"[14, 21] loss: 0.371\n",
"[15, 1] loss: 0.032\n",
"[15, 11] loss: 0.441\n",
"[15, 21] loss: 0.324\n",
"Finished Training\n"
]
}
],
"source": [
"\n",
"\n",
"for epoch in range(epochs_num):\n",
" correct = 0\n",
" total = 0\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" inputs, labels = data\n",
" optimizer.zero_grad()\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
" running_loss += loss.item()\n",
"\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
" if i % 10 == 0: \n",
" print('[%d, %5d] loss: %.3f' %\n",
" (epoch + 1, i + 1, running_loss / 10))\n",
" running_loss = 0.0\n",
"\n",
" writer.add_scalar(\"Loss/train\", loss.item(), i + epoch)\n",
" writer.add_scalar(\"Accuracy/train\", correct/total, i + epoch)\n",
"\n",
"print('Finished Training')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy : 84 %\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"i=0\n",
"with torch.no_grad():\n",
" for data in testloader:\n",
" images, labels = data\n",
" outputs = net(images)\n",
" loss = criterion(outputs, labels)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" writer.add_scalar(\"Accuracy/test\", correct/total, i + epoch)\n",
" writer.add_scalar(\"Loss/test\", loss.item(), i + epoch)\n",
" i += 1\n",
"\n",
"print('Accuracy : %d %%' % (100 * correct / total))"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"\n",
"images, labels = next(iter(trainloader))\n",
"grid = torchvision.utils.make_grid(images)\n",
"writer.add_image('images', grid, 0)\n",
"writer.add_graph(net, images)\n",
"writer.close()"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"multi = create_datalinks(\"new_data/multi\")\n",
"multiset = TreesDataset(multi)\n",
"multiloader = DataLoader(multiset, batch_size=10, shuffle=True, num_workers=2)\n",
"criterion2 = nn.Softmax(dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[0.8594, 0.1406],\n",
" [0.8276, 0.1724],\n",
" [0.8850, 0.1150],\n",
" [0.8887, 0.1113],\n",
" [0.8737, 0.1263],\n",
" [0.8814, 0.1186],\n",
" [0.8911, 0.1089],\n",
" [0.8364, 0.1636],\n",
" [0.8846, 0.1154],\n",
" [0.8726, 0.1274]])\n",
"tensor([[0.8452, 0.1548],\n",
" [0.8533, 0.1467],\n",
" [0.8536, 0.1464],\n",
" [0.8593, 0.1407],\n",
" [0.8557, 0.1443],\n",
" [0.8811, 0.1189],\n",
" [0.8727, 0.1273],\n",
" [0.8529, 0.1471],\n",
" [0.9053, 0.0947],\n",
" [0.8824, 0.1176]])\n",
"tensor([[0.8593, 0.1407],\n",
" [0.8952, 0.1048],\n",
" [0.8780, 0.1220],\n",
" [0.8724, 0.1276],\n",
" [0.8451, 0.1549],\n",
" [0.8424, 0.1576],\n",
" [0.8332, 0.1668],\n",
" [0.8567, 0.1433],\n",
" [0.8487, 0.1513],\n",
" [0.8839, 0.1161]])\n",
"tensor([[0.8759, 0.1241],\n",
" [0.8340, 0.1660],\n",
" [0.9141, 0.0859],\n",
" [0.9075, 0.0925],\n",
" [0.8674, 0.1326],\n",
" [0.8431, 0.1569],\n",
" [0.8933, 0.1067],\n",
" [0.8475, 0.1525],\n",
" [0.8363, 0.1637],\n",
" [0.8789, 0.1211]])\n",
"tensor([[0.8495, 0.1505],\n",
" [0.8402, 0.1598],\n",
" [0.8482, 0.1518],\n",
" [0.8470, 0.1530],\n",
" [0.8733, 0.1267],\n",
" [0.8362, 0.1638],\n",
" [0.8909, 0.1091],\n",
" [0.8568, 0.1432],\n",
" [0.8577, 0.1423],\n",
" [0.8678, 0.1322]])\n",
"1.0\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"i=0\n",
"with torch.no_grad():\n",
" for data in multiloader:\n",
" images, labels = data\n",
" outputs = net(images)\n",
" loss = criterion2(outputs)\n",
" loss2 = criterion(outputs, labels)\n",
" print(loss)\n",
" _, predicted = torch.max(outputs, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" writer.add_scalar(\"Multi/Accuracy\", correct/total, i)\n",
" writer.add_scalar(\"Multi/Loss\", loss2.item(), i + epoch)\n",
"\n",
" i += 1\n",
"\n",
"print(correct/total)"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image\n",
"from torchvision import models\n",
"from torchsummary import summary\n",
"from matplotlib import pyplot"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 6, 76, 56] 456\n",
" MaxPool2d-2 [-1, 6, 38, 28] 0\n",
" Conv2d-3 [-1, 16, 34, 24] 2,416\n",
" MaxPool2d-4 [-1, 16, 17, 12] 0\n",
" Linear-5 [-1, 120] 391,800\n",
" Linear-6 [-1, 84] 10,164\n",
" Linear-7 [-1, 2] 170\n",
"================================================================\n",
"Total params: 405,006\n",
"Trainable params: 405,006\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.05\n",
"Forward/backward pass size (MB): 0.37\n",
"Params size (MB): 1.54\n",
"Estimated Total Size (MB): 1.97\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"summary(net, (3, 80, 60))\n",
"link = \"new_data/AS12_3/AS12_3_1.png\""
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 6, 28, 38])\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA44AAAMXCAYAAAB1ngsfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA6gklEQVR4nO3dy69tSZ4f9F+svfc5574yb2ZWVVe/qtxuv7EA2RIg8AQhTxDIMgMmCCHBAGSJAWbsCUJCAvEPMGAAAySEGIChBzBAiJfcArWETdPdsruruqorqyof93Xueey9VzBICSZ1g2/kPbtO38zPZ5j61S9ixYpYa33PLum23nsBAADAmyz3PQEAAAD+eBMcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGNrOFG8ePuq7px9GtS38Vz56m5nBu+OreP3pNVXNXddMX+Cr5frjH3zSe//mfc/jXTbzbq77fuec4nl/39f0Lvm6v29n9srXfa34Wrv50c9+N08Fx93TD+tP/Gt/M6o9SXB6hw684Hiavtwj94nQzJn+7X/vb37vdDP5etg9/bC++69n7+Z7D1mC4/36uj/H36HvyHsf/wRO9R15ivFP4h0a/3f+3Z/9bvZ/VQUAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBoO1XdqtazO57BPf9rnKf4B0Zn3Pc/Rnrf1191/2tw39J7MLVOMzc2bXyKnrN973H4ufWfqJ2RLtV6ovH542mpWs/ejQfpSc7bO+S+/1HzUz3Gv9bepXV6l+Yamtr/p5rEPX9M/zzPql8cAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGBIcAQAAGNpOVbeq9bxHpVnVCbX7nsA7wjrdv5MclhM07ac61Xff99639buzVNVOMdd7fwF8vfSqWjdh8SkOx0zPcG+cbAvd88Oh3/P4M+c9LX2Xrukk7nv8Uwmv677ft388Pk1O8B1zgrN6F/ziCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwNB2prgvVcfzfrczaHfb7v+VTnNi/D4z1xZOYKZp2nPGzDWdYPipOcyMf6p9lZr5k8yalfVdvgDt9lQbO53AxM0Kxz/F9v9i/Hek54R3aq14e61qPXtH3s33/Wy+Z6d43J7KKT5j7tvMs7Eds7qpszc1/j0v7Ck+OWd6vkvvmxN8x0xd/s9xrfziCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwNB2qrr1Wh+sJ5rKHWv9BD0natPhv6rR/RTrP6PP3KzQNt/7m/NjXHt8cZYVnuXj92Xi+k+wVPH+P1XLU2y/U+ypU5m5/ona9i6twddJq+q78Eae4haeoGc/1TsknevM8O/Ims7q9/waT9egHScWa5/X9vD7rD29jXsum3xRj5+fx7UnkU51Yp+cZEutM987M5PN+7bwsNz/d8zbt/iqxhYAAADuiOAIAADAkOAIAADAkOAIAADAkOAIAADAkOAIAADAkOAIAADAkOAIAADAkOAIAADAkOAIAADA0HaqeqlqDw5ZbQt79nz4tuTFvWcTmOl5Ci1dp6qqls91pm3cc2L8r6Lzi31c+/jiJq599eA8q/vkUdyzna1x7fY8O9PpmaqaOta5mfFPMYETbf++nuC0TqzVVNt0DU40Pm/QevVdeHNO8Ryfeo+ldSc6cO/K1pyZ58xSvSvXP2Hm8jeXu7j2+DB7jz54eBv3/OaTy7j2e1ffyApP8Q6pyhf2vj8NZ943U3OdyRxZXTvZXH9+Pf3iCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwJDgCAAAwNB2prgtvc4f7rPa1r/UhO6qZ2t333NG2jec5rTNsp6oc+YU67qcarFC/9yv/t249jc//25c+1//xf8kqvtrf+9fjnt+/H9/K6598usvo7reT3MD0r7rxPinONUz17+ueW36rJrpOzP+jHQN+onG5w2WqjoLn/npu+lUtzB9N5xo/JO8808w15Ot/1fQ8WYT1z74OF/Yq29nv638G3/uf4p7/vbrX4xrf/DTD6K69TixWU6w/U/2vE/fuX3iomauf+abJ+zbZ+Y640TfZz+LXxwBAAAYEhwBAAAYEhwBAAAYEhwBAAAYEhwBAAAYEhwBAAAYEhwBAAAYEhwBAAAYEhwBAAAYEhwBAAAY2s4Ub5a1njy8jmqX1uOeqRZX5uOndVVVbaL2FOMvldemZtZ/xtR1TdSeQrquf3Tzftzzb/+Z34hr/4nf+leiuh//OB+/Lfma/vJ7L6K6w3qavzMdT9Q3dbtu7rxn7/nTamZd07u6Tow/M9f0Xq33e6S/flqvZXcMa08w/EzP8Hk/1XLmHZKOf4J1+qLvuzH+jJlnyCnGv7rMP2U/+N1DXPvgp9m74Zt/LXuHVlX9Ix99L6598WsXUd3f+f534p51gnfDzB3t6wk29sQEpsafWqu77zm3sGHxzPhv4BdHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhgRHAAAAhrYzxbvNsX7p8Yus8XIM69Z4/KV6XJuOv2kTPdvEXMPamfFnrj/uOXVN+fibyvueYvxT+NMPfnySvv/qr/0vUd1/+bf+6bjn7/yNB3Htn378k6huZv33fRPXptbeJmrzv4ntw9rDml/TWvlc9xN9b8PambWaua5DvFb+Jvnztmyy89lO8Ryd6NnyrRlblvt9N8ysaVp7inU6lTbxHXES4d6vqnr095/FtZd/5RtR3X/wu3817vm//+X/PK79j8K6uf2f1/aJ90jcc+bVEI4/M88+sVfS8Wf0iaMytf7pZd3BNXm7AwAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMLSdKV57q8vDWVS7VI/qWsvqqqqWE9RO9Qyv6Yu+a1yb98zHP4WZ8dfe4trtcvdrNSO9r3/1vb8b9/zfro9x7W+9+k5U93v/0pO45+ZZXFo/un4/qlsrv6czZvbKKXqmtTP7/3bNH62nmOtxzf8meOh57SnuFXdjXbN7007w5+KZlj08Rm3Jz1vac0a7560+823UJ87lcoL37cz4M9I1aJebvOmS79ZHP8re46/y0ev1ehvX/vKDZ1Hd/7H8StwzfU7M6CfoWTXxvpnYf31i+59kX8+s1cwz4ET34GfxiyMAAABDgiMAAABDgiMAAABDgiMAAABDgiMAAABDgiMAAABDgiMAAABDgiMAAABDgiMAAABDgiMAAABD25nim/22/uDHH0W169qiurb0mSnElmUN6/LxD/t8udLxq01cf8/WdKZvX/O/HSyb8Jpqbl33+01Ut5kYf0YPp/o7n/31uOff+FP/Y1z7P/xXfymqe+/TuGWdvczX/zd/8btR3cz97+H5r6q5MxBPIB8/fQatx/ysbHfHuPYQ7v+qqnaCteoTa5U+1/k52y/VfnSR1YZ7qE3c6+PZxL5MHyMTW61vZ96j4fAT1983E+Onj5GZoz5Te4KfC2auvx0nbmy4Vy5+nD9D14dnce35pzdR3Q9/98O45z97/i/GtZ+8ehTV3Xz/cdxz5qy0Q/gdP/Fp1mf2X7hVWv66rSW8ponhqyo/gjNrNXNd8afBTI54A784AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMCQ4AgAAMLSdKn6x1Ie/8eBOJ9BnomvLS9f0ynrec3Ob1/ZNVtfWiZ4T178csgs7XORNl2M+/sx1xSbu1fHs7offP34c1/6Hf+dfiGvPP8/q3v/ePu65e57Xtv/mIiyMW9a6yYvTs7rJL6mWfb5Z0rN6PJs4K4e4tNoxn2ubOAOpmevaXmcTOO7
"text/plain": [
"<Figure size 1152x1152 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = cv2.imread(link, cv2.IMREAD_COLOR)\n",
"img = preprocess(img)\n",
"img = img.view(1, 3, 60, 80)\n",
"\n",
"output = net.pool(F.relu(net.conv1(img)))\n",
"print(output.size())\n",
"\n",
"square = 2\n",
"ix = 1\n",
"pyplot.figure(figsize=(16, 16))\n",
"with torch.no_grad():\n",
" for _ in range(square):\n",
" for _ in range(square):\n",
" ax = pyplot.subplot(square, square, ix)\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" pyplot.imshow(output[0, ix-1, :, :])\n",
" ix += 1\n",
"\n",
"pyplot.show()"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA44AAANDCAYAAADigUH+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA7vUlEQVR4nO3de8xt+XkX9t9aa+/3ct4zZ+bM1Y4nnsSpHUhDlJIQx6jQQlqlUqOqN7VCUEpbilQViKCIIiTUlpYWVKkRaQm0Uqu2XNSKiHItFHCgJDiEJMUUQm7j2DP22HM5M3Nu73XvtX79wx6D3DMPz+OZfc47M5/Pn0ff95nfXnut9VvfvUfaQ++9AQAAwJsZH/QCAAAAuNwURwAAAEKKIwAAACHFEQAAgJDiCAAAQEhxBAAAILSqhKeHjvrq8eu7WcngZ0G+0jDks9VfVdnl7NYqw6uzdyi5lu2rr7f5znHhRd4/q8Ojvnft0Qe9jLLKabDLA7/LU716i9vlpfFuv91e3HmtbU8v5zW6N+z3g3b0oJcBD9RZO24X/fzSXaPT0VFfPfrO20PfE6bixjVfjmfRS3eSJ52/8LkbvfcnvvLfS8Vx9fj19r7/9DenssNYexeq+drwQrT6DleewHpt+Gq9TWc356W3sk3rOZ2dN1Npdq+8zqV2TEoltni8szeZF//z76/NvY/2rj3a/olf89sf9DLKKm/VLktPr/4/GIV1j5va6KV22ZWM+VtLXeH9GXb0O8I/9ye/bydz3w4H7ah9dPiuB70MeKB+rH/8QS/hnlaPPtqe/t7flspWHzHe9Z/YtfZVHJS8+Vpt45ru5DfRYbu7dZcnl55za6Mrp+Czv/s/eu5e/+5/VQUAACCkOAIAABBSHAEAAAgpjgAAAIQURwAAAEKKIwAAACHFEQAAgJDiCAAAQKj0q/Gr1dyeePJ2KjuNS2khU+FXKddT/sfrW2ttVVjLWPw1zcrrXIo/jHpt7yyd/Ymf/HBp9rd/+8+ms6+cXS3N3i75zyN68ZhsCrM3c+0X1LdzbvYr69r5dz8tq9ZOn7gEPzK8u9/S3al5v/gHY/5YT2fFc/1a/t4ybGqz13fy11H5N6tLP15cW3d2LUtpZwP4kqm37aObXLa6z1XyxRvvcEn23L4tLuQivxf9im/JP7e21tpPvPDBdPb0xpXS7DYV3p9aHSrvixXD/NZn+8YRAACAkOIIAABASHEEAAAgpDgCAAAQUhwBAAAIKY4AAACEFEcAAABCiiMAAAAhxREAAICQ4ggAAEBIcQQAACC0qoSncWkPH5ylsutxri1kXNLZvXFbmj0OPb+OIb+O1lpbFV7nr77+M6XZv+Hay+nsh3/kI6XZ3/HIp9PZZ0+eKs2+cXGUzl7MU2n20vOfdVwstdmbZL5yrt5vfa+3zQfPU9lhzF8XX/yDQrSQ/eIfFNdSGV3IXrt6Wpp9fLqfzi7P5a+L1lr7hm/6fH52rx3wT3/+8Xy4+Nb0pbCWSraQ73u7O5+Ad69h6m3/Wm4PHYt7aCU/XuLnjMj5+bqUX33yajr7v/7Lf6O2mGfy0Q/9lX+3NHrvcJPOLsV9ri/559zqTtfn6sPZ/59vHAEAAAgpjgAAAIQURwAAAEKKIwAAACHFEQAAgJDiCAAAQEhxBAAAIKQ4AgAAEFIcAQAACCmOAAAAhFaV8PnJXvuFv/uBVLZP1ZX0fHYpzi7U4z4U1tFaa0M++sMHHy6N/uW/+vvT2fHrj0uzv/+HvjudffiDt0qzbz33cD7cCwewtTYU3vthW5ydfOvPT/ZKc++n9x3dar/rO/5SKrs3bEuzx8KFdzRelGavC2uZKidBa23d5nR2b8hnW2vtj9/4WDr7N6evL83+Z574+XT2ex76u6XZP/fBJ9PZpdc+Xzzr63T2znxYnJ3bsv7bq3dKc3n3mv7a16SzX/iTX1ea/eQPfKK4Gi69obdxzD0MTFNtL+rF552KqfDsOlSfcws+9L6XSvmfevgonf2zx1dKs791/+V0drVfex7a28vn57m2h25rSynpw1s/B33jCAAAQEhxBAAAIKQ4AgAAEFIcAQAACCmOAAAAhBRHAAAAQoojAAAAIcURAACAkOIIAABASHEEAAAgpDgCAAAQWlXCBy+et2/8A7+Qyg57e7WVLEstXzHP6WjvvTR6ODzIz755qzT7X/rtvyOdvf6ztXU//Mf/Vjq7+roPlmZ/zfnNdLYfn5RmD48+ks4ur75emz0MqdyNu6eluffT1fG8/YrDZ1PZh8baNTcVsuvksfyHs/P5sTh7XVj50mrH5C/v3UlnP3C9dv3vj5t09lv390uzn5peSGfPareWkju9tAW1Tc991vk/j2dfzXLui2F/r00f/PqdzO5X8ufBfLV2zrTCeTBui/t5Yf+fXq5dR3/hG/98Ovv13/KbSrOfLGSnj3xDafb86FEpXzGeXKSzfX9dm318nsoNn/qR0tz7pS9DOz+rvebK7F0ZxvwFWtxC2zDkZz/2ZH5vaa217SP55/Pv/fivK80+eCz/rLY5rnWWZS48ERWOX2utLdv8d3p9rr2ZlfPkzfjGEQAAgJDiCAAAQEhxBAAAIKQ4AgAAEFIcAQAACCmOAAAAhBRHAAAAQoojAAAAIcURAACAkOIIAABAaFVKz3Prt27nsoeHtZX0JR1dTs9qoy8uCuFemr1L7//ENp3du3m+s3VsP/P8zmaX3U6efzvUC+fq/fbszafav/jnvjcXHorDKy+7F4cPl+O6G+bauq8+n//sbVnX1vI/PP50OvsDV+fS7PF8h58ZFt7KoXgpDUvu/fnc699XG3wf9fOLNj/76Qe9jPLlX7HLqzm/K37RL/29/0E6+5E/8qPF6Xnzz31qZ7OrdrmDZe9Eve/umeWtGE/Hdvj3cs+vvXgbHWu36ZJl2t06KrP/5rO/pDT76p38nWg+KI1uj//FfA957RfX6lDlJlc9T6ZCZSl7G27OvnEEAAAgpDgCAAAQUhwBAAAIKY4AAACEFEcAAABCiiMAAAAhxREAAICQ4ggAAEBIcQQAACCkOAIAABBSHAEAAAitSulpbMPVo1S0n57VVjLP6Wg/Py+NHlb5l9m329rs9V5+9uaiNPvKz7yUzm4/83xpNu9Oq+PWnvrRIZUdN700e1jy2WVdGt2GJbfm1lprvbbuPuVnV4/JWLhfDPlbXGuttc1R/nO90vFrrfXCR4ZD7ZC0YSn+QUXyvX/peHdL4J3liT/yow96CbyDTOetPfJs7mY9bmv3utVx4Tl3Vfxep7AvDnNt3dMmv/mfPZZ/Jm6ttV7Yug5e3ZRmj4V1v+/Has/+fcwvfFnX3svpPH+eVGdX3/t78Y0jAAAAIcURAACAkOIIAABASHEEAAAgpDgCAAAQUhwBAAAIKY4AAACEFEcAAABCiiMAAAAhxREAAIDQqhLu27nNN17NDX76A6WFbD/3Qilf0bfbnc2eHruezm5ffKk0e/uZ56vL4T1u3PZ2eGN353vWsO07mz1ullJ+3s9/PjYW1z3MhXzxY7q9W5V11I5Jq7zMcaiNXhXylePXWmtDbnb1HAForbXWWxuSt48rL5wUZ1f2i919r7N5aK+U74U94KG/f6M0ezk6yK9jXTwmhXWvXzkrjR4K7+WyV6pabTy7yM8+qL2Xw/LW90bfOAIAABBSHAEAAAgpjgAAAIQURwAAAEKKIwAAACHFEQAAgJDiCAAAQEhxBAAAIKQ4AgAAEFIcAQAACCmOAAAAhFa7Grz93Au7Gn2pbF986UEv4T1l+vCH0tn5539hhyu5nMbN0vZfvJvKLlf2arNPLtLZYe6l2X0Y8rOXpTR7Xcguh5V0a+Otk3y48Bpba63t5ddSOX6ttTbMcyFcm93H/OeR463cufrl2fu5YzJsCq8P4EvGi7kdffY4lZ1u3C7NXq5eSWeH89PS7OFik86Or9ce/edHr+bDxf15vJPfQ4e5Nruir2vHZDg5S2fHa0el2ctB/tlsev1OafZ8/aFS/l584wgAAEBIcQQAACCkOAIAABBSHAEAAAgpjgAAAIQURwAAAEKKIwAAACHFEQAAgJDiCAA
"text/plain": [
"<Figure size 1152x1152 with 16 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"img = cv2.imread(link, cv2.IMREAD_COLOR)\n",
"img = preprocess(img)\n",
"img = img.view(1, 3, 60, 80)\n",
"\n",
"output = net.pool(F.relu(net.conv1(img)))\n",
"output = net.pool(F.relu(net.conv2(output)))\n",
"output.size()\n",
"\n",
"square = 4\n",
"ix = 1\n",
"pyplot.figure(figsize=(16, 16))\n",
"with torch.no_grad():\n",
" for _ in range(square):\n",
" for _ in range(square):\n",
" ax = pyplot.subplot(square, square, ix)\n",
" ax.set_xticks([])\n",
" ax.set_yticks([])\n",
" pyplot.imshow(output[0, ix-1, :, :])\n",
" ix += 1\n",
"\n",
"pyplot.show()"
]
}
],
"metadata": {
"interpreter": {
"hash": "3c791669b07b322e46c2c9e5f9e8a4c39f8cc206c386431b147b2f78281d9ccb"
},
"kernelspec": {
"display_name": "Python 3.8.10 64-bit ('venv': venv)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}