photos for predictions

This commit is contained in:
Zofia Lorenc 2024-05-26 17:44:01 +02:00
parent d2ad851cab
commit 4955e737c5
2 changed files with 36 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.2 KiB

View File

@ -0,0 +1,36 @@
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
classes = [
"bób", "brokuł", "brukselka", "burak", "cebula",
"cukinia", "dynia", "fasola", "groch", "jarmuż",
"kalafior", "kalarepa", "kapusta", "marchew",
"ogórek", "papryka", "pietruszka", "pomidor",
"por", "rzepa", "rzodkiewka", "sałata", "seler",
"szpinak", "ziemniak"]
model = torch.load("best_model.pth")
mean = [0.5322, 0.5120, 0.3696]
std = [0.2487, 0.2436, 0.2531]
image_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(torch.Tensor(mean),torch.Tensor(std))
])
def predict(model, image_transforms, image_path, classes):
model = model.eval()
image = Image.open(image_path)
image = image_transforms(image).float()
image = image.unsqueeze(0)
output = model(image)
_, predicted = torch.max(output.data, 1)
print(classes[predicted.item()])
predict(model, image_transforms, "marchew_118.jpg", classes)