2
0
forked from s444420/AL-2020
AL-2020/coder/coder.py
2020-05-31 17:21:05 +02:00

60 lines
1.6 KiB
Python

import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim
import cv2
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
# load nn model
input_size = 784 # = 28*28
hidden_sizes = [128, 128, 64]
output_size = 10
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
nn.ReLU(),
nn.Linear(hidden_sizes[2], output_size),
nn.LogSoftmax(dim=-1))
model.load_state_dict(torch.load('digit_reco_model2.pt'))
model.eval()
# model = torch.load('digit_reco_model2.pt')
if model is None:
print("Model is not loaded.")
else:
print("Model is loaded.")
# img from dataset
val_set = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)
images, labels = next(iter(val_loader))
print(type(images))
img = images[0].view(1, 784)
plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
plt.show()
# recognizing
with torch.no_grad():
logps = model(img)
print(logps)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))