add getPrediction function
This commit is contained in:
parent
d65d632c62
commit
8029b42ed3
20
NeuralNetwork/prediction.py
Normal file
20
NeuralNetwork/prediction.py
Normal file
@ -0,0 +1,20 @@
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from PIL import Image
|
||||
from NeuralNetwork import NeuralNetwork
|
||||
|
||||
def getPrediction(img_path):
|
||||
|
||||
# Inicjacja sieci neuronowej
|
||||
neural_net = NeuralNetwork()
|
||||
PATH = './trained_nn.pth'
|
||||
img = Image.open(img_path)
|
||||
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
|
||||
classes = ['glass', 'metal', 'paper', 'plastic']
|
||||
neural_net.load_state_dict(torch.load(PATH))
|
||||
neural_net.eval()
|
||||
outputs = neural_net(transform_tensor)
|
||||
|
||||
# Wyciągnięcie największej wagi co przekłada się na rozpoznanie klasy, w tym przypadku rodzju odpadu
|
||||
return classes[torch.max(outputs, 1)[1]]
|
Loading…
Reference in New Issue
Block a user