Projekt_Si/classes/Jimmy_Neuron/predict_image.py
2024-06-12 00:24:57 +02:00

37 lines
1.0 KiB
Python

import torch
from torchvision.transforms import Compose, Lambda
import torchvision.io as io
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
hidden_size = 135 * 64
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 6, 5),
torch.nn.MaxPool2d(2, 2),
torch.nn.Conv2d(6, 16, 5),
torch.nn.Flatten(),
torch.nn.Linear(53824, hidden_size),
torch.nn.ReLU(),
torch.nn.Linear(hidden_size, 32 * 32),
torch.nn.ReLU(),
torch.nn.Linear(32 * 32, 10),
torch.nn.LogSoftmax(dim=-1)
).to(device)
model.load_state_dict(torch.load('model.pt', map_location=device))
model.eval()
def predict_image(image_path):
transform = Compose([Lambda(lambda x: x.float())])
image = io.read_image(image_path, mode=io.ImageReadMode.UNCHANGED)
image = transform(image)
image = image.unsqueeze(0)
image = image.to(device)
with torch.no_grad():
output = model(image)
predicted_class = output.argmax(dim=1).item()
print(predicted_class)
return predicted_class