59 lines
1.5 KiB
Python
59 lines
1.5 KiB
Python
import numpy as np
|
|
import torch
|
|
import torchvision
|
|
import matplotlib.pyplot as plt
|
|
from time import time
|
|
from torchvision import datasets, transforms
|
|
from torch import nn, optim
|
|
import cv2
|
|
|
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
transforms.Normalize((0.5,), (0.5,)),
|
|
])
|
|
|
|
|
|
|
|
# load nn model
|
|
input_size = 784 # = 28*28
|
|
hidden_sizes = [128, 128, 64]
|
|
output_size = 10
|
|
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_sizes[2], output_size),
|
|
nn.LogSoftmax(dim=-1))
|
|
model.load_state_dict(torch.load('digit_reco_model2.pt'))
|
|
model.eval()
|
|
# model = torch.load('digit_reco_model2.pt')
|
|
|
|
if model is None:
|
|
print("Model is not loaded.")
|
|
else:
|
|
print("Model is loaded.")
|
|
|
|
|
|
# img from dataset
|
|
val_set = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
|
|
|
|
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)
|
|
|
|
images, labels = next(iter(val_loader))
|
|
print(type(images))
|
|
img = images[0].view(1, 784)
|
|
plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
|
|
plt.show()
|
|
|
|
|
|
# recognizing
|
|
|
|
with torch.no_grad():
|
|
logps = model(img)
|
|
|
|
ps = torch.exp(logps)
|
|
probab = list(ps.numpy()[0])
|
|
print("Predicted Digit =", probab.index(max(probab)))
|