add coder.py

This commit is contained in:
shaaqu 2020-05-27 02:15:29 +02:00
parent 4720da3158
commit 38d24273c5
6 changed files with 191 additions and 49 deletions

View File

@ -19,12 +19,12 @@
<select />
</component>
<component name="ChangeListManager">
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="po">
<change afterPath="$PROJECT_DIR$/coder/digits_recognizer.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="going to pytorch on conda eve">
<change afterPath="$PROJECT_DIR$/coder/coder.py" afterDir="false" />
<change afterPath="$PROJECT_DIR$/coder/digit_reco_model.pt" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/wozek.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/wozek.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/coder/image.py" beforeDir="false" afterPath="$PROJECT_DIR$/coder/image.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/coder/digits_recognizer.py" beforeDir="false" afterPath="$PROJECT_DIR$/coder/digits_recognizer.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/coder/gr_test.png" beforeDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
@ -106,6 +106,25 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="coder" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="wozek" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/coder/coder.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="true" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="digits_recognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="wozek" />
<option name="INTERPRETER_OPTIONS" value="" />
@ -147,35 +166,10 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true">
<module name="wozek" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="rocognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
<module name="wozek" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
<option name="IS_MODULE_SDK" value="true" />
@ -215,18 +209,18 @@
</configuration>
<list>
<item itemvalue="Python.image" />
<item itemvalue="Python.main" />
<item itemvalue="Python.train_nn" />
<item itemvalue="Python.rocognizer" />
<item itemvalue="Python.digits_recognizer" />
<item itemvalue="Python.coder" />
</list>
<recent_temporary>
<list>
<item itemvalue="Python.digits_recognizer" />
<item itemvalue="Python.image" />
<item itemvalue="Python.coder" />
<item itemvalue="Python.rocognizer" />
<item itemvalue="Python.image" />
<item itemvalue="Python.train_nn" />
<item itemvalue="Python.main" />
</list>
</recent_temporary>
</component>
@ -262,7 +256,9 @@
<workItem from="1590409526059" duration="4922000" />
<workItem from="1590423569728" duration="2532000" />
<workItem from="1590436739719" duration="6325000" />
<workItem from="1590443664804" duration="2683000" />
<workItem from="1590443664804" duration="2943000" />
<workItem from="1590497613517" duration="6041000" />
<workItem from="1590518246722" duration="12460000" />
</task>
<task id="LOCAL-00001" summary="create Shelf">
<created>1589815443652</created>
@ -341,7 +337,14 @@
<option name="project" value="LOCAL" />
<updated>1590359074952</updated>
</task>
<option name="localTasksCounter" value="12" />
<task id="LOCAL-00012" summary="going to pytorch on conda eve">
<created>1590447313737</created>
<option name="number" value="00012" />
<option name="presentableId" value="LOCAL-00012" />
<option name="project" value="LOCAL" />
<updated>1590447313737</updated>
</task>
<option name="localTasksCounter" value="13" />
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
@ -372,7 +375,8 @@
<MESSAGE value="finding barcode" />
<MESSAGE value="po" />
<MESSAGE value="new dataset" />
<option name="LAST_COMMIT_MESSAGE" value="new dataset" />
<MESSAGE value="going to pytorch on conda eve" />
<option name="LAST_COMMIT_MESSAGE" value="going to pytorch on conda eve" />
</component>
<component name="WindowStateProjectService">
<state x="115" y="162" key="#com.intellij.refactoring.safeDelete.UnsafeUsagesDialog" timestamp="1589923610328">
@ -399,10 +403,10 @@
<screen x="0" y="0" width="1536" height="824" />
</state>
<state x="277" y="57" key="SettingsEditor/0.0.1536.824@0.0.1536.824" timestamp="1590443566792" />
<state x="361" y="145" key="Vcs.Push.Dialog.v2" timestamp="1590359093497">
<state x="361" y="145" key="Vcs.Push.Dialog.v2" timestamp="1590447321698">
<screen x="0" y="0" width="1536" height="824" />
</state>
<state x="361" y="145" key="Vcs.Push.Dialog.v2/0.0.1536.824@0.0.1536.824" timestamp="1590359093496" />
<state x="361" y="145" key="Vcs.Push.Dialog.v2/0.0.1536.824@0.0.1536.824" timestamp="1590447321698" />
<state x="54" y="145" width="672" height="678" key="search.everywhere.popup" timestamp="1589918982407">
<screen x="0" y="0" width="1536" height="824" />
</state>

