diff --git a/src/veggies_recognition/marchew_118.jpg b/src/veggies_recognition/marchew_118.jpg new file mode 100644 index 00000000..62ba9ca7 Binary files /dev/null and b/src/veggies_recognition/marchew_118.jpg differ diff --git a/src/veggies_recognition/predict.py b/src/veggies_recognition/predict.py new file mode 100644 index 00000000..12d8aa76 --- /dev/null +++ b/src/veggies_recognition/predict.py @@ -0,0 +1,36 @@ +import torch +import torchvision +import torchvision.transforms as transforms +from PIL import Image + +classes = [ + "bób", "brokuł", "brukselka", "burak", "cebula", + "cukinia", "dynia", "fasola", "groch", "jarmuż", + "kalafior", "kalarepa", "kapusta", "marchew", + "ogórek", "papryka", "pietruszka", "pomidor", + "por", "rzepa", "rzodkiewka", "sałata", "seler", + "szpinak", "ziemniak"] + +model = torch.load("best_model.pth") + +mean = [0.5322, 0.5120, 0.3696] +std = [0.2487, 0.2436, 0.2531] + +image_transforms = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(torch.Tensor(mean),torch.Tensor(std)) +]) + +def predict(model, image_transforms, image_path, classes): + model = model.eval() + image = Image.open(image_path) + image = image_transforms(image).float() + image = image.unsqueeze(0) + + output = model(image) + _, predicted = torch.max(output.data, 1) + + print(classes[predicted.item()]) + +predict(model, image_transforms, "marchew_118.jpg", classes) \ No newline at end of file