AL-2020/coder/coder.py

59 lines
1.5 KiB
Python
Raw Normal View History

2020-05-27 02:15:29 +02:00
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
2020-05-30 15:52:48 +02:00
from torch import nn, optim
2020-05-27 02:15:29 +02:00
import cv2
2020-05-30 15:52:48 +02:00
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
2020-05-27 02:15:29 +02:00
# load nn model
2020-05-30 15:52:48 +02:00
input_size = 784 # = 28*28
hidden_sizes = [128, 128, 64]
output_size = 10
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
nn.ReLU(),
nn.Linear(hidden_sizes[2], output_size),
nn.LogSoftmax(dim=-1))
model.load_state_dict(torch.load('digit_reco_model2.pt'))
model.eval()
# model = torch.load('digit_reco_model2.pt')
2020-05-27 02:15:29 +02:00
if model is None:
print("Model is not loaded.")
else:
print("Model is loaded.")
2020-05-30 15:52:48 +02:00
# img from dataset
val_set = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)
images, labels = next(iter(val_loader))
print(type(images))
img = images[0].view(1, 784)
plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
plt.show()
2020-05-27 02:15:29 +02:00
# recognizing
with torch.no_grad():
logps = model(img)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))