Inteligentna_smieciarka/loadmodel.py

45 lines
1.1 KiB
Python
Raw Normal View History

2023-06-01 23:44:09 +02:00
import torch
import torchvision
import torchvision.transforms as transforms
import PIL.Image as Image
import os
def classify(image_path):
model = torch.load('./model_training/garbage_model.pth')
mean = [0.6908, 0.6612, 0.6218]
std = [0.1947, 0.1926, 0.2086]
classes = [
"glass",
"mixed",
"paper",
"plastic",
]
image_transforms = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
])
model = model.eval()
image = Image.open(image_path)
image = image_transforms(image).float()
image = image.unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
label = os.path.basename(os.path.dirname(image_path))
prediction = classes[predicted.item()]
print(f"predicted: {prediction}")
if label == prediction:
print("predicted correctly.")
else:
print("predicted incorrectly.")
return prediction
# classify("./model_training/test.jpg")