Traktor/source/NN/neural_network.py
2024-06-09 16:37:41 +02:00

127 lines
4.0 KiB
Python

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt
from NN.model import *
from PIL import Image
import pygame
from area.constants import GREY
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#data transform to tensors:
data_transformer = transforms.Compose([
transforms.Resize((100, 100)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5))
])
#loading data:
train_set = datasets.ImageFolder(root='resources/train', transform=data_transformer)
test_set = datasets.ImageFolder(root='resources/test', transform=data_transformer)
#to mozna nawet przerzucic do funkcji train:
# train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
#function for training model
def train(model, dataset, iter=100, batch_size=64):
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.NLLLoss()
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
model.train()
for epoch in range(iter):
for inputs, labels in train_loader:
optimizer.zero_grad()
output = model(inputs.to(device))
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
if epoch % 10 == 0:
print('epoch: %3d loss: %.4f' % (epoch, loss))
#function for getting accuracy
def accuracy(model, dataset):
model.eval()
with torch.no_grad():
correct = sum([
(model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum()
for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True)
])
return correct.float() / len(dataset)
# model = Conv_Neural_Network_Model()
# model.to(device)
#loading the already saved model:
# model.load_state_dict(torch.load('CNN_model.pth'))
# model.eval()
# #training the model:
# train(model, train_set)
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
# torch.save(model.state_dict(), 'CNN_model.pth')
def load_model():
model = Conv_Neural_Network_Model()
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
model.eval()
return model
def load_image(image_path):
testImage = Image.open(image_path).convert('RGB')
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
return testImage
#display the image for prediction next to the field
def display_image(screen, image_path, position):
image = pygame.image.load(image_path)
image = pygame.transform.scale(image, (250, 250))
screen.blit(image, position)
#display result of the guessed image (text under the image)
def display_result(screen, position, predicted_class):
font = pygame.font.Font(None, 30)
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
screen.blit(displayed_text, position)
def clear_text_area(win, x, y, width, height, color=GREY):
pygame.draw.rect(win, color, (x, y, width, height))
pygame.display.update()
def guess_image(model, image_tensor):
with torch.no_grad():
testOutput = model(image_tensor)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
return predicted_class
#TEST - loading the image and getting results:
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
# testImage = Image.open(testImage_path)
# testImage = data_transformer(testImage)
# testImage = testImage.unsqueeze(0)
# testImage = testImage.to(device)
# model.load_state_dict(torch.load('CNN_model.pth'))
# model.to(device)
# model.eval()
# testOutput = model(testImage)
# _, predicted = torch.max(testOutput, 1)
# predicted_class = train_set.classes[predicted.item()]
# print(f'The predicted class is: {predicted_class}')