added functions for loading images, model and testing
This commit is contained in:
parent
fb0ec5057c
commit
b45c2e0f1f
@ -4,17 +4,16 @@ from torch.utils.data import DataLoader
|
|||||||
from torchvision import datasets, transforms, utils
|
from torchvision import datasets, transforms, utils
|
||||||
from torchvision.transforms import Compose, Lambda, ToTensor
|
from torchvision.transforms import Compose, Lambda, ToTensor
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
from .model import *
|
||||||
from model import *
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
|
|
||||||
#data transform to tensors:
|
#data transform to tensors:
|
||||||
data_transformer = transforms.Compose([
|
data_transformer = transforms.Compose([
|
||||||
transforms.Resize((150, 150)),
|
transforms.Resize((100, 100)),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
transforms.Normalize((0.5, 0.5, 0.5 ), (0.5, 0.5, 0.5))
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
@ -24,13 +23,8 @@ test_set = datasets.ImageFolder(root='resources/test', transform=data_transforme
|
|||||||
|
|
||||||
|
|
||||||
#to mozna nawet przerzucic do funkcji train:
|
#to mozna nawet przerzucic do funkcji train:
|
||||||
#train_loader = DataLoader(train_set, batch_size=32, shuffle=True, num_workers=2)
|
# train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
|
||||||
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True, num_workers=2)
|
#test_loader = DataLoader(test_set, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
#test if classes work properly:
|
|
||||||
#print(train_set.classes)
|
|
||||||
#print(train_set.class_to_idx)
|
|
||||||
#print(train_set.targets[3002])
|
|
||||||
|
|
||||||
|
|
||||||
#function for training model
|
#function for training model
|
||||||
@ -62,12 +56,10 @@ def accuracy(model, dataset):
|
|||||||
return correct.float() / len(dataset)
|
return correct.float() / len(dataset)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
model = Neural_Network_Model()
|
model = Neural_Network_Model()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
#loading the already saved model:
|
||||||
model.load_state_dict(torch.load('model.pth'))
|
model.load_state_dict(torch.load('model.pth'))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -78,18 +70,27 @@ model.eval()
|
|||||||
|
|
||||||
|
|
||||||
#TEST - loading the image and getting results:
|
#TEST - loading the image and getting results:
|
||||||
testImage_path = 'resources/images/plant_photos/pexels-polina-tankilevitch-4110456.jpg'
|
#testImage_path = 'resources/images/plant_photos/pexels-dxt-73640.jpg'
|
||||||
testImage = Image.open(testImage_path)
|
|
||||||
testImage = data_transformer(testImage)
|
def load_model():
|
||||||
testImage = testImage.unsqueeze(0)
|
model = Neural_Network_Model()
|
||||||
testImage = testImage.to(device)
|
model.load_state_dict(torch.load('model.pth'))
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image_path):
|
||||||
|
testImage = Image.open(image_path)
|
||||||
|
testImage = data_transformer(testImage)
|
||||||
|
testImage = testImage.unsqueeze(0)
|
||||||
|
return testImage
|
||||||
|
|
||||||
|
def guess_image(model, image_tensor):
|
||||||
|
with torch.no_grad():
|
||||||
|
testOutput = model(image_tensor)
|
||||||
|
_, predicted = torch.max(testOutput, 1)
|
||||||
|
predicted_class = train_set.classes[predicted.item()]
|
||||||
|
return predicted_class
|
||||||
|
|
||||||
model.load_state_dict(torch.load('model.pth'))
|
|
||||||
model.to(device)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
testOutput = model(testImage)
|
|
||||||
_, predicted = torch.max(testOutput, 1)
|
|
||||||
predicted_class = train_set.classes[predicted.item()]
|
|
||||||
print(f'The predicted class is: {predicted_class}')
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user