Trashmaster/NeuralNetwork/prediction.py

20 lines
748 B
Python
Raw Normal View History

2022-05-23 20:20:15 +02:00
import torch
import torchvision.transforms as transforms
from PIL import Image
from NeuralNetwork import NeuralNetwork
def getPrediction(img_path, network_name):
2022-05-26 20:55:34 +02:00
2022-05-23 20:20:15 +02:00
# Inicjacja sieci neuronowej
2022-05-26 20:55:34 +02:00
neural_net = NeuralNetwork.NeuralNetwork()
PATH = './NeuralNetwork/trained_networks/'
2022-05-23 20:20:15 +02:00
img = Image.open(img_path)
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
classes = ['glass', 'metal', 'paper', 'plastic']
neural_net.load_state_dict(torch.load(PATH + network_name, map_location='cpu'))
2022-05-23 20:20:15 +02:00
neural_net.eval()
outputs = neural_net(transform_tensor)
# Wyciągnięcie największej wagi co przekłada się na rozpoznanie klasy, w tym przypadku rodzju odpadu
return classes[torch.max(outputs, 1)[1]]