add coder.py
This commit is contained in:
parent
4720da3158
commit
38d24273c5
@ -19,12 +19,12 @@
|
||||
<select />
|
||||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="po">
|
||||
<change afterPath="$PROJECT_DIR$/coder/digits_recognizer.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
||||
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="going to pytorch on conda eve">
|
||||
<change afterPath="$PROJECT_DIR$/coder/coder.py" afterDir="false" />
|
||||
<change afterPath="$PROJECT_DIR$/coder/digit_reco_model.pt" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/wozek.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/wozek.iml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/coder/image.py" beforeDir="false" afterPath="$PROJECT_DIR$/coder/image.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/coder/digits_recognizer.py" beforeDir="false" afterPath="$PROJECT_DIR$/coder/digits_recognizer.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/coder/gr_test.png" beforeDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
@ -106,6 +106,25 @@
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="coder" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="wozek" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/coder/coder.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="digits_recognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="wozek" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
@ -147,35 +166,10 @@
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
||||
<module name="wozek" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||
<option name="PARAMETERS" value="" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="rocognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||
<module name="wozek" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
@ -215,18 +209,18 @@
|
||||
</configuration>
|
||||
<list>
|
||||
<item itemvalue="Python.image" />
|
||||
<item itemvalue="Python.main" />
|
||||
<item itemvalue="Python.train_nn" />
|
||||
<item itemvalue="Python.rocognizer" />
|
||||
<item itemvalue="Python.digits_recognizer" />
|
||||
<item itemvalue="Python.coder" />
|
||||
</list>
|
||||
<recent_temporary>
|
||||
<list>
|
||||
<item itemvalue="Python.digits_recognizer" />
|
||||
<item itemvalue="Python.image" />
|
||||
<item itemvalue="Python.coder" />
|
||||
<item itemvalue="Python.rocognizer" />
|
||||
<item itemvalue="Python.image" />
|
||||
<item itemvalue="Python.train_nn" />
|
||||
<item itemvalue="Python.main" />
|
||||
</list>
|
||||
</recent_temporary>
|
||||
</component>
|
||||
@ -262,7 +256,9 @@
|
||||
<workItem from="1590409526059" duration="4922000" />
|
||||
<workItem from="1590423569728" duration="2532000" />
|
||||
<workItem from="1590436739719" duration="6325000" />
|
||||
<workItem from="1590443664804" duration="2683000" />
|
||||
<workItem from="1590443664804" duration="2943000" />
|
||||
<workItem from="1590497613517" duration="6041000" />
|
||||
<workItem from="1590518246722" duration="12460000" />
|
||||
</task>
|
||||
<task id="LOCAL-00001" summary="create Shelf">
|
||||
<created>1589815443652</created>
|
||||
@ -341,7 +337,14 @@
|
||||
<option name="project" value="LOCAL" />
|
||||
<updated>1590359074952</updated>
|
||||
</task>
|
||||
<option name="localTasksCounter" value="12" />
|
||||
<task id="LOCAL-00012" summary="going to pytorch on conda eve">
|
||||
<created>1590447313737</created>
|
||||
<option name="number" value="00012" />
|
||||
<option name="presentableId" value="LOCAL-00012" />
|
||||
<option name="project" value="LOCAL" />
|
||||
<updated>1590447313737</updated>
|
||||
</task>
|
||||
<option name="localTasksCounter" value="13" />
|
||||
<servers />
|
||||
</component>
|
||||
<component name="TypeScriptGeneratedFilesManager">
|
||||
@ -372,7 +375,8 @@
|
||||
<MESSAGE value="finding barcode" />
|
||||
<MESSAGE value="po" />
|
||||
<MESSAGE value="new dataset" />
|
||||
<option name="LAST_COMMIT_MESSAGE" value="new dataset" />
|
||||
<MESSAGE value="going to pytorch on conda eve" />
|
||||
<option name="LAST_COMMIT_MESSAGE" value="going to pytorch on conda eve" />
|
||||
</component>
|
||||
<component name="WindowStateProjectService">
|
||||
<state x="115" y="162" key="#com.intellij.refactoring.safeDelete.UnsafeUsagesDialog" timestamp="1589923610328">
|
||||
@ -399,10 +403,10 @@
|
||||
<screen x="0" y="0" width="1536" height="824" />
|
||||
</state>
|
||||
<state x="277" y="57" key="SettingsEditor/0.0.1536.824@0.0.1536.824" timestamp="1590443566792" />
|
||||
<state x="361" y="145" key="Vcs.Push.Dialog.v2" timestamp="1590359093497">
|
||||
<state x="361" y="145" key="Vcs.Push.Dialog.v2" timestamp="1590447321698">
|
||||
<screen x="0" y="0" width="1536" height="824" />
|
||||
</state>
|
||||
<state x="361" y="145" key="Vcs.Push.Dialog.v2/0.0.1536.824@0.0.1536.824" timestamp="1590359093496" />
|
||||
<state x="361" y="145" key="Vcs.Push.Dialog.v2/0.0.1536.824@0.0.1536.824" timestamp="1590447321698" />
|
||||
<state x="54" y="145" width="672" height="678" key="search.everywhere.popup" timestamp="1589918982407">
|
||||
<screen x="0" y="0" width="1536" height="824" />
|
||||
</state>
|
||||
|
58
coder/coder.py
Normal file
58
coder/coder.py
Normal file
@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import matplotlib.pyplot as plt
|
||||
from time import time
|
||||
from torchvision import datasets, transforms
|
||||
from torch import nn, optim, nn, optim
|
||||
import cv2
|
||||
|
||||
|
||||
def view_classify(img, ps):
|
||||
''' Function for viewing an image and it's predicted classes.
|
||||
'''
|
||||
ps = ps.data.numpy().squeeze()
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
|
||||
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
|
||||
ax1.axis('off')
|
||||
ax2.barh(np.arange(10), ps)
|
||||
ax2.set_aspect(0.1)
|
||||
ax2.set_yticks(np.arange(10))
|
||||
ax2.set_yticklabels(np.arange(10))
|
||||
ax2.set_title('Class Probability')
|
||||
ax2.set_xlim(0, 1.1)
|
||||
plt.tight_layout()
|
||||
|
||||
# load nn model
|
||||
model = torch.load('digit_reco_model2.pt')
|
||||
|
||||
if model is None:
|
||||
print("Model is not loaded.")
|
||||
else:
|
||||
print("Model is loaded.")
|
||||
|
||||
# image
|
||||
img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2GRAY)
|
||||
img = cv2.blur(img, (9, 9)) # poprawia jakosc
|
||||
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
|
||||
img = img.reshape((len(img), -1))
|
||||
|
||||
print(type(img))
|
||||
# print(img.shape)
|
||||
# plt.imshow(img ,cmap='binary')
|
||||
# plt.show()
|
||||
img = np.array(img, dtype=np.float32)
|
||||
img = torch.from_numpy(img)
|
||||
img = img.view(1, 784)
|
||||
|
||||
# recognizing
|
||||
|
||||
with torch.no_grad():
|
||||
logps = model(img)
|
||||
|
||||
ps = torch.exp(logps)
|
||||
probab = list(ps.numpy()[0])
|
||||
print("Predicted Digit =", probab.index(max(probab)))
|
||||
|
||||
view_classify(img.view(1, 28, 28), ps)
|
BIN
coder/digit_reco_model.pt
Normal file
BIN
coder/digit_reco_model.pt
Normal file
Binary file not shown.
BIN
coder/digit_reco_model2.pt
Normal file
BIN
coder/digit_reco_model2.pt
Normal file
Binary file not shown.
@ -6,17 +6,19 @@ from time import time
|
||||
from torchvision import datasets, transforms
|
||||
from torch import nn, optim
|
||||
|
||||
# IMG transform
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (0.5,)),
|
||||
])
|
||||
|
||||
trainset = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
|
||||
valset = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
|
||||
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)
|
||||
# dataset download
|
||||
train_set = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
|
||||
val_set = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
|
||||
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
|
||||
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)
|
||||
|
||||
dataiter = iter(trainloader)
|
||||
images, labels = dataiter.next()
|
||||
data_iter = iter(train_loader)
|
||||
images, labels = data_iter.next()
|
||||
|
||||
print(images.shape)
|
||||
print(labels.shape)
|
||||
@ -25,15 +27,93 @@ plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
|
||||
plt.show()
|
||||
|
||||
# building nn model
|
||||
input_size = 784
|
||||
hidden_sizes = [128, 64]
|
||||
input_size = 784 # = 28*28
|
||||
hidden_sizes = [128, 128, 64]
|
||||
output_size = 10
|
||||
|
||||
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_sizes[1], output_size),
|
||||
nn.LogSoftmax(dim=1))
|
||||
print(model)
|
||||
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_sizes[2], output_size),
|
||||
nn.LogSoftmax(dim=-1))
|
||||
# print(model)
|
||||
|
||||
criterion = nn.NLLLoss()
|
||||
images, labels = next(iter(train_loader))
|
||||
images = images.view(images.shape[0], -1)
|
||||
|
||||
logps = model(images) # log probabilities
|
||||
loss = criterion(logps, labels) # calculate the NLL loss
|
||||
|
||||
# print('Before backward pass: \n', model[0].weight.grad)
|
||||
loss.backward()
|
||||
# print('After backward pass: \n', model[0].weight.grad)
|
||||
|
||||
# training
|
||||
|
||||
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
|
||||
time0 = time()
|
||||
epochs = 100
|
||||
for e in range(epochs):
|
||||
running_loss = 0
|
||||
for images, labels in train_loader:
|
||||
# Flatten MNIST images into a 784 long vector
|
||||
images = images.view(images.shape[0], -1)
|
||||
|
||||
# Training pass
|
||||
optimizer.zero_grad()
|
||||
|
||||
output = model(images)
|
||||
loss = criterion(output, labels)
|
||||
|
||||
# This is where the model learns by backpropagating
|
||||
loss.backward()
|
||||
|
||||
# And optimizes its weights here
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item()
|
||||
else:
|
||||
print("Epoch {} - Training loss: {}".format(e + 1, running_loss / len(train_loader)))
|
||||
|
||||
print("\nTraining Time (in minutes) =", (time() - time0) / 60)
|
||||
|
||||
# testing
|
||||
|
||||
images, labels = next(iter(val_loader))
|
||||
print(type(images))
|
||||
img = images[0].view(1, 784)
|
||||
with torch.no_grad():
|
||||
logps = model(img)
|
||||
ps = torch.exp(logps)
|
||||
probab = list(ps.numpy()[0])
|
||||
print("Predicted Digit =", probab.index(max(probab)))
|
||||
# view_classify(img.view(1, 28, 28), ps)
|
||||
|
||||
# accuracy
|
||||
correct_count, all_count = 0, 0
|
||||
for images, labels in val_loader:
|
||||
for i in range(len(labels)):
|
||||
img = images[i].view(1, 784)
|
||||
with torch.no_grad():
|
||||
logps = model(img)
|
||||
|
||||
ps = torch.exp(logps)
|
||||
probab = list(ps.numpy()[0])
|
||||
pred_label = probab.index(max(probab))
|
||||
true_label = labels.numpy()[i]
|
||||
if true_label == pred_label:
|
||||
correct_count += 1
|
||||
all_count += 1
|
||||
|
||||
print("Number Of Images Tested =", all_count)
|
||||
print("\nModel Accuracy =", (correct_count / all_count))
|
||||
|
||||
|
||||
# saving model
|
||||
|
||||
# torch.save(model, './digit_reco_model.pt')
|
||||
torch.save(model, './digit_reco_model2.pt')
|
Binary file not shown.
Before Width: | Height: | Size: 9.1 KiB |
Loading…
Reference in New Issue
Block a user