import torch import torchvision import torchvision.transforms as transforms import PIL.Image as Image import os def classify(image_path): model = torch.load('./model_training/garbage_model.pth') mean = [0.6908, 0.6612, 0.6218] std = [0.1947, 0.1926, 0.2086] classes = [ "glass", "mixed", "paper", "plastic", ] image_transforms = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor(), transforms.Normalize(torch.Tensor(mean), torch.Tensor(std)) ]) model = model.eval() image = Image.open(image_path) image = image_transforms(image).float() image = image.unsqueeze(0) output = model(image) _, predicted = torch.max(output.data, 1) label = os.path.basename(os.path.dirname(image_path)) prediction = classes[predicted.item()] print(f"predicted: {prediction}") if label == prediction: print("predicted correctly.") else: print("predicted incorrectly.") return prediction # classify("./model_training/test.jpg")