2022-05-23 20:20:15 +02:00
|
|
|
import torch
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
from NeuralNetwork import NeuralNetwork
|
|
|
|
|
2022-05-27 09:22:49 +02:00
|
|
|
def getPrediction(img_path, network_name):
|
2022-05-26 20:55:34 +02:00
|
|
|
|
2022-05-23 20:20:15 +02:00
|
|
|
# Inicjacja sieci neuronowej
|
2022-05-26 20:55:34 +02:00
|
|
|
neural_net = NeuralNetwork.NeuralNetwork()
|
2022-05-27 09:22:49 +02:00
|
|
|
PATH = './NeuralNetwork/trained_networks/'
|
2022-05-23 20:20:15 +02:00
|
|
|
img = Image.open(img_path)
|
|
|
|
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
|
|
|
|
classes = ['glass', 'metal', 'paper', 'plastic']
|
2022-06-10 03:16:43 +02:00
|
|
|
neural_net.load_state_dict(torch.load(PATH + network_name, map_location='cpu'))
|
2022-05-23 20:20:15 +02:00
|
|
|
neural_net.eval()
|
|
|
|
outputs = neural_net(transform_tensor)
|
|
|
|
|
|
|
|
# Wyciągnięcie największej wagi co przekłada się na rozpoznanie klasy, w tym przypadku rodzju odpadu
|
|
|
|
return classes[torch.max(outputs, 1)[1]]
|