45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
|
import torch
|
||
|
import torchvision
|
||
|
import torchvision.transforms as transforms
|
||
|
import PIL.Image as Image
|
||
|
import os
|
||
|
|
||
|
|
||
|
def classify(image_path):
|
||
|
model = torch.load('./model_training/garbage_model.pth')
|
||
|
mean = [0.6908, 0.6612, 0.6218]
|
||
|
std = [0.1947, 0.1926, 0.2086]
|
||
|
classes = [
|
||
|
"glass",
|
||
|
"mixed",
|
||
|
"paper",
|
||
|
"plastic",
|
||
|
]
|
||
|
image_transforms = transforms.Compose([
|
||
|
transforms.Resize((128, 128)),
|
||
|
transforms.ToTensor(),
|
||
|
transforms.Normalize(torch.Tensor(mean), torch.Tensor(std))
|
||
|
])
|
||
|
|
||
|
model = model.eval()
|
||
|
image = Image.open(image_path)
|
||
|
image = image_transforms(image).float()
|
||
|
image = image.unsqueeze(0)
|
||
|
|
||
|
output = model(image)
|
||
|
_, predicted = torch.max(output.data, 1)
|
||
|
|
||
|
label = os.path.basename(os.path.dirname(image_path))
|
||
|
prediction = classes[predicted.item()]
|
||
|
print(f"predicted: {prediction}")
|
||
|
if label == prediction:
|
||
|
print("predicted correctly.")
|
||
|
else:
|
||
|
print("predicted incorrectly.")
|
||
|
return prediction
|
||
|
|
||
|
|
||
|
# classify("./model_training/test.jpg")
|
||
|
|
||
|
|