photos for predictions
This commit is contained in:
parent
d2ad851cab
commit
4955e737c5
BIN
src/veggies_recognition/marchew_118.jpg
Normal file
BIN
src/veggies_recognition/marchew_118.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 9.2 KiB |
36
src/veggies_recognition/predict.py
Normal file
36
src/veggies_recognition/predict.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
classes = [
|
||||||
|
"bób", "brokuł", "brukselka", "burak", "cebula",
|
||||||
|
"cukinia", "dynia", "fasola", "groch", "jarmuż",
|
||||||
|
"kalafior", "kalarepa", "kapusta", "marchew",
|
||||||
|
"ogórek", "papryka", "pietruszka", "pomidor",
|
||||||
|
"por", "rzepa", "rzodkiewka", "sałata", "seler",
|
||||||
|
"szpinak", "ziemniak"]
|
||||||
|
|
||||||
|
model = torch.load("best_model.pth")
|
||||||
|
|
||||||
|
mean = [0.5322, 0.5120, 0.3696]
|
||||||
|
std = [0.2487, 0.2436, 0.2531]
|
||||||
|
|
||||||
|
image_transforms = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(torch.Tensor(mean),torch.Tensor(std))
|
||||||
|
])
|
||||||
|
|
||||||
|
def predict(model, image_transforms, image_path, classes):
|
||||||
|
model = model.eval()
|
||||||
|
image = Image.open(image_path)
|
||||||
|
image = image_transforms(image).float()
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
output = model(image)
|
||||||
|
_, predicted = torch.max(output.data, 1)
|
||||||
|
|
||||||
|
print(classes[predicted.item()])
|
||||||
|
|
||||||
|
predict(model, image_transforms, "marchew_118.jpg", classes)
|
Loading…
Reference in New Issue
Block a user