AL-2020/coder/coder.py
2020-05-27 02:15:29 +02:00

59 lines
1.4 KiB
Python

import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim, nn, optim
import cv2
def view_classify(img, ps):
''' Function for viewing an image and it's predicted classes.
'''
ps = ps.data.numpy().squeeze()
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
ax1.axis('off')
ax2.barh(np.arange(10), ps)
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
ax2.set_yticklabels(np.arange(10))
ax2.set_title('Class Probability')
ax2.set_xlim(0, 1.1)
plt.tight_layout()
# load nn model
model = torch.load('digit_reco_model2.pt')
if model is None:
print("Model is not loaded.")
else:
print("Model is loaded.")
# image
img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2GRAY)
img = cv2.blur(img, (9, 9)) # poprawia jakosc
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
img = img.reshape((len(img), -1))
print(type(img))
# print(img.shape)
# plt.imshow(img ,cmap='binary')
# plt.show()
img = np.array(img, dtype=np.float32)
img = torch.from_numpy(img)
img = img.view(1, 784)
# recognizing
with torch.no_grad():
logps = model(img)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))
view_classify(img.view(1, 28, 28), ps)