2020-05-20 11:45:55 +02:00
|
|
|
import cv2
|
|
|
|
import matplotlib.pyplot as plt
|
2020-05-30 15:52:48 +02:00
|
|
|
import torch
|
2020-06-01 00:21:32 +02:00
|
|
|
from PIL.Image import Image
|
2020-05-31 17:21:05 +02:00
|
|
|
from torch import nn
|
|
|
|
from torchvision.transforms import transforms
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-06-01 00:21:32 +02:00
|
|
|
def white_bg_square(img):
|
|
|
|
"return a white-background-color image having the img in exact center"
|
|
|
|
size = (max(img.size),)*2
|
|
|
|
layer = Image.new('RGB', size, (255, 255, 255))
|
|
|
|
layer.paste(img, tuple(map(lambda x:(x[0]-x[1])/2, zip(size, img.size))))
|
|
|
|
return layer
|
|
|
|
|
2020-05-31 17:21:05 +02:00
|
|
|
code = []
|
2020-06-01 00:21:32 +02:00
|
|
|
path = "test5.jpg"
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-05-31 17:21:05 +02:00
|
|
|
transform = transforms.Compose([transforms.ToTensor(),
|
|
|
|
transforms.Normalize((0.5,), (0.5,)),
|
|
|
|
])
|
|
|
|
|
2020-05-30 15:52:48 +02:00
|
|
|
img = cv2.imread(path)
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-05-30 15:52:48 +02:00
|
|
|
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
|
|
img_gray = cv2.GaussianBlur(img_gray, (5, 5), 0)
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-05-30 15:52:48 +02:00
|
|
|
ret, im_th = cv2.threshold(img_gray, 90, 255, cv2.THRESH_BINARY_INV)
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-05-30 15:52:48 +02:00
|
|
|
ctrs, hier = cv2.findContours(im_th.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-05-30 15:52:48 +02:00
|
|
|
rects = [cv2.boundingRect(ctr) for ctr in ctrs]
|
2020-05-20 11:45:55 +02:00
|
|
|
|
2020-05-31 17:21:05 +02:00
|
|
|
# load nn model
|
|
|
|
input_size = 784 # = 28*28
|
|
|
|
hidden_sizes = [128, 128, 64]
|
|
|
|
output_size = 10
|
|
|
|
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
|
|
|
nn.ReLU(),
|
|
|
|
nn.Linear(hidden_sizes[2], output_size),
|
|
|
|
nn.LogSoftmax(dim=-1))
|
|
|
|
model.load_state_dict(torch.load('digit_reco_model2.pt'))
|
|
|
|
model.eval()
|
|
|
|
|
2020-05-30 15:52:48 +02:00
|
|
|
for rect in rects:
|
2020-05-31 17:21:05 +02:00
|
|
|
# Crop image
|
2020-06-01 00:21:32 +02:00
|
|
|
crop_img = img[rect[1]:rect[1] + rect[3] + 10, rect[0]:rect[0] + rect[2] + 10, 0]
|
2020-05-30 15:52:48 +02:00
|
|
|
# Resize the image
|
2020-06-01 00:21:32 +02:00
|
|
|
roi = cv2.resize(crop_img, (28, 28), interpolation=cv2.INTER_LINEAR)
|
|
|
|
roi = cv2.dilate(roi, (3, 3))
|
2020-05-31 17:21:05 +02:00
|
|
|
plt.imshow(roi)
|
|
|
|
plt.show()
|
2020-06-01 00:21:32 +02:00
|
|
|
im = transform(roi)
|
|
|
|
im = im.view(1, 784)
|
2020-05-31 17:21:05 +02:00
|
|
|
with torch.no_grad():
|
2020-06-01 00:21:32 +02:00
|
|
|
logps = model(im.float())
|
2020-05-31 17:21:05 +02:00
|
|
|
ps = torch.exp(logps)
|
|
|
|
probab = list(ps.numpy()[0])
|
|
|
|
print("Predicted Digit =", probab.index(max(probab)))
|
2020-05-20 08:24:33 +02:00
|
|
|
|
2020-06-01 00:21:32 +02:00
|
|
|
cv2.imshow("Code", img)
|
2020-05-30 15:52:48 +02:00
|
|
|
cv2.waitKey()
|