added functions for loading images, model and testing

This commit is contained in:
MarRac 2024-05-26 19:56:18 +02:00
parent fb0ec5057c
commit b45c2e0f1f

View File

@ -4,17 +4,16 @@ from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils from torchvision import datasets, transforms, utils
from torchvision.transforms import Compose, Lambda, ToTensor from torchvision.transforms import Compose, Lambda, ToTensor
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np from .model import *
from model import *
from PIL import Image from PIL import Image
device = torch.device('cuda') device = torch.device('cuda')
#data transform to tensors: #data transform to tensors:
data_transformer = transforms.Compose([ data_transformer = transforms.Compose([
transforms.Resize((150, 150)), transforms.Resize((100, 100)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5))
]) ])
@ -24,13 +23,8 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme
#to mozna nawet przerzucic do funkcji train: #to mozna nawet przerzucic do funkcji train:
#train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2) # train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True, num_workers=2) #test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
#test if classes work properly:
#print(train_set.classes)
#print(train_set.class_to_idx)
#print(train_set.targets[3002])
#function for training model #function for training model
@ -62,12 +56,10 @@ def accuracy(model, dataset):
return correct.float() / len(dataset) return correct.float() / len(dataset)
model = Neural_Network_Model() model = Neural_Network_Model()
model.to(device) model.to(device)
#loading the already saved model:
model.load_state_dict(torch.load('model.pth')) model.load_state_dict(torch.load('model.pth'))
model.eval() model.eval()
@ -78,18 +70,27 @@ model.eval()
#TEST - loading the image and getting results: #TEST - loading the image and getting results:
testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg' #testImage_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
testImage = Image.open(testImage_path)
testImage = data_transformer(testImage) def load_model():
testImage = testImage.unsqueeze(0) model = Neural_Network_Model()
testImage = testImage.to(device) model.load_state_dict(torch.load('model.pth'))
model.eval()
return model
def load_image(image_path):
testImage = Image.open(image_path)
testImage = data_transformer(testImage)
testImage = testImage.unsqueeze(0)
return testImage
def guess_image(model, image_tensor):
with torch.no_grad():
testOutput = model(image_tensor)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
return predicted_class
model.load_state_dict(torch.load('model.pth'))
model.to(device)
model.eval()
testOutput = model(testImage)
_, predicted = torch.max(testOutput, 1)
predicted_class = train_set.classes[predicted.item()]
print(f'The predicted class is: {predicted_class}')