140 lines
4.4 KiB
Python
140 lines
4.4 KiB
Python
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from torch import nn
|
||
|
from torch import optim
|
||
|
from torchvision import datasets, transforms, models
|
||
|
import torchvision
|
||
|
|
||
|
PATH = './cifar_net.pth'
|
||
|
|
||
|
class Net(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.conv1 = nn.Conv2d(3, 6, 5)
|
||
|
self.pool = nn.MaxPool2d(2, 2)
|
||
|
self.conv2 = nn.Conv2d(6, 16, 5)
|
||
|
self.fc1 = nn.Linear(44944, 120)
|
||
|
#self.fc2 = nn.Linear(120, 84)
|
||
|
#self.fc3 = nn.Linear(84, 10)
|
||
|
self.fc2 = nn.Linear(120, 2)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.pool(F.relu(self.conv1(x)))
|
||
|
x = self.pool(F.relu(self.conv2(x)))
|
||
|
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
||
|
x = F.relu(self.fc1(x))
|
||
|
#x = F.relu(self.fc2(x))
|
||
|
#x = self.fc3(x)
|
||
|
x = self.fc2(x)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class NetRunner:
|
||
|
@staticmethod
|
||
|
def imshow(img):
|
||
|
npimg = img.numpy()
|
||
|
plt.imshow(np.transpose(npimg, (1,2,0)))
|
||
|
plt.show()
|
||
|
plt.close()
|
||
|
|
||
|
@staticmethod
|
||
|
def train(trainloader):
|
||
|
|
||
|
net = Net()
|
||
|
|
||
|
criterion = nn.CrossEntropyLoss()
|
||
|
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||
|
|
||
|
for epoch in range(2):
|
||
|
|
||
|
running_loss = 0.0
|
||
|
for i, data in enumerate(trainloader, 0):
|
||
|
|
||
|
inputs, labels = data
|
||
|
|
||
|
#labels = labels.unsqueeze(-1)
|
||
|
#labels = labels.float()
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
|
||
|
outputs = net(inputs)
|
||
|
loss = criterion(outputs, labels)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
running_loss += loss.item()
|
||
|
if i % 2000 == 1999: # 2000 i 1999
|
||
|
print('[%d, %5d] loss: %.3f' %
|
||
|
(epoch + 1, i + 1, running_loss / 2000)) # 2000
|
||
|
running_loss = 0.0
|
||
|
|
||
|
print('Finished Training')
|
||
|
|
||
|
torch.save(net.state_dict(), PATH)
|
||
|
|
||
|
|
||
|
def test(self,testloader, classes):
|
||
|
|
||
|
dataiter = iter(testloader)
|
||
|
images, labels = dataiter.next()
|
||
|
|
||
|
self.imshow(torchvision.utils.make_grid(images))
|
||
|
|
||
|
net = Net()
|
||
|
net.load_state_dict(torch.load(PATH))
|
||
|
outputs = net(images)
|
||
|
|
||
|
_, predicted = torch.max(outputs, 1)
|
||
|
|
||
|
print(predicted)
|
||
|
print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
|
||
|
for j in range(1)))
|
||
|
|
||
|
# Obliczanie %
|
||
|
"""correct = 0
|
||
|
total = 0
|
||
|
with torch.no_grad():
|
||
|
for data in testloader:
|
||
|
images, labels = data
|
||
|
# calculate outputs by running images through the network
|
||
|
outputs = net(images)
|
||
|
# the class with the highest energy is what we choose as prediction
|
||
|
_, predicted = torch.max(outputs.data, 1)
|
||
|
total += labels.size(0)
|
||
|
correct += (predicted == labels).sum().item()
|
||
|
|
||
|
print('Accuracy of the network: %d %%' % (
|
||
|
100 * correct / total))"""
|
||
|
return
|
||
|
|
||
|
|
||
|
def prepare_data(self):
|
||
|
data_dir = 'dataset'
|
||
|
train_dir = data_dir + '/train'
|
||
|
test_dir = data_dir + '/test'
|
||
|
|
||
|
trainning_transforms = transforms.Compose([ transforms.Resize(256),
|
||
|
transforms.CenterCrop(224),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
|
||
|
|
||
|
testing_transforms = transforms.Compose([ transforms.Resize(256),
|
||
|
transforms.CenterCrop(224),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
|
||
|
|
||
|
training_dataset = datasets.ImageFolder(train_dir, transform=trainning_transforms)
|
||
|
testing_dataset = datasets.ImageFolder(test_dir, transform=testing_transforms)
|
||
|
|
||
|
trainloader = torch.utils.data.DataLoader(training_dataset, batch_size=4, shuffle=True)
|
||
|
testloader = torch.utils.data.DataLoader(testing_dataset, batch_size=1, shuffle=True)
|
||
|
|
||
|
classes = ('rock', 'grenade')
|
||
|
|
||
|
#train(trainloader)
|
||
|
self.test(testloader, classes)
|
||
|
|
||
|
|