59 lines
1.4 KiB
Python
59 lines
1.4 KiB
Python
|
import numpy as np
|
||
|
import torch
|
||
|
import torchvision
|
||
|
import matplotlib.pyplot as plt
|
||
|
from time import time
|
||
|
from torchvision import datasets, transforms
|
||
|
from torch import nn, optim, nn, optim
|
||
|
import cv2
|
||
|
|
||
|
|
||
|
def view_classify(img, ps):
|
||
|
''' Function for viewing an image and it's predicted classes.
|
||
|
'''
|
||
|
ps = ps.data.numpy().squeeze()
|
||
|
|
||
|
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
|
||
|
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
|
||
|
ax1.axis('off')
|
||
|
ax2.barh(np.arange(10), ps)
|
||
|
ax2.set_aspect(0.1)
|
||
|
ax2.set_yticks(np.arange(10))
|
||
|
ax2.set_yticklabels(np.arange(10))
|
||
|
ax2.set_title('Class Probability')
|
||
|
ax2.set_xlim(0, 1.1)
|
||
|
plt.tight_layout()
|
||
|
|
||
|
# load nn model
|
||
|
model = torch.load('digit_reco_model2.pt')
|
||
|
|
||
|
if model is None:
|
||
|
print("Model is not loaded.")
|
||
|
else:
|
||
|
print("Model is loaded.")
|
||
|
|
||
|
# image
|
||
|
img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2GRAY)
|
||
|
img = cv2.blur(img, (9, 9)) # poprawia jakosc
|
||
|
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
|
||
|
img = img.reshape((len(img), -1))
|
||
|
|
||
|
print(type(img))
|
||
|
# print(img.shape)
|
||
|
# plt.imshow(img ,cmap='binary')
|
||
|
# plt.show()
|
||
|
img = np.array(img, dtype=np.float32)
|
||
|
img = torch.from_numpy(img)
|
||
|
img = img.view(1, 784)
|
||
|
|
||
|
# recognizing
|
||
|
|
||
|
with torch.no_grad():
|
||
|
logps = model(img)
|
||
|
|
||
|
ps = torch.exp(logps)
|
||
|
probab = list(ps.numpy()[0])
|
||
|
print("Predicted Digit =", probab.index(max(probab)))
|
||
|
|
||
|
view_classify(img.view(1, 28, 28), ps)
|