Traktor/source/NN/neural_network.py
2024-05-25 22:30:04 +02:00

96 lines
2.8 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt
import numpy as np
from model import *
from PIL import Image
device = torch.device('cuda')
#data transform to tensors:
data_transformer = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
#loading data:
train_set = datasets.ImageFolder(root='resources/train', transform=data_transformer)
test_set = datasets.ImageFolder(root='resources/test', transform=data_transformer)
#to mozna nawet przerzucic do funkcji train:
#train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True, num_workers=2)
#test if classes work properly:
#print(train_set.classes)
#print(train_set.class_to_idx)
#print(train_set.targets[3002])
#function for training model
def train(model, dataset, iter=100, batch_size=64):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model.train()
for epoch in range(iter):
for inputs, labels in train_loader:
optimizer.zero_grad()
output = model(inputs.to(device))
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('epoch: %3d loss: %.4f' % (epoch, loss))
#function for getting accuracy
def accuracy(model, dataset):
model.eval()
with torch.no_grad():
correct = sum([
(model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum()
for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True)
])
return correct.float() / len(dataset)
model = Neural_Network_Model()
model.to(device)
model.load_state_dict(torch.load('model.pth'))
model.eval()
#training the model:
# train(model, train_set)
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
# torch.save(model.state_dict(), 'model.pth')
#TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg'
testImage = Image.open(testImage_path)
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
testImage = testImage.to(device)
model.load_state_dict(torch.load('model.pth'))
model.to(device)
model.eval()
testOutput = model(testImage)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
print(f'The predicted class is: {predicted_class}')