add getPrediction function
This commit is contained in:
parent
d65d632c62
commit
8029b42ed3
20
NeuralNetwork/prediction.py
Normal file
20
NeuralNetwork/prediction.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from NeuralNetwork import NeuralNetwork
|
||||||
|
|
||||||
|
def getPrediction(img_path):
|
||||||
|
|
||||||
|
# Inicjacja sieci neuronowej
|
||||||
|
neural_net = NeuralNetwork()
|
||||||
|
PATH = './trained_nn.pth'
|
||||||
|
img = Image.open(img_path)
|
||||||
|
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
|
||||||
|
classes = ['glass', 'metal', 'paper', 'plastic']
|
||||||
|
neural_net.load_state_dict(torch.load(PATH))
|
||||||
|
neural_net.eval()
|
||||||
|
outputs = neural_net(transform_tensor)
|
||||||
|
|
||||||
|
# Wyciągnięcie największej wagi co przekłada się na rozpoznanie klasy, w tym przypadku rodzju odpadu
|
||||||
|
return classes[torch.max(outputs, 1)[1]]
|
Loading…
Reference in New Issue
Block a user