2024-05-25 02:07:27 +02:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.utils.data import DataLoader
|
2024-05-25 18:41:25 +02:00
|
|
|
from torchvision import datasets, transforms, utils
|
2024-05-25 02:07:27 +02:00
|
|
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
|
|
|
import matplotlib.pyplot as plt
|
2024-05-27 05:28:48 +02:00
|
|
|
from NN.model import *
|
2024-05-25 22:30:04 +02:00
|
|
|
from PIL import Image
|
2024-06-04 16:55:27 +02:00
|
|
|
import pygame
|
2024-05-25 02:07:27 +02:00
|
|
|
|
2024-05-27 05:28:48 +02:00
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
2024-05-25 18:41:25 +02:00
|
|
|
|
2024-05-25 02:07:27 +02:00
|
|
|
#data transform to tensors:
|
2024-05-25 18:41:25 +02:00
|
|
|
data_transformer = transforms.Compose([
|
2024-05-26 19:56:18 +02:00
|
|
|
transforms.Resize((100, 100)),
|
2024-05-25 02:07:27 +02:00
|
|
|
transforms.ToTensor(),
|
2024-05-26 19:56:18 +02:00
|
|
|
transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5))
|
2024-05-25 02:07:27 +02:00
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
#loading data:
|
2024-05-25 16:33:34 +02:00
|
|
|
train_set = datasets.ImageFolder(root='resources/train', transform=data_transformer)
|
|
|
|
test_set = datasets.ImageFolder(root='resources/test', transform=data_transformer)
|
2024-05-25 02:07:27 +02:00
|
|
|
|
|
|
|
|
|
|
|
#to mozna nawet przerzucic do funkcji train:
|
2024-05-26 19:56:18 +02:00
|
|
|
# train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
|
|
|
|
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
|
2024-05-25 16:33:34 +02:00
|
|
|
|
2024-05-25 02:07:27 +02:00
|
|
|
|
2024-05-25 16:33:34 +02:00
|
|
|
#function for training model
|
|
|
|
def train(model, dataset, iter=100, batch_size=64):
|
2024-05-27 04:27:46 +02:00
|
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
2024-05-25 16:33:34 +02:00
|
|
|
criterion = nn.NLLLoss()
|
|
|
|
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
for epoch in range(iter):
|
|
|
|
for inputs, labels in train_loader:
|
|
|
|
optimizer.zero_grad()
|
|
|
|
output = model(inputs.to(device))
|
|
|
|
loss = criterion(output, labels.to(device))
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
2024-05-25 18:41:25 +02:00
|
|
|
if epoch % 10 == 0:
|
|
|
|
print('epoch: %3d loss: %.4f' % (epoch, loss))
|
|
|
|
|
2024-05-25 16:33:34 +02:00
|
|
|
#function for getting accuracy
|
|
|
|
def accuracy(model, dataset):
|
|
|
|
model.eval()
|
2024-05-25 22:30:04 +02:00
|
|
|
with torch.no_grad():
|
|
|
|
correct = sum([
|
|
|
|
(model(inputs.to(device)).argmax(dim=1) == labels.to(device)).sum()
|
|
|
|
for inputs, labels in DataLoader(dataset, batch_size=64, shuffle=True)
|
|
|
|
])
|
2024-05-25 16:33:34 +02:00
|
|
|
|
|
|
|
return correct.float() / len(dataset)
|
|
|
|
|
2024-05-25 02:07:27 +02:00
|
|
|
|
2024-05-27 05:28:48 +02:00
|
|
|
# model = Conv_Neural_Network_Model()
|
|
|
|
# model.to(device)
|
2024-05-25 22:30:04 +02:00
|
|
|
|
2024-05-26 19:56:18 +02:00
|
|
|
#loading the already saved model:
|
2024-05-27 05:28:48 +02:00
|
|
|
# model.load_state_dict(torch.load('CNN_model.pth'))
|
|
|
|
# model.eval()
|
2024-05-25 22:30:04 +02:00
|
|
|
|
2024-05-27 04:27:46 +02:00
|
|
|
# #training the model:
|
2024-05-25 22:30:04 +02:00
|
|
|
# train(model, train_set)
|
|
|
|
# print(f"Accuracy of the network is: {100*accuracy(model, test_set)}%")
|
2024-05-27 04:27:46 +02:00
|
|
|
# torch.save(model.state_dict(), 'CNN_model.pth')
|
2024-05-25 22:30:04 +02:00
|
|
|
|
|
|
|
|
2024-05-26 19:56:18 +02:00
|
|
|
|
|
|
|
def load_model():
|
2024-05-26 23:28:22 +02:00
|
|
|
model = Conv_Neural_Network_Model()
|
2024-05-27 05:28:48 +02:00
|
|
|
model.load_state_dict(torch.load('CNN_model.pth', map_location=torch.device('cpu')))
|
2024-05-26 19:56:18 +02:00
|
|
|
model.eval()
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
def load_image(image_path):
|
2024-05-27 05:28:48 +02:00
|
|
|
testImage = Image.open(image_path).convert('RGB')
|
2024-05-26 19:56:18 +02:00
|
|
|
testImage = data_transformer(testImage)
|
|
|
|
testImage = testImage.unsqueeze(0)
|
|
|
|
return testImage
|
|
|
|
|
2024-06-04 16:55:27 +02:00
|
|
|
def display_image(screen, image_path, position):
|
|
|
|
image = pygame.image.load(image_path)
|
|
|
|
image = pygame.transform.scale(image, (250, 250))
|
|
|
|
screen.blit(image, position)
|
|
|
|
|
|
|
|
def display_result(screen, position, predicted_class):
|
|
|
|
font = pygame.font.Font(None, 30)
|
|
|
|
displayed_text = font.render("The predicted image is: "+str(predicted_class), 1, (255,255,255))
|
|
|
|
screen.blit(displayed_text, position)
|
|
|
|
|
2024-05-26 19:56:18 +02:00
|
|
|
def guess_image(model, image_tensor):
|
|
|
|
with torch.no_grad():
|
|
|
|
testOutput = model(image_tensor)
|
|
|
|
_, predicted = torch.max(testOutput, 1)
|
|
|
|
predicted_class = train_set.classes[predicted.item()]
|
|
|
|
return predicted_class
|
2024-05-25 22:30:04 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
2024-05-26 23:28:22 +02:00
|
|
|
#TEST - loading the image and getting results:
|
2024-05-27 05:28:48 +02:00
|
|
|
# testImage_path = 'resources/images/plant_photos/1c76aa4d-11f4-47d1-8bdd-2cb78deeeccf.jpg'
|
|
|
|
# testImage = Image.open(testImage_path)
|
|
|
|
# testImage = data_transformer(testImage)
|
|
|
|
# testImage = testImage.unsqueeze(0)
|
|
|
|
# testImage = testImage.to(device)
|
|
|
|
|
|
|
|
# model.load_state_dict(torch.load('CNN_model.pth'))
|
|
|
|
# model.to(device)
|
|
|
|
# model.eval()
|
|
|
|
|
|
|
|
# testOutput = model(testImage)
|
|
|
|
# _, predicted = torch.max(testOutput, 1)
|
|
|
|
# predicted_class = train_set.classes[predicted.item()]
|
|
|
|
# print(f'The predicted class is: {predicted_class}')
|