stud-ai/1-intro/7-colours_nn.ipynb

394 lines
77 KiB
Plaintext
Raw Normal View History

2024-08-06 11:37:45 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done\n"
]
}
],
"source": [
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"from PIL import Image\n",
"import os\n",
"\n",
"transform = transforms.Compose(\n",
" [ transforms.Resize(32),\n",
" transforms.Pad(10, fill=255),\n",
" transforms.CenterCrop((32, 32)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
" ]\n",
")\n",
"\n",
"def check_image(path):\n",
" try:\n",
" im = Image.open(path)\n",
" im.verify()\n",
" return True\n",
" \n",
" except:\n",
" print(path)\n",
" return False\n",
" finally:\n",
" im.close()\n",
"\n",
"\n",
"trainset = torchvision.datasets.ImageFolder(root='../datasets/Kolory_mini/', transform=transform,is_valid_file = check_image)\n",
"testset = torchvision.datasets.ImageFolder(root='../datasets/Kolory_mini/', transform=transform,is_valid_file = check_image)\n",
"\n",
"\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=12, drop_last=True)\n",
"testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=True, num_workers=12, drop_last=True)\n",
"print(\"Done\")"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor(-0.6549)\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAB4CAYAAADrPanmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAACqn0lEQVR4nOz9aZBkWXbfB/7uvW/13T32yIyM3Koqs/buru7qRgONHQ1BHJKSaKSGIqExcdgfZsZMtJFsRM2ngdlwTGM2MxrNl5GBIk2QDYeQaAQHJIwQCKDRQDd7q+qq6toys3KJjIzI2D18d3/bvXc+PPeIyKzcK6srQebfLDN8ef7efdt55/7P/5wjrLU8xVM8xVM8xZ8/yM96AE/xFE/xFE/xaHhqwJ/iKZ7iKf6c4qkBf4qneIqn+HOKpwb8KZ7iKZ7izymeGvCneIqneIo/p3hqwJ/iKZ7iKf6c4hMZcCHErwohLgkhrggh/u7jGtRTPMVTPMVT3B/iUXXgQggFfAT8MrAOvAH8L621Hz6+4T3FUzzFUzzF3fBJPPAvAVestdestQnw28BfejzDeoqneIqneIr7wfkEvz0GrB15vw68fq8fFAoFW6vVPsEmn+IpnuIp/u3D5ubmnrV25vbPP4kBfyAIIb4BfAOgWq3yjW9849Pe5FM8xVM8xb9R+I3f+I3VO33+SQz4TWDpyPvj489ugbX2N4HfBFhcXLQA58+fZ3l5+RNs+tPD6uoqFy5cQAjBL//yLyPlkynU+ZM/+RPiOGZpaYkXXnjhsx7OHdFsNnnjjTcA+OpXv0oQBAffCSFuWTbThu4oZX+QkGqL1gZjLRaw1iKxhL6Di6bTbvOjDy9Tqs8wNzNLMQxwXYWSAikE02WPRinAVfm5u1ucZ/L5m2++SafToV6v86XXX0fccenPFqPRiD/90z8F4PXXX6der3/GI7oz3n//fdbX1/F9n5/92Z/9rIdzRxhj+OM//mOstbzwwgssLS3d/0cPiZutJr1oRDUosFBvPNI6rl+/zsWLF++5zCcx4G8AzwghTpEb7v8Q+OsP8sNTp07xpS996RNs+tODlPLAgL/++usopT7rId0R3/nOd4jjmNnZWb785S9/1sO5I1ZWVg4M+CuvvEK5XD4w3Lcb8DjVbHdGXNsbMIwNaabRxmCtxViDAmaqAb5NWFu7wc6bHyFmargzJynXKhQCD1dJlIRn58ucmCoReM4txvtury9evEin06FcLvP6F7+EkHce42eJTqdzYMCff/75J9YB2t7eZn19Hc/zfqLXpbX2gc9XlmV885vfxFrLqVOneO211x77eN5ZW2G722apNsXzx0488no+NQNurc2EEP874A8ABfxDa+0Hj7q+p/g3G0KIW26w2284IcSR2Y4FAVKABYwFIQWlwEVpjZQSqRw8z8NxFEpJpMzXrySEnoM8YoQnxvpurw/GpC16kCJ9hXQlB9/eYhfGg+Iu6i0B4jP24Q/3605jFIf/P0EPqE+Kn+TD9vbr5vZtW2vJrMUKQej5n+pYPhEHbq39l8C/fExjeYp/i3CnGy6/McTH7I4ApBCEnovUBtd1cVwPqRRKKZSUKClxlMRzBKHrIB/hhk72h2z83oe48wWqz83hNkKUP1nX7TfpeLw2f9gIMdmnJ8MoWmvRSQdrUoSQSDm+1ZWPlB6IJ5MafFQ8jAf+OKCNAUDdgWKN05T9XpfeaEDpCG34aeBTD2J+UlgLJstIh0OG3S7luVmU4/zEuek0Ten3e2xtbDAzO0ulUsX1vCdqmv3nGTnXnetapRRYDTq3jfl35DeLKxS+5+F6bv6ZkiglD2gPKTh4/bAY7ne58M73CSsFFl4+TeXZOQon6gSzRVQ5QCoB1pINEuK9Ab2rTUbXm3gzFQqnGgSLZfxagHI+Y9rNGkzUpHP5fyLJYqRbwQmn8IMy1qYIPPzqMzilBaTyco8xS0iH+0jl4AZllPPpGp7HjZ/UfRhnKbudFt/95u9REJYz515h/thJKtU6Subn/b0b17i+sUbo+ZS8f8sNuElT4naHzvV1NlYuc+ylF6ktHccrFn9iRtwYTXNvl+9/7ztcfOsHLJ88yflXv8jSyTPUG1OfeaDzfslYT9pD5vbxTOgMbfL9sFiMMWhtwBoyo3GVh6sEjs0NtnKcWymY8d/cKTaPNK7BsM/bF1aYrtbRaKZbfUo3pgjnKwSzFfyZAlZAsjtgtN5luNtldHGfvXfXcRZCaq8eY+Gnn8Mregde+2dz7A06bpF1rtJp9clMiF+eY+7US0jHwZosP0bWonXGfnOL9z/4Hv2dq9Rrs5w4/QrHz7yKEPInMv7H4T3/JDzwzBg2RwO+vX6NN77zTQr9Djc/+oBT51/m9PlXWTp1jswY3l+/Tuj7PDt/HM91P9UxPdEG3BpD2usT7ewTb+/RvbmBkCCVQ/XYIl6x8BO5wIbDIWvra/zgB99n5d03aW6skiYxSRxx5rkXaExNA5+NoRyNRuzsbCOFxHVdpJQIKXOeWE74YXnAQd/OReewt7AWR1jSTw13uuEmD6Kjz6OcVBEoJXAdiWMVjuPiuh5KiFsWtliMhczYuzHUd+S+JximIzZ2LtMYVomF5lirxfT2PPX5WcqzdURJYKXEdDPSVsSw12e40WS/vU3QKyMLDqWFKfxaSDhTwgk8UJ8BJ24tVo8QZkR3b4dOJ8MvD5g68SphcRqLAOkSx306nTZvvvU93n3vT4nb6xyfmUeZiOlGFa84hfLKINSnem0/jnV/2veeBTpZwpV+h+/sbbLe2ad49QrDdpNue49hrw3WYhyPVq/D+WPLLE/P3ZFieZx4Qg14foPpJCXabRFv7SGiFFcq1t99D8fzkUpRX15CuvkufFo3ibWW3d1dPvroIz68cIlRZ4gdXkWnEf1elzQzfP6LXyYIw584Dwews7PND3/wPTwvoFQs4noeruvhui6u5+O5Lq7n4joOSjkoR+IoB6XUEWMvcm9LynHgMCcuRB6Re6y4mwpFCJDyCP0tBEIKJPlDyHcdXEfhCUEh9CkWQoRgrFQxB1y0sbkk8X7SwTuNIdIJV4cbrKU77I06nG3vcbrTxfZTsp0Bo9Y+VjhIxwUpiAZDsv6QbtwhGwKrWwz/OKMyXWPhp05ROT6FE3i3Phx/YpeHQHk+/d6Q3c0mXkcziix+vYzj+RiT0m5tc/Gj9/mDf/U7tDs3kSaBuMNMmNA7VqI89yJ+4xzSCXhSuP2fBKy1aK0xWiOkwHFyuu7maMA7nSZv9tp4UzPoDy9gNzaJ+j36zW3SaIhTnyP0QuYqNaYq1U99rE+kAZ/cY9F+i2i3Sdbp4yqXemOGtZXLrLz5BlmS4IYB5cX5T8XQTJBlKVeuXObdH/+Y4SjG+AV2WrukKzdoNjusXF/H8Qt84bXXUEr9xI34+++9w+/+zj9CCic32GMD7jg+vu/jez5BGFIIQ8JCgWKxSK1WY2Z2llqtRrlcplguUyiVKJVK+J6Hkgop5E/knp0cKykEvpdfjtaCMRZjLI7KA5Shq/CUpOi7TNcqLM1Ns91PyTJNlmm0oxBINJAk5hYvfrKNybm5m3FPMeybATaG/XRE7AncYshUq0qy1kamlpGJaEdt+kmfICzw8k//NKJX5nuXf8Sl977JKB3x+YVn+Nn213nxP/hZSotTCGFzYY04eCx+uhAS6dfRhFRrZTwvQPkV9m/8gLh9kVK5wDCJubaxww8/uIrjpiwfm+XmZocoE0RRj6R5maHwcMvLSOU/0rVwL2rvfvfInX6b/+TTnxmmScLe3jbNvW0KpTLHl05ipOL9zj7f29tmPRoRLi8TvFdD7jeh3cPYm2Tx/0xx7jjnfubXaHjep+59wxNqwCGnT6KtPWx/hDOmBMqFCkq5dPZ2se/+GGvghV/9FfxK+VMbx/7+PlevXmXl+vXcS3Q9Isen2R4w7EdEw4jf/vv/L/Z2/hpf+vJXqNcbOO5P7rDqLGM0HCAQ2N74xhgH8uRYmue67liJkP9TShE4Lp7vEwQFgrBIWCxSqpapNWaYnp6h0ZiiXp+i0ahTLpcJwzBXf4wDyPKIuzx
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"# functions to show an image\n",
"\n",
"\n",
"def imshow(img):\n",
" img = img / 2 + 0.5 # unnormalize\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" plt.show()\n",
"\n",
"\n",
"# get some random training images\n",
"dataiter = iter(trainloader)\n",
"images, labels = dataiter.next()\n",
"\n",
"print(images[0].min())\n",
"\n",
"# show images\n",
"imshow(torchvision.utils.make_grid(images))\n",
"\n",
"# print labels\n",
"#print(' '.join('%5s' % classes[labels[j]] for j in range(4)))\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda:0\n"
]
},
{
"data": {
"text/plain": [
"Net(\n",
" (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
" (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
" (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
" (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
" (fc3): Linear(in_features=84, out_features=14, bias=True)\n",
")"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"\n",
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 14)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(-1, 16 * 5 * 5)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"\n",
"net = Net()\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"print(device)\n",
"net.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"cuda:0\n",
"wololo1\n",
"wololo3\n",
"wololo4\n",
"[1, 498] loss: 2.403\n",
"Accuracy of the network on the 10000 test images: 23.95 %\n",
"[2, 498] loss: 2.140\n",
"Accuracy of the network on the 10000 test images: 34.82 %\n",
"[3, 498] loss: 2.048\n",
"Accuracy of the network on the 10000 test images: 35.14 %\n",
"[4, 498] loss: 2.001\n",
"Accuracy of the network on the 10000 test images: 32.73 %\n",
"[5, 498] loss: 1.974\n",
"Accuracy of the network on the 10000 test images: 32.56 %\n",
"[6, 498] loss: 1.923\n",
"Accuracy of the network on the 10000 test images: 38.45 %\n",
"[7, 498] loss: 1.893\n",
"Accuracy of the network on the 10000 test images: 37.89 %\n",
"[8, 498] loss: 1.938\n",
"Accuracy of the network on the 10000 test images: 38.00 %\n",
"[9, 498] loss: 1.854\n",
"Accuracy of the network on the 10000 test images: 39.21 %\n",
"[10, 498] loss: 1.933\n",
"Accuracy of the network on the 10000 test images: 38.86 %\n",
"[11, 498] loss: 1.904\n",
"Accuracy of the network on the 10000 test images: 38.81 %\n",
"[12, 498] loss: 1.899\n",
"Accuracy of the network on the 10000 test images: 38.91 %\n",
"[13, 498] loss: 1.863\n",
"Accuracy of the network on the 10000 test images: 32.96 %\n",
"[14, 498] loss: 1.856\n",
"Accuracy of the network on the 10000 test images: 38.66 %\n",
"[15, 498] loss: 1.838\n",
"Accuracy of the network on the 10000 test images: 39.48 %\n",
"Finished Training\n",
"CPU times: user 1min 31s, sys: 25.7 s, total: 1min 57s\n",
"Wall time: 3min 52s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"import torch.optim as optim\n",
"\n",
"# Assuming that we are on a CUDA machine, this should print a CUDA device:\n",
"\n",
"print(device)\n",
"\n",
"BATCH_SIZE = 16\n",
"EPOCHS = 15\n",
"OUTPUTS= 1\n",
"LR = 0.025\n",
"MINI_BATCH_SIZE = 500\n",
"\n",
"print(\"wololo1\")\n",
"\n",
"criterion = nn.CrossEntropyLoss()\n",
"print(\"wololo3\")\n",
"\n",
"optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)\n",
"print(\"wololo4\")\n",
"\n",
"for epoch in range(EPOCHS): # loop over the dataset multiple times\n",
"\n",
" running_loss = 0.0\n",
" for i, data in enumerate(trainloader, 0):\n",
" # get the inputs; data is a list of [inputs, labels]\n",
" \n",
" inputs, labels = data[0].to(device), data[1].to(device)\n",
" #inputs, labels = data\n",
" \n",
" # zero the parameter gradients\n",
" optimizer.zero_grad()\n",
"\n",
" # forward + backward + optimize\n",
" outputs = net(inputs)\n",
" loss = criterion(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # print statistics\n",
" running_loss += loss.item()\n",
" if i % MINI_BATCH_SIZE == MINI_BATCH_SIZE - 1: # print every 2000 mini-batches\n",
" print('[%d, %5d] loss: %.3f' %\n",
" (epoch + 1, i-1,running_loss / MINI_BATCH_SIZE))\n",
" running_loss = 0.0\n",
" \n",
" correct = 0\n",
" total = 0\n",
" with torch.no_grad():\n",
" for data in testloader:\n",
" images, labels = data[0].to(device), data[1].to(device)\n",
" outputs = net(images)\n",
" _, predicted = torch.max(outputs.data, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
" print('Accuracy of the network on the 10000 test images: %.2f %%' % (\n",
" 100.0 * correct / total))\n",
"\n",
"print('Finished Training')"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAYaElEQVR4nO3dfWzd5XUH8O+5L35NiGM7JCEEHEKAZlACvYRMhbZr1YoxKqDaKPSFrKuaaipa0ei6iEkrU7upnQYVmqpu6YhKp65AW9qiiq2laUZGgRCHQkKSAoGGvJDYBuIkdmzft7M/7s3qsOec61zfF8Pz/UhR7Of459/jn318r3/nPucRVQURvf0lmj0BImoMJjtRJJjsRJFgshNFgslOFAkmO1EkUtM5WESuAnA3gCSAf1PVr3of39vbq319fdM5JRE59uzZg9dee01CsaqTXUSSAL4B4IMA9gPYIiIPqepO65i+vj5s2bKl2lNSnYgEfzboLSiTyZix6TyNXwlgt6q+rKpZAPcBuHYan4+I6mg6yb4IwL5J7+8vjxHRDFT3G3QiskZE+kWkf2hoqN6nIyLDdJL9AIDFk94/szx2ElVdp6oZVc3MmzdvGqcjoumYTrJvAbBMRJaISAuAGwE8VJtpEVGtVX03XlXzInILgJ+hVHpbr6o7ajYzahhv5SPv1L99TKvOrqoPA3i4RnMhojriK+iIIsFkJ4oEk50oEkx2okgw2YkiwWQnigSTnSgSTHaiSDDZiSLBZCeKBJOdKBLTem18bKwFI1wsQm8FfGQnigSTnSgSTHaiSDDZiSLBZCeKBJOdKBINLb1ls1ns3bs3GDvrrLMaNo+iFs3Yzh3mhjbY+MuNwfHLV60yj+nq6jJjTz1l745z0UUXmrHzzz/fjLW2tgTHWR4kPrITRYLJThQJJjtRJJjsRJFgshNFgslOFIlpld5EZA+AYwAKAPKqau8ED2Dvvn34/K23BWP/8PdfNo+74IJwqckrJxWLdnntiSeeNGN33nW3GXv99cPB8Y1bnzePWbqgy4w9vXWrGfO2ZFr+Drv0tnr1J4Pjl13mfmsoArWos/+Bqr5Wg89DRHXEp/FEkZhusiuAn4vIVhFZU4sJEVF9TPdp/BWqekBETgfwiIj8RlU3Tf6A8i+BNQDQ3t4+zdMRUbWm9ciuqgfK/w8C+BGAlYGPWaeqGVXNtLS2Tud0RDQNVSe7iHSKyOwTbwP4EIDnajUxIqqt6TyNnw/gR+XyVwrAf6jqf3kHjI6M4vHHw2Wvv7zti+Zxn/rTm4Pjy3/vHeYxAwMDZmz9Pd8xY6OjY2bMWlGWLEyYx/T09JqxJRf9vydC/2fX9l+bsQ0bHzVjBw8dCo5/9IY/MY9573vfY8bmzbPnn0wmzRjNPFUnu6q+DODiGs6FiOqIpTeiSDDZiSLBZCeKBJOdKBJMdqJIiLe6qtaSyZR2dM4Oxlpa7BfcdHZ2Bsd7e+2y0Jw5XfZExP4dJ2Jfj9GR0eB435Il5jHjRbvgsXvPfjOWO3LQjE2Mj5sxa7WfdQ0B4IJ3LDdjn/j4R83YR66/1owlEnwcaYZMJoP+/v7gclB+R4giwWQnigSTnSgSTHaiSDDZiSLR0O2fVIvIToTvJKvTM66QzxsRuwddLtFhxjra7Dv/bQl7HhMT2eD4s88+ax+TteYOZMftRTfJpP17OOEsQOmc3RWeB+zrseO3Q2bsn7/xr2asw7nDf/ll7wqO9/T0mMdQffGRnSgSTHaiSDDZiSLBZCeKBJOdKBJMdqJINLj0psibZTSbtc1Tou0085jWdNqMzW6xS3bHRnNmbGTkWHA81dJmHjNr8VIzlh8eNGO5Y2+YMafiiLHRkXAgbS/wybefbcZeevVVM/ZXX/hrM/bhD18THP/8X9xiHrNgwelmzNvqi6aGj+xEkWCyE0WCyU4UCSY7USSY7ESRYLITRaJi6U1E1gO4BsCgql5YHusGcD+APgB7ANygqoendspwCSiVDm+tBACzF4TLV7O67VJNWyK8Qg0Asjl7ZVtHh1NG6wyXqLxtkCaK9jyOOSvsssZ1AoBiwT4unwuXDltgH1Mc3G7Gkk6/vvG8XQ578MEfB8d3Pf+CecwnPnajGXvXpSvMWF+fXTpkL7zfmcqV+DaAq940thbABlVdBmBD+X0imsEqJnt5v/U3v8LjWgD3lt++F8B1tZ0WEdVatc9x5qvqiV7Hh1Da0ZWIZrBpv1xWVVWcZusisgbAmumeh4imp9pH9gERWQgA5f/NF3mr6jpVzahqpspzEVENVJvsDwFYXX57NYCf1GY6RFQvFbd/EpHvAXgfgF4AAwC+BODHAB4AcBaAV1AqvTnLtErSre3as/CccGy2XUZLd4ZXt3UUjprHJGCvriuKvSLOuxqtiXB0dqv9+bzGkQcOHrKPcxZ5tTkNM3NGWW5szG5uWSyc+kpEwF+JZm1D5Wlttcue5513rhlbu/aLZuxdl14SHG9psb9nb+UVdt72TxX/ZlfVm4zQB6Y1KyJqKL7igCgSTHaiSDDZiSLBZCeKBJOdKBIVS2+11NJxms5btioYSxXt0pAUwvvDJVL2Srm2ngX2PE7rNmOatxtOTrweLpVJ1mjyCKDN+XXqNcVsd0pDyZS9ym5wKLxv2/jEhHmMVyYrFgv2PJzVftX8XHnHeHPs6e01Y6tWXhYcv/nmj5vHXHaZ/fqvmV6W80pvfGQnigSTnSgSTHaiSDDZiSLBZCeKBJOdKBIN3esNhRxkJLx3WEHt0koiGS5DtXafYR6TnN1jT8OpChXGj5ux/EQ4lky1m8fk2jvtz5ezz1XM2SWv1qJTojLKV7M6O8xjsjl71VsyZf+IHB8dNWNqfD+9EppXyisU7OsxNGjvmffTh/8zOP7qwYPBcQD4s0/dbMauvPIKMzZnzhwzNhPwkZ0oEkx2okgw2YkiwWQnigSTnSgSjV0I09KqvfPPDMYSabuvmhhbELXOtdvVt8y1e9oVs/aikOzhATOWTIXn2HZ6+GsCgLRzN744ZvfQyznzSDmd8iQfXlCUHXcWwjiVkKSzfVKhYC8aymbDMe+uutevzzvO+xlOpsKVnJa0XWXo7LS/Z5dessKMfeQj15kxb3FNd/fc4LhXnbBwIQwRMdmJYsFkJ4oEk50oEkx2okgw2YkiUXEhjIisB3ANgEFVvbA8dgeAzwA40fDsdlV9uOLZEikkO8NlhpYeu3xVHHk9OD4xbC+AyI8OmzFJ2CUNrwTYMm9RcDzZ5pTXsuH+eQAwMRz+ugBAJ+yefJqyt0kq5MJlKHUWoCScUt74mL3YJeGU5drbw4uDvBLa+LizRZUz/1Ta7kWYMspX3jxGnQU+m5/aYsa2bX/OjPWdfZYZu/jidwbHr7zy3eYx5523LDiey9nl0Kk8sn8bwFWB8a+r6oryv8qJTkRNVTHZVXUTgIqbNhLRzDadv9lvEZFtIrJeRMLPzYloxqg22b8JYCmAFQAOArjT+kARWSMi/SLSX+3WwEQ0fVUlu6oOqGpBS+1IvgVgpfOx61Q1o6qZRLKxjXGI6HeqSnYRWTjp3esB2LchiWhGqLjqTUS+B+B9AHoBDAD4Uvn9FQAUwB4An1VVu6lXWdusLl18YbiHl86xV6mJsboqN/CSeYxO2OUTMXraAUCyxS69tRnlwUSLXfrJHbbLgwVn1VsqbZfXvG2vkAuX+vJjR8xDxCm9eSvbxPnRkUR4m6Si0z/PK4eJU+bzWD/f3jZOaa+U5/Tkq3ZnKKuE2dZm/ywumB9e8fnoo7/E8PDh4EwqPq9W1ZsCw/dUOo6IZha+go4oEkx2okgw2YkiwWQnigSTnSgSDX2VixTzaM2GX2Y/MWqXO/LGqjevYWOi7TQzluywt+nRsWEzNjG0NzguXmPAfNaeh1dec16ApDl7dVhhYiQ4XnRKaCnnXAmj2SfgN6osGmU0r7zm1a7EOZdXPE4YKxzTabv86q3m80rV1TZvtVb0jYzYrzjdfSxcdp6YsBuL8pGdKBJMdqJIMNmJIsFkJ4oEk50oEkx2okg0tvQGRaIQLkW1jx4wjxsZORYcT7ba5bVEz9l
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-1.0007, 0.3057, 1.1468, -2.7095, -2.2706, 2.1910, 2.8463, 2.8963,\n",
" 0.5560, -1.4890, -0.9385, -0.0219, -2.8346, 2.1300]],\n",
" device='cuda:0')\n",
"{0: 'biel', 1: 'czern', 2: 'inny-kolor', 3: 'odcienie-brazu-i-bezu', 4: 'odcienie-czerwieni', 5: 'odcienie-fioletu', 6: 'odcienie-granatowego', 7: 'odcienie-niebieskiego', 8: 'odcienie-pomaranczowego', 9: 'odcienie-rozu', 10: 'odcienie-szarosci-i-srebra', 11: 'odcienie-zieleni', 12: 'odcienie-zoltego-i-zlota', 13: 'wielokolorowy'}\n",
"odcienie-niebieskiego\n"
]
}
],
"source": [
"import requests\n",
"from torch.autograd import Variable\n",
"\n",
"\n",
"url1 = \"https://chillizet-static.hitraff.pl/uploads/productfeeds/images/99/dd/house-klapki-friends-czarny.jpg\"\n",
"url2 = \"https://e-obuwniczy.pl/pol_pl_POLBUTY-BUT-BAL-VENETTO-635-SKORA-LICOWA-CZARNY-2551_5.jpg\"\n",
"url3 = \"https://bhp-nord.pl/33827-thickbox_default/but-s1p-portwest-steelite-tove-ft15.jpg\"\n",
"url4 = \"https://www.sklepmartes.pl/174554-thickbox_default/dzieciece-kalosze-cosy-wellies-kids-2076-victoria-blue-bejo.jpg\"\n",
"\n",
"img = Image.open(requests.get(url3, stream=True).raw)\n",
"\n",
"image_tensor = transform(img).float()\n",
"imshow(image_tensor)\n",
"image_tensor = image_tensor.unsqueeze_(0)\n",
"inputi = Variable(image_tensor)\n",
"\n",
"\n",
"output = net(inputi.to(device))\n",
"_, predicted = torch.max(output.data, 1)\n",
"\n",
"print(output.data)\n",
"idx2class = {v: k for k, v in trainset.class_to_idx.items()}\n",
"print(idx2class)\n",
"\n",
"print(idx2class[int(predicted)])\n",
"\n",
"\n",
"#import pickle\n",
"#a_file = open(\"class-shoe.pkl\", \"wb\")\n",
"#pickle.dump(shoe_names, a_file)\n",
"#a_file.close()\n",
"\n",
"#a_file = open(\"class-shoe.pkl\", \"rb\")\n",
"#outpu = pickle.load(a_file)\n",
"#print(outpu[int(predicted)])\n",
"#a_file.close()"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import os.path\n",
"import csv\n",
"\n",
"if not (os.path.isfile(\"nn-col-state-dict.pth\")):\n",
" torch.save(net.state_dict(), \"nn-col-state-dict.pth\")\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}