garbage_recognition #27

Merged
s464843 merged 6 commits from garbage_recognition into master 2022-05-26 10:26:49 +02:00
Showing only changes of commit 8029b42ed3 - Show all commits

View File

@ -0,0 +1,20 @@
import torch
import torchvision.transforms as transforms
from PIL import Image
from NeuralNetwork import NeuralNetwork
def getPrediction(img_path):
# Inicjacja sieci neuronowej
neural_net = NeuralNetwork()
PATH = './trained_nn.pth'
img = Image.open(img_path)
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
classes = ['glass', 'metal', 'paper', 'plastic']
neural_net.load_state_dict(torch.load(PATH))
neural_net.eval()
outputs = neural_net(transform_tensor)
# Wyciągnięcie największej wagi co przekłada się na rozpoznanie klasy, w tym przypadku rodzju odpadu
return classes[torch.max(outputs, 1)[1]]