From b45c2e0f1f0c2330123d095a2be21e0c14254da0 Mon Sep 17 00:00:00 2001 From: MarRac Date: Sun, 26 May 2024 19:56:18 +0200 Subject: [PATCH] added functions for loading images, model and testing --- source/NN/neural_network.py | 53 +++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/source/NN/neural_network.py b/source/NN/neural_network.py index fa61040..fcd2f95 100644 --- a/source/NN/neural_network.py +++ b/source/NN/neural_network.py @@ -4,17 +4,16 @@ from torch.utils.data import DataLoader from torchvision import datasets, transforms, utils from torchvision.transforms import Compose, Lambda, ToTensor import matplotlib.pyplot as plt -import numpy as np -from model import * +from .model import * from PIL import Image device = torch.device('cuda') #data transform to tensors: data_transformer = transforms.Compose([ - transforms.Resize((150, 150)), + transforms.Resize((100, 100)), transforms.ToTensor(), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5)) ]) @@ -24,13 +23,8 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme #to mozna nawet przerzucic do funkcji train: -#train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2) -#test_loader = DataLoader(test_set, batch_size=32, shuffle=True, num_workers=2) - -#test if classes work properly: -#print(train_set.classes) -#print(train_set.class_to_idx) -#print(train_set.targets[3002]) +# train_loader = DataLoader(train_set, batch_size=64, shuffle=True) +#test_loader = DataLoader(test_set, batch_size=32, shuffle=True) #function for training model @@ -62,12 +56,10 @@ def accuracy(model, dataset): return correct.float() / len(dataset) - - - model = Neural_Network_Model() model.to(device) +#loading the already saved model: model.load_state_dict(torch.load('model.pth')) model.eval() @@ -78,18 +70,27 @@ model.eval() #TEST - loading the image and getting results: -testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg' -testImage = Image.open(testImage_path) -testImage = data_transformer(testImage) -testImage = testImage.unsqueeze(0) -testImage = testImage.to(device) +#testImage_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg' + +def load_model(): + model = Neural_Network_Model() + model.load_state_dict(torch.load('model.pth')) + model.eval() + return model + + +def load_image(image_path): + testImage = Image.open(image_path) + testImage = data_transformer(testImage) + testImage = testImage.unsqueeze(0) + return testImage + +def guess_image(model, image_tensor): + with torch.no_grad(): + testOutput = model(image_tensor) + _, predicted = torch.max(testOutput, 1) + predicted_class = train_set.classes[predicted.item()] + return predicted_class -model.load_state_dict(torch.load('model.pth')) -model.to(device) -model.eval() -testOutput = model(testImage) -_, predicted = torch.max(testOutput, 1) -predicted_class = train_set.classes[predicted.item()] -print(f'The predicted class is: {predicted_class}')