projekt_ml/predict.py
2025-01-19 16:27:39 +01:00

29 lines
637 B
Python

from sys import argv
import torch
from PIL import Image
from preprocess import TRANSFORM
from settings import MODEL_FILENAME, CLASSES, OurCNN
if __name__ == "__main__":
if len(argv) < 2:
print("Usage: predict.py <path>")
exit(1)
model = OurCNN()
model.load_state_dict(torch.load(MODEL_FILENAME))
# Load the image
image = Image.open(argv[1])
# Preprocess the image
tensor: torch.Tensor = TRANSFORM(image) # type: ignore
# get the prediction
with torch.no_grad():
output = model(tensor.unsqueeze(0))
_, predicted = torch.max(output, 1)
print(CLASSES[predicted])