AL-2020/coder/rocognizer.py

71 lines
1.9 KiB
Python
Raw Normal View History

import numpy as np
2020-05-20 11:45:55 +02:00
import argparse
import imutils
import cv2
import matplotlib.pyplot as plt
2020-05-30 15:52:48 +02:00
import torch
from matplotlib import cm
from torch import nn
2020-05-30 15:52:48 +02:00
from PIL import Image
from skimage.feature import hog
from torchvision.transforms import transforms
2020-05-20 11:45:55 +02:00
code = []
2020-05-30 15:52:48 +02:00
path = "test1.jpg"
2020-05-20 11:45:55 +02:00
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
2020-05-30 15:52:48 +02:00
img = cv2.imread(path)
2020-05-20 11:45:55 +02:00
2020-05-30 15:52:48 +02:00
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
img_gray = cv2.GaussianBlur(img_gray, (5, 5), 0)
2020-05-20 11:45:55 +02:00
2020-05-30 15:52:48 +02:00
ret, im_th = cv2.threshold(img_gray, 90, 255, cv2.THRESH_BINARY_INV)
2020-05-20 11:45:55 +02:00
2020-05-30 15:52:48 +02:00
ctrs, hier = cv2.findContours(im_th.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
2020-05-20 11:45:55 +02:00
2020-05-30 15:52:48 +02:00
rects = [cv2.boundingRect(ctr) for ctr in ctrs]
2020-05-20 11:45:55 +02:00
# load nn model
input_size = 784 # = 28*28
hidden_sizes = [128, 128, 64]
output_size = 10
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
nn.ReLU(),
nn.Linear(hidden_sizes[2], output_size),
nn.LogSoftmax(dim=-1))
model.load_state_dict(torch.load('digit_reco_model2.pt'))
model.eval()
2020-05-30 15:52:48 +02:00
for rect in rects:
# Crop image
crop_img = img[rect[1]:rect[1] + rect[3], rect[0]:rect[0] + rect[2]]
plt.imshow(crop_img)
plt.show()
2020-05-30 15:52:48 +02:00
# Resize the image
roi = cv2.resize(crop_img, (28, 28), interpolation=cv2.INTER_AREA)
plt.imshow(roi)
plt.show()
im = Image.fromarray(roi)
im = transform(im)
print(im)
plt.imshow(im)
plt.show()
with torch.no_grad():
logps = model(im)
ps = torch.exp(logps)
print(ps[0])
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))
2020-05-30 15:52:48 +02:00
cv2.imshow("Resulting Image with Rectangular ROIs", img)
cv2.waitKey()