37 lines
1.0 KiB
Python
37 lines
1.0 KiB
Python
import torch
|
|
from torchvision.transforms import Compose, Lambda
|
|
import torchvision.io as io
|
|
|
|
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
|
|
|
hidden_size = 135 * 64
|
|
model = torch.nn.Sequential(
|
|
torch.nn.Conv2d(3, 6, 5),
|
|
torch.nn.MaxPool2d(2, 2),
|
|
torch.nn.Conv2d(6, 16, 5),
|
|
torch.nn.Flatten(),
|
|
torch.nn.Linear(53824, hidden_size),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(hidden_size, 32 * 32),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(32 * 32, 10),
|
|
torch.nn.LogSoftmax(dim=-1)
|
|
).to(device)
|
|
|
|
model.load_state_dict(torch.load('model.pt', map_location=device))
|
|
model.eval()
|
|
|
|
def predict_image(image_path):
|
|
transform = Compose([Lambda(lambda x: x.float())])
|
|
image = io.read_image(image_path, mode=io.ImageReadMode.UNCHANGED)
|
|
image = transform(image)
|
|
image = image.unsqueeze(0)
|
|
image = image.to(device)
|
|
with torch.no_grad():
|
|
output = model(image)
|
|
predicted_class = output.argmax(dim=1).item()
|
|
print(predicted_class)
|
|
|
|
return predicted_class
|
|
|