58
coder/coder.py Normal file
View File

@ -0,0 +1,58 @@
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from time import time
from torchvision import datasets, transforms
from torch import nn, optim, nn, optim
import cv2
def view_classify(img, ps):
''' Function for viewing an image and it's predicted classes.
'''
ps = ps.data.numpy().squeeze()
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
ax1.axis('off')
ax2.barh(np.arange(10), ps)
ax2.set_aspect(0.1)
ax2.set_yticks(np.arange(10))
ax2.set_yticklabels(np.arange(10))
ax2.set_title('Class Probability')
ax2.set_xlim(0, 1.1)
plt.tight_layout()
# load nn model
model = torch.load('digit_reco_model2.pt')
if model is None:
print("Model is not loaded.")
else:
print("Model is loaded.")
# image
img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2GRAY)
img = cv2.blur(img, (9, 9)) # poprawia jakosc
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
img = img.reshape((len(img), -1))
print(type(img))
# print(img.shape)
# plt.imshow(img ,cmap='binary')
# plt.show()
img = np.array(img, dtype=np.float32)
img = torch.from_numpy(img)
img = img.view(1, 784)
# recognizing
with torch.no_grad():
logps = model(img)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))
view_classify(img.view(1, 28, 28), ps)

BIN
coder/digit_reco_model.pt Normal file

Binary file not shown.

BIN
coder/digit_reco_model2.pt Normal file

Binary file not shown.

View File

@ -6,17 +6,19 @@ from time import time
from torchvision import datasets, transforms
from torch import nn, optim
# IMG transform
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
trainset = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
valset = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)
# dataset download
train_set = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
val_set = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)
dataiter = iter(trainloader)
images, labels = dataiter.next()
data_iter = iter(train_loader)
images, labels = data_iter.next()
print(images.shape)
print(labels.shape)
@ -25,15 +27,93 @@ plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
plt.show()
# building nn model
input_size = 784
hidden_sizes = [128, 64]
input_size = 784 # = 28*28
hidden_sizes = [128, 128, 64]
output_size = 10
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(dim=1))
print(model)
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
nn.ReLU(),
nn.Linear(hidden_sizes[2], output_size),
nn.LogSoftmax(dim=-1))
# print(model)
criterion = nn.NLLLoss()
images, labels = next(iter(train_loader))
images = images.view(images.shape[0], -1)
logps = model(images) # log probabilities
loss = criterion(logps, labels) # calculate the NLL loss
# print('Before backward pass: \n', model[0].weight.grad)
loss.backward()
# print('After backward pass: \n', model[0].weight.grad)
# training
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
time0 = time()
epochs = 100
for e in range(epochs):
running_loss = 0
for images, labels in train_loader:
# Flatten MNIST images into a 784 long vector
images = images.view(images.shape[0], -1)
# Training pass
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels)
# This is where the model learns by backpropagating
loss.backward()
# And optimizes its weights here
optimizer.step()
running_loss += loss.item()
else:
print("Epoch {} - Training loss: {}".format(e + 1, running_loss / len(train_loader)))
print("\nTraining Time (in minutes) =", (time() - time0) / 60)
# testing
images, labels = next(iter(val_loader))
print(type(images))
img = images[0].view(1, 784)
with torch.no_grad():
logps = model(img)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))
# view_classify(img.view(1, 28, 28), ps)
# accuracy
correct_count, all_count = 0, 0
for images, labels in val_loader:
for i in range(len(labels)):
img = images[i].view(1, 784)
with torch.no_grad():
logps = model(img)
ps = torch.exp(logps)
probab = list(ps.numpy()[0])
pred_label = probab.index(max(probab))
true_label = labels.numpy()[i]
if true_label == pred_label:
correct_count += 1
all_count += 1
print("Number Of Images Tested =", all_count)
print("\nModel Accuracy =", (correct_count / all_count))
# saving model
# torch.save(model, './digit_reco_model.pt')
torch.save(model, './digit_reco_model2.pt')

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.1 KiB