import torch from torchvision.transforms import Compose, Lambda import torchvision.io as io device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') hidden_size = 135 * 64 model = torch.nn.Sequential( torch.nn.Conv2d(3, 6, 5), torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(6, 16, 5), torch.nn.Flatten(), torch.nn.Linear(53824, hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, 32 * 32), torch.nn.ReLU(), torch.nn.Linear(32 * 32, 10), torch.nn.LogSoftmax(dim=-1) ).to(device) model.load_state_dict(torch.load('model.pt', map_location=device)) model.eval() def predict_image(image_path): transform = Compose([Lambda(lambda x: x.float())]) image = io.read_image(image_path, mode=io.ImageReadMode.UNCHANGED) image = transform(image) image = image.unsqueeze(0) image = image.to(device) with torch.no_grad(): output = model(image) predicted_class = output.argmax(dim=1).item() print(predicted_class) return predicted_class