495 lines
76 KiB
Plaintext
495 lines
76 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Done\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torchvision\n",
|
||
|
"import torchvision.transforms as transforms\n",
|
||
|
"from PIL import Image\n",
|
||
|
"import os\n",
|
||
|
"\n",
|
||
|
"transform = transforms.Compose(\n",
|
||
|
" [ transforms.Resize(32),\n",
|
||
|
" transforms.Pad(10, fill=255),\n",
|
||
|
" transforms.CenterCrop((32, 32)),\n",
|
||
|
" transforms.ToTensor(),\n",
|
||
|
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
|
||
|
" ]\n",
|
||
|
")\n",
|
||
|
"\n",
|
||
|
"def check_image(path):\n",
|
||
|
" try:\n",
|
||
|
" im = Image.open(path)\n",
|
||
|
" im.verify()\n",
|
||
|
" return True\n",
|
||
|
" \n",
|
||
|
" except:\n",
|
||
|
" print(path)\n",
|
||
|
" return False\n",
|
||
|
" finally:\n",
|
||
|
" im.close()\n",
|
||
|
"\n",
|
||
|
"#transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) <-- this could part of the problem because supposedly 0.5 0.5 is used for grayscale images\n",
|
||
|
"trainset = torchvision.datasets.ImageFolder(root='../datasets/Damskie_mini/', transform=transform,is_valid_file = check_image)\n",
|
||
|
"testset = torchvision.datasets.ImageFolder(root='../datasets/Damskie_mini_test/', transform=transform,is_valid_file = check_image)\n",
|
||
|
"#, is_valid_file = check_image\n",
|
||
|
"\n",
|
||
|
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, num_workers=12, drop_last=True)\n",
|
||
|
"testloader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=True, num_workers=12, drop_last=True)\n",
|
||
|
"print(\"Done\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"tensor(-0.7882)\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXAAAAB4CAYAAADrPanmAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAACeAUlEQVR4nOz9aZBt13XfCf72PuMdc858+eYB72EGSBAgQNI0KdmUJdluyZJDkl1hqSKqg/WlIrqiq6PL1V+6FNEdXf5S1RX9oarlctsuy21ZstSaKIkzSIokCGLGm/Hm93Ke7nymPfSHfe7NfAOANwKk6y1EIvNl3nvOPufus/ba//Vf/yWstTy0h/bQHtpD++kz+XEP4KE9tIf20B7a3dlDB/7QHtpDe2g/pfbQgT+0h/bQHtpPqT104A/toT20h/ZTag8d+EN7aA/tof2U2kMH/tAe2kN7aD+ldk8OXAjx80KIM0KIc0KIf3q/BvXQHtpDe2gP7cNN3C0PXAjhAWeBLwHXgB8D/8hae/L+De+hPbSH9tAe2vvZvUTgnwbOWWsvWGtz4PeAX7o/w3poD+2hPbSH9mHm38N79wBXd/z7GvDiB72hWq3a8fHxezjlQ3toD+2h/W/PlpaW1q21Mzf+/l4c+G2ZEOLLwJcBxsbG+PKXv/ygT/nQHtpDe2j/Udlv//ZvX77V7+/FgS8A+3b8e2/5u+vMWvs7wO8A7N692wI8/vjjHDhw4B5O/eDs8uXLnDp1CiEEX/rSl5DyJ5Oo8+1vf5ssy9i3bx9PPvnkXR/HWksyGLBw9SpLC9fQWmOtBWsxxpJmKcYYDh8+wu59exmfmER63m0de2Njgx//+McAfPGLXySO47sepzGGoiiwxiCEIAjD+/bZvPLKK7RaLSYmJnjxxVtvIq2FjZ7i6npKd5CRFwWqyLHaUK3EKKXRSmOtwRgzep/nSTwpMQiMFUjPo1ENmZuosG+mgu+J2xpjkiR85zvfAeDTn/40ExMT937hD8COHz/OwsICcRzzxS9+8eMezi1Na803vvENrLU8+eST7Nu378PfdMe2M7foPuPtfKMtfyuue6UQ268FuHTpEqdPn/7As9yLA/8xcFQIcQjnuH8D+Me388ZDhw7x6U9/+h5O/eBMSjly4C+++CLebTqrj9r++q//mizLmJ2d5aWXXrrr41hraW1uIrRia3WZwlqM1mij0YVCpQnWwvzsDE898SR7Dhy47Xty8eLFkQN/7rnnaDabdz3GXq/L5UuXSJKEarXKocNHqFQqCHF7DvCD7Pjx47RaLRqNxi3vpbUWZeCtywnLso3a6pH0+6S2jzEZUW2KPCvIsxyjFYXVw6eR0JeEgY/BQ1mJHwRMzzTZd2icTx8bJwxubxFqt9sjB/7EE0/8xAZAq6urLCwsEATBPc3LB2lKKb75zW9ireXQoUM8//zz9/X41hqsVlijsFhcqtFijVvgnZ8WSCHL15f/9gKkH153rAfmwK21SgjxXwBfBTzg/2OtPXG3x3toH48No20pJJ6QGCnRSpFlOYPBAIulUqlQrdUIo+hj25EsLCzw737v91hcWmL/vv381m/9FocOHXrg57XWYiy0Bpo3LnS5sNSj3enS73XptdvopIcyEVmuSJIMpQqKonTgAgJfEoY+nhei8YhiS6OuyZW9L4vPgzRrf/LH+JNoRhUUSQud9bBWAT5YjcpTrFGAxfMCPN9DCInRGryIoDZNVA8/7PDX2T1h4NbavwD+4l6O8dA+XnMOykUHAEkyoNfrk6QpqijwfJ+pqSpjExPE1crHOE64dPkaJ06eZHVtg8997nMfkQOHQWZ49VyPswsdut0OaTIgSwdkSRfTXsOoeQQSIQTaON+9vYEWWCMw0v3ek4JK5FEJPbyfTHRuZA+d991Z0l0n2byKyTr4YYjVGq1ytHbQmhAgpQAhEEJgtEVGTeqVsTs+10/4FPrfnlljMFpjSiz6o9Br11rT6/Vot9ukaUaSJBR5jsUhcmEYEVcq+H7wkT7U1lqyLOXUqZP8q3/9r3jv3Dm63S6bm5u8/c7b5Hn+QO+PtZZOojmzlPDu5S6DQZ8kGZCnCTpLsEmHdj8hKxRmJ75ZRt9KGQqlMSVuLz2JkJIokISh5KF//I/TPD9ESoG1GqsLjC4QUiKkAAxGF+gyh6Jyt2uz1iLlncfTD5yF8tA+2CyAdUhZ2m6zfvoU6fIyvhTMfPpFarNzwIOLhoRwCcJ+krC2tk6tWt3+AyCkII5CwjD6SPMB1lra7Tbnzp/j9//97/Otl7/D4uISeZ4Dlh//+HUuXbrI4cNHCPxgZ+7n/o0hz8jaHVqLLVptRTIYUGQZpkihSCiylKSwFErh++7eSCGw5ViELwl8D9+X5fAEUgqiwCPyJffqwR9CHD+ZJqTnvgCGn1GJeRthEcJCiY67IK18n7jzePqhAy9NK0Wv16Pf6320J3YQNNmgz+VXX+Xq979HevUqtTjCBAH7v/BFgrhyzw/7Bw/BYowhzTKiKHJ4uJQuAheCuFIhCIOPDP+21tLtdnnvvff49ssv882XX+bixcvkWeYehG6PS5cu89rrrzE7M0tzbOz+LS7l02SLDNvbwmuvUx9sMaElyypH5ynFoE8+6JIWGoNw8JNwi50QYrSWCCnwve17Ce5+hr4kHDn1u7cPc94PHfzHY0JKEBIQZWRdPktSOshDuESnLPMrQgjnvH+aHPhw66u1Qgq3tfw4JttwHO3WFuffO8u1K7ekWz6wc1tr0XnO5vkLvPkHv8/Sj38EW5tMj41hvYDJY4/R3LsXLwxvohndnzG4757nEwQ+WiuEAN/zXP5cSKq1GkFw/2h7HzAaAIqi4OLFi3zve9/jT/70zzh/4RJ5mrpJLyVKaVrtDt/45jd59plnieKISqV6z/PHWkuWK3yrob+J3Vii2l3ngO3ydATLIiJXA3q9Ft1Ol0Q7hokU4HsSWz6MWHZAKtebQBD4ksC/97HeaiqIHT998P24meZ26989tDs1gUQgRndTeh7WGPAEVgisAWGdc7doLBKk84F3ah+rA1dFwfrKIvXmGNVaA8+/fjgP2qEPnXeR53z9a3/Fd77+NQpj2XPokQd63uG5LZYiTWhdusz3/8f/Fyde/gZBv0cdS2oU177+V8SHD/PUr/8Gjbk5h6He53EYY9BKY4xGSOnGVWK2vu/j+x61Ws1F5g8QQtnJkV24dpU//uM/5mvf+Canz75HlqZgDBaLLbejShtef/1tvvvd71Cr1Th48CD36nT6g4yTJy+xPxzQKFrYpIdNBkRFwlNRDhMeb6YD3pUZW8bNz+nxOnHo4XkCY0paGBZtSgKZsXiA53lY4XjhYSAJbpP/fct7VZ5j6CHEDX9zz82HO2N3z2987e1F7Q+j+5ttNIdFmQaxFmtuWMjLv1k7TGZKrJUIIUe0wjuxj8WBa61ZunaVV777LX708tf5ws9+iU9+7ouMT89graVSrX1kk6MoCl7/0Q/58z/8D2yurTA9v+cjOS9A2u6wcvw4b/6b3+WNl79Bpd+nbhQVwCvA9vu8/e/+HfX9+zn42c8yNj9//wdhLcZodJ7jyaEDciaEIAxDZnbtIqpUHmgEboxhMOjz7jtv82d/9hV+9OprLC4ulclcg7GO3uFLDy/wEZ6k1x/wjW9+m8OHDjM1OUlzbPyexnD16hr/j3/2H/gvfmEfz8yHeDZHFwqrFaE0HI4TGnOaMQEV6XGp7zM7M8nnP3mAq+s5V1d6ZL5HXpQLYIl3B75HYQzC8/ClJPQk3j1QUC5vKJJY4UtB6EMzloQSAg98Ke4IbdPakBc5QkDgBW5ct3j/jQ77ofO+tRmdo7I+RZ6gVI6UEt8EO6hJbtEUwkXgWmmMAIv46YnAF65eYenyBb76lT+nvbbM6698n6XVNawXIH2fX/ilf8D8bTrS4aqnigJweJP03GV90CQzxpAM+ixcucJf/tmf0NpYR+X5NqbwgGw43jxJWHz9Dc587eucevVVksGAXQKqnkckLL4QJNpQbW9y7bvfoTY5SXV8nLBSvb8DcoQJd9/ENoY7jMSjKGJicgo/eHB
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import numpy as np\n",
|
||
|
"\n",
|
||
|
"# functions to show an image\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"def imshow(img):\n",
|
||
|
" img = img / 2 + 0.5 # unnormalize\n",
|
||
|
" npimg = img.numpy()\n",
|
||
|
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
|
||
|
" plt.show()\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"# get some random training images\n",
|
||
|
"dataiter = iter(trainloader)\n",
|
||
|
"images, labels = dataiter.next()\n",
|
||
|
"\n",
|
||
|
"print(images[0].min())\n",
|
||
|
"\n",
|
||
|
"# show images\n",
|
||
|
"imshow(torchvision.utils.make_grid(images))\n",
|
||
|
"\n",
|
||
|
"# print labels\n",
|
||
|
"#print(' '.join('%5s' % classes[labels[j]] for j in range(4)))\n",
|
||
|
"\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda:0\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Net(\n",
|
||
|
" (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n",
|
||
|
" (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n",
|
||
|
" (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n",
|
||
|
" (fc1): Linear(in_features=400, out_features=120, bias=True)\n",
|
||
|
" (fc2): Linear(in_features=120, out_features=84, bias=True)\n",
|
||
|
" (fc3): Linear(in_features=84, out_features=18, bias=True)\n",
|
||
|
")"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import torch\n",
|
||
|
"import torch.nn as nn\n",
|
||
|
"import torch.nn.functional as F\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"class Net(nn.Module):\n",
|
||
|
" def __init__(self):\n",
|
||
|
" super(Net, self).__init__()\n",
|
||
|
" self.conv1 = nn.Conv2d(3, 6, 5)\n",
|
||
|
" self.pool = nn.MaxPool2d(2, 2)\n",
|
||
|
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
|
||
|
" self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
|
||
|
" self.fc2 = nn.Linear(120, 84)\n",
|
||
|
" self.fc3 = nn.Linear(84, 18)\n",
|
||
|
"\n",
|
||
|
" def forward(self, x):\n",
|
||
|
" x = self.pool(F.relu(self.conv1(x)))\n",
|
||
|
" x = self.pool(F.relu(self.conv2(x)))\n",
|
||
|
" x = x.view(-1, 16 * 5 * 5)\n",
|
||
|
" x = F.relu(self.fc1(x))\n",
|
||
|
" x = F.relu(self.fc2(x))\n",
|
||
|
" x = self.fc3(x)\n",
|
||
|
" return x\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"net = Net()\n",
|
||
|
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||
|
"print(device)\n",
|
||
|
"net.to(device)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 17,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"cuda:0\n",
|
||
|
"wololo1\n",
|
||
|
"wololo3\n",
|
||
|
"wololo4\n",
|
||
|
"[1, 498] loss: 2.228\n",
|
||
|
"[1, 998] loss: 2.298\n",
|
||
|
"Accuracy of the network on the 10000 test images: 29.24 %\n",
|
||
|
"[2, 498] loss: 2.171\n",
|
||
|
"[2, 998] loss: 2.312\n",
|
||
|
"Accuracy of the network on the 10000 test images: 30.47 %\n",
|
||
|
"[3, 498] loss: 2.146\n",
|
||
|
"[3, 998] loss: 2.262\n",
|
||
|
"Accuracy of the network on the 10000 test images: 28.68 %\n",
|
||
|
"[4, 498] loss: 2.201\n",
|
||
|
"[4, 998] loss: 2.335\n",
|
||
|
"Accuracy of the network on the 10000 test images: 31.08 %\n",
|
||
|
"[5, 498] loss: 2.197\n",
|
||
|
"[5, 998] loss: 2.223\n",
|
||
|
"Accuracy of the network on the 10000 test images: 32.20 %\n",
|
||
|
"[6, 498] loss: 2.279\n",
|
||
|
"[6, 998] loss: 2.253\n",
|
||
|
"Accuracy of the network on the 10000 test images: 24.39 %\n",
|
||
|
"[7, 498] loss: 2.237\n",
|
||
|
"[7, 998] loss: 2.246\n",
|
||
|
"Accuracy of the network on the 10000 test images: 32.92 %\n",
|
||
|
"[8, 498] loss: 2.223\n",
|
||
|
"[8, 998] loss: 2.195\n",
|
||
|
"Accuracy of the network on the 10000 test images: 23.49 %\n",
|
||
|
"[9, 498] loss: 2.249\n",
|
||
|
"[9, 998] loss: 2.239\n",
|
||
|
"Accuracy of the network on the 10000 test images: 27.79 %\n",
|
||
|
"[10, 498] loss: 2.261\n",
|
||
|
"[10, 998] loss: 2.262\n",
|
||
|
"Accuracy of the network on the 10000 test images: 28.63 %\n",
|
||
|
"[11, 498] loss: 2.238\n",
|
||
|
"[11, 998] loss: 2.299\n",
|
||
|
"Accuracy of the network on the 10000 test images: 26.84 %\n",
|
||
|
"[12, 498] loss: 2.198\n",
|
||
|
"[12, 998] loss: 2.344\n",
|
||
|
"Accuracy of the network on the 10000 test images: 30.41 %\n",
|
||
|
"[13, 498] loss: 2.240\n",
|
||
|
"[13, 998] loss: 2.282\n",
|
||
|
"Accuracy of the network on the 10000 test images: 28.68 %\n",
|
||
|
"[14, 498] loss: 2.263\n",
|
||
|
"[14, 998] loss: 2.230\n",
|
||
|
"Accuracy of the network on the 10000 test images: 31.75 %\n",
|
||
|
"[15, 498] loss: 2.298\n",
|
||
|
"[15, 998] loss: 2.278\n",
|
||
|
"Accuracy of the network on the 10000 test images: 30.97 %\n",
|
||
|
"Finished Training\n",
|
||
|
"CPU times: user 1min 23s, sys: 20.6 s, total: 1min 44s\n",
|
||
|
"Wall time: 2min 59s\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"%%time\n",
|
||
|
"\n",
|
||
|
"import torch.optim as optim\n",
|
||
|
"\n",
|
||
|
"# Assuming that we are on a CUDA machine, this should print a CUDA device:\n",
|
||
|
"\n",
|
||
|
"print(device)\n",
|
||
|
"\n",
|
||
|
"BATCH_SIZE = 16\n",
|
||
|
"EPOCHS = 15\n",
|
||
|
"OUTPUTS= 1\n",
|
||
|
"LR = 0.025\n",
|
||
|
"MINI_BATCH_SIZE = 500\n",
|
||
|
"\n",
|
||
|
"print(\"wololo1\")\n",
|
||
|
"\n",
|
||
|
"criterion = nn.CrossEntropyLoss()\n",
|
||
|
"print(\"wololo3\")\n",
|
||
|
"\n",
|
||
|
"optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)\n",
|
||
|
"print(\"wololo4\")\n",
|
||
|
"\n",
|
||
|
"for epoch in range(EPOCHS): # loop over the dataset multiple times\n",
|
||
|
"\n",
|
||
|
" running_loss = 0.0\n",
|
||
|
" for i, data in enumerate(trainloader, 0):\n",
|
||
|
" # get the inputs; data is a list of [inputs, labels]\n",
|
||
|
" \n",
|
||
|
" inputs, labels = data[0].to(device), data[1].to(device)\n",
|
||
|
" #inputs, labels = data\n",
|
||
|
" \n",
|
||
|
" # zero the parameter gradients\n",
|
||
|
" optimizer.zero_grad()\n",
|
||
|
"\n",
|
||
|
" # forward + backward + optimize\n",
|
||
|
" outputs = net(inputs)\n",
|
||
|
" loss = criterion(outputs, labels)\n",
|
||
|
" loss.backward()\n",
|
||
|
" optimizer.step()\n",
|
||
|
"\n",
|
||
|
" # print statistics\n",
|
||
|
" running_loss += loss.item()\n",
|
||
|
" if i % MINI_BATCH_SIZE == MINI_BATCH_SIZE - 1: # print every 2000 mini-batches\n",
|
||
|
" print('[%d, %5d] loss: %.3f' %\n",
|
||
|
" (epoch + 1, i-1,running_loss / MINI_BATCH_SIZE))\n",
|
||
|
" running_loss = 0.0\n",
|
||
|
" \n",
|
||
|
" correct = 0\n",
|
||
|
" total = 0\n",
|
||
|
" with torch.no_grad():\n",
|
||
|
" for data in testloader:\n",
|
||
|
" images, labels = data[0].to(device), data[1].to(device)\n",
|
||
|
" outputs = net(images)\n",
|
||
|
" _, predicted = torch.max(outputs.data, 1)\n",
|
||
|
" print(outputs)\n",
|
||
|
" total += labels.size(0)\n",
|
||
|
" correct += (predicted == labels).sum().item()\n",
|
||
|
" print('Accuracy of the network on the 10000 test images: %.2f %%' % (\n",
|
||
|
" 100.0 * correct / total))\n",
|
||
|
"\n",
|
||
|
"print('Finished Training')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"ename": "NameError",
|
||
|
"evalue": "name 'Image' is not defined",
|
||
|
"output_type": "error",
|
||
|
"traceback": [
|
||
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
|
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||
|
"\u001b[0;32m<ipython-input-1-9770c3effe28>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0murl4\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"https://www.sklepmartes.pl/174554-thickbox_default/dzieciece-kalosze-cosy-wellies-kids-2076-victoria-blue-bejo.jpg\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrequests\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0murl4\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstream\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mraw\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0mimage_tensor\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
|
||
|
"\u001b[0;31mNameError\u001b[0m: name 'Image' is not defined"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import requests\n",
|
||
|
"from torch.autograd import Variable\n",
|
||
|
"\n",
|
||
|
"\n",
|
||
|
"url1 = \"https://chillizet-static.hitraff.pl/uploads/productfeeds/images/99/dd/house-klapki-friends-czarny.jpg\"\n",
|
||
|
"url2 = \"https://e-obuwniczy.pl/pol_pl_POLBUTY-BUT-BAL-VENETTO-635-SKORA-LICOWA-CZARNY-2551_5.jpg\"\n",
|
||
|
"url3 = \"https://bhp-nord.pl/33827-thickbox_default/but-s1p-portwest-steelite-tove-ft15.jpg\"\n",
|
||
|
"url4 = \"https://www.sklepmartes.pl/174554-thickbox_default/dzieciece-kalosze-cosy-wellies-kids-2076-victoria-blue-bejo.jpg\"\n",
|
||
|
"\n",
|
||
|
"img = Image.open(requests.get(url4, stream=True).raw)\n",
|
||
|
"\n",
|
||
|
"image_tensor = transform(img).float()\n",
|
||
|
"imshow(image_tensor)\n",
|
||
|
"image_tensor = image_tensor.unsqueeze_(0)\n",
|
||
|
"inputi = Variable(image_tensor)\n",
|
||
|
"\n",
|
||
|
"shoe_names = { 0: \"Balerinki\", 1: \"Botki\", 2:\"Creepersy\", 3: \"Czolenka\", 4:\"Domowe\", 5:\"Espadryle\",\n",
|
||
|
" 6:\"Glany\", 7:\"Kalosze\", 8:\"Klapki\", 9:\"Kozaki\", 10:\"Mokasyny\", 11:\"Polbuty\", 12:\"Pozostale\",\n",
|
||
|
" 13:\"Sandaly\", 14:\"Sniegowce\", 15:\"Sportowe\", 16:\"Tenisowki\", 17:\"Trekkingowe\"}\n",
|
||
|
"\n",
|
||
|
"output = net(inputi.to(device))\n",
|
||
|
"_, predicted = torch.max(output.data, 1)\n",
|
||
|
"\n",
|
||
|
"print(shoe_names[int(predicted)])"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Trekkingowe\n",
|
||
|
"Balerinki\n",
|
||
|
"Sniegowce\n",
|
||
|
"Mokasyny\n",
|
||
|
"Czolenka\n",
|
||
|
"Domowe\n",
|
||
|
"Glany\n",
|
||
|
"Creepersy\n",
|
||
|
"Polbuty\n",
|
||
|
"Sandaly\n",
|
||
|
"Pozostale\n",
|
||
|
"Botki\n",
|
||
|
"Kalosze\n",
|
||
|
"Kozaki\n",
|
||
|
"Espadryle\n",
|
||
|
"Tenisowki\n",
|
||
|
"Sportowe\n",
|
||
|
"Klapki\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"{'Balerinki': 0,\n",
|
||
|
" 'Botki': 1,\n",
|
||
|
" 'Creepersy': 2,\n",
|
||
|
" 'Czolenka': 3,\n",
|
||
|
" 'Domowe': 4,\n",
|
||
|
" 'Espadryle': 5,\n",
|
||
|
" 'Glany': 6,\n",
|
||
|
" 'Kalosze': 7,\n",
|
||
|
" 'Klapki': 8,\n",
|
||
|
" 'Kozaki': 9,\n",
|
||
|
" 'Mokasyny': 10,\n",
|
||
|
" 'Polbuty': 11,\n",
|
||
|
" 'Pozostale': 12,\n",
|
||
|
" 'Sandaly': 13,\n",
|
||
|
" 'Sniegowce': 14,\n",
|
||
|
" 'Sportowe': 15,\n",
|
||
|
" 'Tenisowki': 16,\n",
|
||
|
" 'Trekkingowe': 17}"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 14,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"import os, sys\n",
|
||
|
"\n",
|
||
|
"# Open a file\n",
|
||
|
"path = \"../datasets/Damskie_mini/\"\n",
|
||
|
"dirs = os.listdir( path )\n",
|
||
|
"\n",
|
||
|
"# This would print all the files and directories\n",
|
||
|
"for file in dirs:\n",
|
||
|
" print(file)\n",
|
||
|
"\n",
|
||
|
"trainset.class_to_idx"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAVLElEQVR4nO3da2ycVXoH8P8z4/HYHt/jiXGTgBMSLlkKAUwKgq5Y2N0CXQRUFYKqiA9os1otUqm2HxCVCiv1A1sVEB8qqrBEsCvKZRcQaQu7hCyFZVsFDIQkJMstG9gEx5ckTny3Z+bph3mjOu55jp25Ojn/nxRlfB6feY9f+5nXfp8554iqgohOf7FqD4CIKoPJThQIJjtRIJjsRIFgshMFgslOFIiaYjqLyHUAHgUQB/ATVX3Q9/kdHR3a3d1dzCGrqpAypYiUYSREbvv27cPQ0JDzh67gZBeROIB/AfAtAPsBvCsim1V1t9Wnu7sbvb29hR6yInwJPTU14w548rkuWVvkiIgWrqenx4wV82v8egCfqepeVZ0G8CyAm4p4PiIqo2KSfRmAP8z6eH/URkSLUNlv0InIBhHpFZHewcHBch+OiAzFJPsBACtmfbw8ajuBqm5U1R5V7Umn00UcjoiKUUyyvwtgjYisFJFaALcB2FyaYRFRqRV8N15VMyJyN4BfIV9626SqH5VsZGXku+P+q9feNGNvvr3T2X7euavMPn99+/VmLB7n2xyocoqqs6vqKwBeKdFYiKiMeGkhCgSTnSgQTHaiQDDZiQLBZCcKRFF3409V09PGhBYATz/3shn7zW/fc7Zf1rPO7LNkaZsZy2SyZqw2Yb8OtzY32sdrb3G2t7W1mn2am1JmLFmbMGOc0Hdq4ZWdKBBMdqJAMNmJAsFkJwoEk50oEEHejT9ybNKM9R2ZMmOjYyPO9te3bjH7bP2v182Yb0m7eDxuxuqSSTPW2tLqbO/s7DT7LF9xlhn74wvWmrH1PReasbXnuScHLWm17/xzvb7y4pWdKBBMdqJAMNmJAsFkJwoEk50oEEx2okAEWXobOGpPhDkS/yMzlmrrcLYPHfjE7JPN2ZNdROzXWl9scmLMjI2MuMuDBwf6zT47d+0yY6/+0l51rCHVbMbWXbre2f6j+/7W7HPR+dx2oJx4ZScKBJOdKBBMdqJAMNmJAsFkJwoEk50oEEWV3kRkH4ARAFkAGVW1d4JfRNoa7Vljqy/8phnLtrhngMX7dpt9pg/tNWMz48N2bNqemQfY0+XiMffMsaYm99p0ANDY4i4pAsDItP0jUt9s95tuWuluz9lr2lF5laLO/g1VHSrB8xBRGfHXeKJAFJvsCuA1EXlPRDaUYkBEVB7F/hp/laoeEJGlALaIyO9U9a3ZnxC9CGwAgDPPPLPIwxFRoYq6sqvqgej/AQAvAfh/b4hW1Y2q2qOqPel0upjDEVERCk52EUmJSNPxxwC+DcCeUUFEVVXMr/GdAF6KFgmsAfBvqvrLkoyqzBrr7fLPyvYGMzY0fIazvSG11OxTs/pqM5bLTJuxQ8bsNQAYm7TLcprLONuzYpfrjnoWepzJ2bFEfZ0Zq21pd7a3eMqeVF4FJ7uq7gVwUQnHQkRlxNIbUSCY7ESBYLITBYLJThQIJjtRIIJccLKh3v6yV55hl5O+HHLvAzc6mTP7TM/YsQxqzVgqaZcA+yfc5TUAGJ0yYll7kc047DF2puzz0Za096Nbla5392myv2YqL17ZiQLBZCcKBJOdKBBMdqJAMNmJAhHk3fh43J7ckaqzX//qjLOV9SyrZh8J8MxNQaunYrAia08m6YP7rrtk7K+rrdY+VkfK/uI6mu0762uWpZztLbwbXzW8shMFgslOFAgmO1EgmOxEgWCyEwWCyU4UiDBLb8YWSQDQ0miXhmLGS2M2Z9fQsll7kol/Gyc7tqTOnoDSFDPKcll7FPGcPUkmVeOe/JMfh12WW9roPmBNnNeXauGZJwoEk50oEEx2okAw2YkCwWQnCgSTnSgQ85beRGQTgO8AGFDVC6K2dgDPAegGsA/Arap6pHzDLC3xbHfU5pnJlaxx97NKcvljLXhYc/rZTxqL2aW3hBETtZ+vMTZmxnpW2mvhrep2z2wDgBVL3ecx5il7Unkt5Mr+JIDr5rTdC2Crqq4BsDX6mIgWsXmTPdpv/fCc5psAPBU9fgrAzaUdFhGVWqF/s3eqal/0+CDyO7oS0SJW9A06VVV43vcpIhtEpFdEegcHB4s9HBEVqNBk7xeRLgCI/h+wPlFVN6pqj6r2pNPpAg9HRMUqNNk3A7gzenwngJdLMxwiKpeFlN6eAXA1gA4R2Q/gfgAPAnheRO4C8AWAW8s5yEryld4SRukt7qmveRec9Mx686xFiZjveNZimmr3ias9M69v8Kg9jroWM3bZ+jYzRtUxb7Kr6u1G6NoSj4WIyojvoCMKBJOdKBBMdqJAMNmJAsFkJwpEkAtO+jSn7NJbXcL92hjzbNrmm+TlL8t5YoVMHPN0Gs3Ys9c+P2b/iCQzS8xYe7tdlqPq4JWdKBBMdqJAMNmJAsFkJwoEk50oEEx2okCw9DZHqt4+JQ1J92ujeueo+fj6+WJ2GU2NEpt4+uTELjf6fkTamuvMWEMdf7QWG17ZiQLBZCcKBJOdKBBMdqJAMNmJAsFbpnMkk/bWSs2phLM9Ya37BiDueTnNeKe7FMr9nDnf3Xj13Pn3xOqT9o9PTQ2vI4sNvyNEgWCyEwWCyU4UCCY7USCY7ESBYLITBWIh2z9tAvAdAAOqekHU9gCA7wI4vi3rfar6SrkGWUk1nlqZtf1T0lNmyiTs0lU2Y8dyOXtLJu9EGOP1W9UeY86zNVQu51lfz7PAnm+LKqqOhVzZnwRwnaP9EVVdF/07LRKd6HQ2b7Kr6lsADldgLERURsX8zX63iOwQkU0iwi07iRa5QpP9MQBnA1gHoA/AQ9YnisgGEekVkd7BwUHr04iozApKdlXtV9WsquYAPA5gvedzN6pqj6r2pNPpQsdJREUqKNlFpGvWh7cA2FWa4RBRuSyk9PYMgKsBdIjIfgD3A7haRNYhv1DaPgDfK98QK8s3A+zwyIyzfWTc3Q74t2qqsSfYIZv1zVKzYyLuJ815XtdHpjJm7Kgn1nd00ox5Z9JRVcyb7Kp6u6P5iTKMhYjKiO+gIwoEk50oEEx2okAw2YkCwWQnCgQXnJxjJmPPNvvqyJSzfeDotNknacyUAwDPpDEv31ZOEPfr94zndf33wxNmbMxzPj78csSMDY+6z0lne73Zh8qLV3aiQDDZiQLBZCcKBJOdKBBMdqJAMNmJAsHS2xzjE/YsryMj7nLS2IxdnjoykTVjtZ494uo9U+Linql01p5uY/aXhSnvj4FdVvz0q2Ez9vlXo852lt6qh1d2okAw2YkCwWQnCgSTnSgQTHaiQPBu/By+u/FDxt34Q5416HxbQ43M2Hfqp7P2c9Z4ZtAkjdio5+uazvm2f7Jj49N2FeI3O4ac7T3ntpt9ahOeRfmoaLyyEwWCyU4UCCY7USCY7ESBYLITBYLJThSIhWz/tALATwF0Ir/d00ZVfVRE2gE8B6Ab+S2gblXVI+UbamVYa6cBwJdH3Gu1HfZMdlnWZk/8EM8OSZ2NSTN2dMIuy01m3CW2xjr7+VL1KTM2MWmXwzKestz2z4452z/42P4RuexrS8xYzLePFi3IQq7sGQA/VNW1AC4H8AMRWQvgXgBbVXUNgK3Rx0S0SM2b7Krap6rvR49HAOwBsAzATQCeij7tKQA3l2mMRFQCJ/U3u4h0A7gYwDYAnaraF4UOIv9rPhEtUgtOdhFpBPACgHtU9YQ/yFRVkf973tVvg4j0ikjv4OBgUYMlosItKNlFJIF8oj+tqi9Gzf0i0hXFuwAMuPqq6kZV7VHVnnQ6XYoxE1EB5k12ERHk92Pfo6oPzwptBnBn9PhOAC+XfnhEVCoLmfV2JYA7AOwUke1R230AHgTwvIjcBeALALeWZYQVls3a9bBk3P3aODw1afZJjNmlq/PTjWasodb+1iypS5ixmYx7/LUx+3W9xhObzjSYsfaGOjN
|
||
|
"text/plain": [
|
||
|
"<Figure size 432x288 with 1 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {
|
||
|
"needs_background": "light"
|
||
|
},
|
||
|
"output_type": "display_data"
|
||
|
},
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"Sniegowce\n"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"'\\n??????????????\\nimport os, sys\\n\\n# Open a file\\npath = \"./datasets/Damskie_mini/\"\\ndirs = os.listdir( path )\\n\\n# This would print all the files and directories\\nfor file in dirs:\\n print(file)\\n'"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": []
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 19,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import numpy as np\n",
|
||
|
"import os.path\n",
|
||
|
"import csv\n",
|
||
|
"\n",
|
||
|
"torch.save(net, 'nn.pth')\n",
|
||
|
"\n",
|
||
|
"torch.save(net.state_dict(), \"nn-state-dict.pth\")\n",
|
||
|
" "
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.6.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|