29 lines
637 B
Python
29 lines
637 B
Python
from sys import argv
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from preprocess import TRANSFORM
|
|
from settings import MODEL_FILENAME, CLASSES, OurCNN
|
|
|
|
if __name__ == "__main__":
|
|
|
|
if len(argv) < 2:
|
|
print("Usage: predict.py <path>")
|
|
exit(1)
|
|
|
|
model = OurCNN()
|
|
model.load_state_dict(torch.load(MODEL_FILENAME))
|
|
|
|
# Load the image
|
|
image = Image.open(argv[1])
|
|
|
|
# Preprocess the image
|
|
tensor: torch.Tensor = TRANSFORM(image) # type: ignore
|
|
|
|
# get the prediction
|
|
with torch.no_grad():
|
|
output = model(tensor.unsqueeze(0))
|
|
_, predicted = torch.max(output, 1)
|
|
|
|
print(CLASSES[predicted])
|