forked from s444420/AL-2020
add coder.py
This commit is contained in:
parent
4720da3158
commit
38d24273c5
@ -19,12 +19,12 @@
|
|||||||
<select />
|
<select />
|
||||||
</component>
|
</component>
|
||||||
<component name="ChangeListManager">
|
<component name="ChangeListManager">
|
||||||
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="po">
|
<list default="true" id="828778c9-9d97-422f-a727-18ddbd059b85" name="Default Changelist" comment="going to pytorch on conda eve">
|
||||||
<change afterPath="$PROJECT_DIR$/coder/digits_recognizer.py" afterDir="false" />
|
<change afterPath="$PROJECT_DIR$/coder/coder.py" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
<change afterPath="$PROJECT_DIR$/coder/digit_reco_model.pt" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/.idea/wozek.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/wozek.iml" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/coder/digits_recognizer.py" beforeDir="false" afterPath="$PROJECT_DIR$/coder/digits_recognizer.py" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/coder/image.py" beforeDir="false" afterPath="$PROJECT_DIR$/coder/image.py" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/coder/gr_test.png" beforeDir="false" />
|
||||||
</list>
|
</list>
|
||||||
<option name="SHOW_DIALOG" value="false" />
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
@ -106,6 +106,25 @@
|
|||||||
<option name="INPUT_FILE" value="" />
|
<option name="INPUT_FILE" value="" />
|
||||||
<method v="2" />
|
<method v="2" />
|
||||||
</configuration>
|
</configuration>
|
||||||
|
<configuration name="coder" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||||
|
<module name="wozek" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
|
||||||
|
<option name="IS_MODULE_SDK" value="true" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/coder/coder.py" />
|
||||||
|
<option name="PARAMETERS" value="" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="true" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
<configuration name="digits_recognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
<configuration name="digits_recognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||||
<module name="wozek" />
|
<module name="wozek" />
|
||||||
<option name="INTERPRETER_OPTIONS" value="" />
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
@ -147,35 +166,10 @@
|
|||||||
<option name="INPUT_FILE" value="" />
|
<option name="INPUT_FILE" value="" />
|
||||||
<method v="2" />
|
<method v="2" />
|
||||||
</configuration>
|
</configuration>
|
||||||
<configuration name="main" type="PythonConfigurationType" factoryName="Python" temporary="true">
|
|
||||||
<module name="wozek" />
|
|
||||||
<option name="INTERPRETER_OPTIONS" value="" />
|
|
||||||
<option name="PARENT_ENVS" value="true" />
|
|
||||||
<envs>
|
|
||||||
<env name="PYTHONUNBUFFERED" value="1" />
|
|
||||||
</envs>
|
|
||||||
<option name="SDK_HOME" value="" />
|
|
||||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$" />
|
|
||||||
<option name="IS_MODULE_SDK" value="true" />
|
|
||||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
|
||||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
|
||||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
|
||||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
|
||||||
<option name="PARAMETERS" value="" />
|
|
||||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
|
||||||
<option name="EMULATE_TERMINAL" value="false" />
|
|
||||||
<option name="MODULE_MODE" value="false" />
|
|
||||||
<option name="REDIRECT_INPUT" value="false" />
|
|
||||||
<option name="INPUT_FILE" value="" />
|
|
||||||
<method v="2" />
|
|
||||||
</configuration>
|
|
||||||
<configuration name="rocognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
<configuration name="rocognizer" type="PythonConfigurationType" factoryName="Python" temporary="true" nameIsGenerated="true">
|
||||||
<module name="wozek" />
|
<module name="wozek" />
|
||||||
<option name="INTERPRETER_OPTIONS" value="" />
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
<option name="PARENT_ENVS" value="true" />
|
<option name="PARENT_ENVS" value="true" />
|
||||||
<envs>
|
|
||||||
<env name="PYTHONUNBUFFERED" value="1" />
|
|
||||||
</envs>
|
|
||||||
<option name="SDK_HOME" value="" />
|
<option name="SDK_HOME" value="" />
|
||||||
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
|
<option name="WORKING_DIRECTORY" value="$PROJECT_DIR$/coder" />
|
||||||
<option name="IS_MODULE_SDK" value="true" />
|
<option name="IS_MODULE_SDK" value="true" />
|
||||||
@ -215,18 +209,18 @@
|
|||||||
</configuration>
|
</configuration>
|
||||||
<list>
|
<list>
|
||||||
<item itemvalue="Python.image" />
|
<item itemvalue="Python.image" />
|
||||||
<item itemvalue="Python.main" />
|
|
||||||
<item itemvalue="Python.train_nn" />
|
<item itemvalue="Python.train_nn" />
|
||||||
<item itemvalue="Python.rocognizer" />
|
<item itemvalue="Python.rocognizer" />
|
||||||
<item itemvalue="Python.digits_recognizer" />
|
<item itemvalue="Python.digits_recognizer" />
|
||||||
|
<item itemvalue="Python.coder" />
|
||||||
</list>
|
</list>
|
||||||
<recent_temporary>
|
<recent_temporary>
|
||||||
<list>
|
<list>
|
||||||
<item itemvalue="Python.digits_recognizer" />
|
<item itemvalue="Python.digits_recognizer" />
|
||||||
<item itemvalue="Python.image" />
|
<item itemvalue="Python.coder" />
|
||||||
<item itemvalue="Python.rocognizer" />
|
<item itemvalue="Python.rocognizer" />
|
||||||
|
<item itemvalue="Python.image" />
|
||||||
<item itemvalue="Python.train_nn" />
|
<item itemvalue="Python.train_nn" />
|
||||||
<item itemvalue="Python.main" />
|
|
||||||
</list>
|
</list>
|
||||||
</recent_temporary>
|
</recent_temporary>
|
||||||
</component>
|
</component>
|
||||||
@ -262,7 +256,9 @@
|
|||||||
<workItem from="1590409526059" duration="4922000" />
|
<workItem from="1590409526059" duration="4922000" />
|
||||||
<workItem from="1590423569728" duration="2532000" />
|
<workItem from="1590423569728" duration="2532000" />
|
||||||
<workItem from="1590436739719" duration="6325000" />
|
<workItem from="1590436739719" duration="6325000" />
|
||||||
<workItem from="1590443664804" duration="2683000" />
|
<workItem from="1590443664804" duration="2943000" />
|
||||||
|
<workItem from="1590497613517" duration="6041000" />
|
||||||
|
<workItem from="1590518246722" duration="12460000" />
|
||||||
</task>
|
</task>
|
||||||
<task id="LOCAL-00001" summary="create Shelf">
|
<task id="LOCAL-00001" summary="create Shelf">
|
||||||
<created>1589815443652</created>
|
<created>1589815443652</created>
|
||||||
@ -341,7 +337,14 @@
|
|||||||
<option name="project" value="LOCAL" />
|
<option name="project" value="LOCAL" />
|
||||||
<updated>1590359074952</updated>
|
<updated>1590359074952</updated>
|
||||||
</task>
|
</task>
|
||||||
<option name="localTasksCounter" value="12" />
|
<task id="LOCAL-00012" summary="going to pytorch on conda eve">
|
||||||
|
<created>1590447313737</created>
|
||||||
|
<option name="number" value="00012" />
|
||||||
|
<option name="presentableId" value="LOCAL-00012" />
|
||||||
|
<option name="project" value="LOCAL" />
|
||||||
|
<updated>1590447313737</updated>
|
||||||
|
</task>
|
||||||
|
<option name="localTasksCounter" value="13" />
|
||||||
<servers />
|
<servers />
|
||||||
</component>
|
</component>
|
||||||
<component name="TypeScriptGeneratedFilesManager">
|
<component name="TypeScriptGeneratedFilesManager">
|
||||||
@ -372,7 +375,8 @@
|
|||||||
<MESSAGE value="finding barcode" />
|
<MESSAGE value="finding barcode" />
|
||||||
<MESSAGE value="po" />
|
<MESSAGE value="po" />
|
||||||
<MESSAGE value="new dataset" />
|
<MESSAGE value="new dataset" />
|
||||||
<option name="LAST_COMMIT_MESSAGE" value="new dataset" />
|
<MESSAGE value="going to pytorch on conda eve" />
|
||||||
|
<option name="LAST_COMMIT_MESSAGE" value="going to pytorch on conda eve" />
|
||||||
</component>
|
</component>
|
||||||
<component name="WindowStateProjectService">
|
<component name="WindowStateProjectService">
|
||||||
<state x="115" y="162" key="#com.intellij.refactoring.safeDelete.UnsafeUsagesDialog" timestamp="1589923610328">
|
<state x="115" y="162" key="#com.intellij.refactoring.safeDelete.UnsafeUsagesDialog" timestamp="1589923610328">
|
||||||
@ -399,10 +403,10 @@
|
|||||||
<screen x="0" y="0" width="1536" height="824" />
|
<screen x="0" y="0" width="1536" height="824" />
|
||||||
</state>
|
</state>
|
||||||
<state x="277" y="57" key="SettingsEditor/0.0.1536.824@0.0.1536.824" timestamp="1590443566792" />
|
<state x="277" y="57" key="SettingsEditor/0.0.1536.824@0.0.1536.824" timestamp="1590443566792" />
|
||||||
<state x="361" y="145" key="Vcs.Push.Dialog.v2" timestamp="1590359093497">
|
<state x="361" y="145" key="Vcs.Push.Dialog.v2" timestamp="1590447321698">
|
||||||
<screen x="0" y="0" width="1536" height="824" />
|
<screen x="0" y="0" width="1536" height="824" />
|
||||||
</state>
|
</state>
|
||||||
<state x="361" y="145" key="Vcs.Push.Dialog.v2/0.0.1536.824@0.0.1536.824" timestamp="1590359093496" />
|
<state x="361" y="145" key="Vcs.Push.Dialog.v2/0.0.1536.824@0.0.1536.824" timestamp="1590447321698" />
|
||||||
<state x="54" y="145" width="672" height="678" key="search.everywhere.popup" timestamp="1589918982407">
|
<state x="54" y="145" width="672" height="678" key="search.everywhere.popup" timestamp="1589918982407">
|
||||||
<screen x="0" y="0" width="1536" height="824" />
|
<screen x="0" y="0" width="1536" height="824" />
|
||||||
</state>
|
</state>
|
||||||
|
58
coder/coder.py
Normal file
58
coder/coder.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from time import time
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
from torch import nn, optim, nn, optim
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
def view_classify(img, ps):
|
||||||
|
''' Function for viewing an image and it's predicted classes.
|
||||||
|
'''
|
||||||
|
ps = ps.data.numpy().squeeze()
|
||||||
|
|
||||||
|
fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
|
||||||
|
ax1.imshow(img.resize_(1, 28, 28).numpy().squeeze())
|
||||||
|
ax1.axis('off')
|
||||||
|
ax2.barh(np.arange(10), ps)
|
||||||
|
ax2.set_aspect(0.1)
|
||||||
|
ax2.set_yticks(np.arange(10))
|
||||||
|
ax2.set_yticklabels(np.arange(10))
|
||||||
|
ax2.set_title('Class Probability')
|
||||||
|
ax2.set_xlim(0, 1.1)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# load nn model
|
||||||
|
model = torch.load('digit_reco_model2.pt')
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
print("Model is not loaded.")
|
||||||
|
else:
|
||||||
|
print("Model is loaded.")
|
||||||
|
|
||||||
|
# image
|
||||||
|
img = cv2.cvtColor(cv2.imread('test3.png'), cv2.COLOR_BGR2GRAY)
|
||||||
|
img = cv2.blur(img, (9, 9)) # poprawia jakosc
|
||||||
|
img = cv2.resize(img, (28, 28), interpolation=cv2.INTER_AREA)
|
||||||
|
img = img.reshape((len(img), -1))
|
||||||
|
|
||||||
|
print(type(img))
|
||||||
|
# print(img.shape)
|
||||||
|
# plt.imshow(img ,cmap='binary')
|
||||||
|
# plt.show()
|
||||||
|
img = np.array(img, dtype=np.float32)
|
||||||
|
img = torch.from_numpy(img)
|
||||||
|
img = img.view(1, 784)
|
||||||
|
|
||||||
|
# recognizing
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logps = model(img)
|
||||||
|
|
||||||
|
ps = torch.exp(logps)
|
||||||
|
probab = list(ps.numpy()[0])
|
||||||
|
print("Predicted Digit =", probab.index(max(probab)))
|
||||||
|
|
||||||
|
view_classify(img.view(1, 28, 28), ps)
|
BIN
coder/digit_reco_model.pt
Normal file
BIN
coder/digit_reco_model.pt
Normal file
Binary file not shown.
BIN
coder/digit_reco_model2.pt
Normal file
BIN
coder/digit_reco_model2.pt
Normal file
Binary file not shown.
@ -6,17 +6,19 @@ from time import time
|
|||||||
from torchvision import datasets, transforms
|
from torchvision import datasets, transforms
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
|
# IMG transform
|
||||||
transform = transforms.Compose([transforms.ToTensor(),
|
transform = transforms.Compose([transforms.ToTensor(),
|
||||||
transforms.Normalize((0.5,), (0.5,)),
|
transforms.Normalize((0.5,), (0.5,)),
|
||||||
])
|
])
|
||||||
|
|
||||||
trainset = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
|
# dataset download
|
||||||
valset = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
|
train_set = datasets.MNIST('PATH_TO_STORE_TRAINSET', download=True, train=True, transform=transform)
|
||||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
|
val_set = datasets.MNIST('PATH_TO_STORE_TESTSET', download=True, train=False, transform=transform)
|
||||||
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)
|
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
|
||||||
|
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
dataiter = iter(trainloader)
|
data_iter = iter(train_loader)
|
||||||
images, labels = dataiter.next()
|
images, labels = data_iter.next()
|
||||||
|
|
||||||
print(images.shape)
|
print(images.shape)
|
||||||
print(labels.shape)
|
print(labels.shape)
|
||||||
@ -25,15 +27,93 @@ plt.imshow(images[0].numpy().squeeze(), cmap='gray_r')
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# building nn model
|
# building nn model
|
||||||
input_size = 784
|
input_size = 784 # = 28*28
|
||||||
hidden_sizes = [128, 64]
|
hidden_sizes = [128, 128, 64]
|
||||||
output_size = 10
|
output_size = 10
|
||||||
|
|
||||||
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
|
model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(hidden_sizes[1], output_size),
|
nn.Linear(hidden_sizes[1], hidden_sizes[2]),
|
||||||
nn.LogSoftmax(dim=1))
|
nn.ReLU(),
|
||||||
print(model)
|
nn.Linear(hidden_sizes[2], output_size),
|
||||||
|
nn.LogSoftmax(dim=-1))
|
||||||
|
# print(model)
|
||||||
|
|
||||||
|
criterion = nn.NLLLoss()
|
||||||
|
images, labels = next(iter(train_loader))
|
||||||
|
images = images.view(images.shape[0], -1)
|
||||||
|
|
||||||
|
logps = model(images) # log probabilities
|
||||||
|
loss = criterion(logps, labels) # calculate the NLL loss
|
||||||
|
|
||||||
|
# print('Before backward pass: \n', model[0].weight.grad)
|
||||||
|
loss.backward()
|
||||||
|
# print('After backward pass: \n', model[0].weight.grad)
|
||||||
|
|
||||||
|
# training
|
||||||
|
|
||||||
|
optimizer = optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
|
||||||
|
time0 = time()
|
||||||
|
epochs = 100
|
||||||
|
for e in range(epochs):
|
||||||
|
running_loss = 0
|
||||||
|
for images, labels in train_loader:
|
||||||
|
# Flatten MNIST images into a 784 long vector
|
||||||
|
images = images.view(images.shape[0], -1)
|
||||||
|
|
||||||
|
# Training pass
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
output = model(images)
|
||||||
|
loss = criterion(output, labels)
|
||||||
|
|
||||||
|
# This is where the model learns by backpropagating
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
# And optimizes its weights here
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
running_loss += loss.item()
|
||||||
|
else:
|
||||||
|
print("Epoch {} - Training loss: {}".format(e + 1, running_loss / len(train_loader)))
|
||||||
|
|
||||||
|
print("\nTraining Time (in minutes) =", (time() - time0) / 60)
|
||||||
|
|
||||||
|
# testing
|
||||||
|
|
||||||
|
images, labels = next(iter(val_loader))
|
||||||
|
print(type(images))
|
||||||
|
img = images[0].view(1, 784)
|
||||||
|
with torch.no_grad():
|
||||||
|
logps = model(img)
|
||||||
|
ps = torch.exp(logps)
|
||||||
|
probab = list(ps.numpy()[0])
|
||||||
|
print("Predicted Digit =", probab.index(max(probab)))
|
||||||
|
# view_classify(img.view(1, 28, 28), ps)
|
||||||
|
|
||||||
|
# accuracy
|
||||||
|
correct_count, all_count = 0, 0
|
||||||
|
for images, labels in val_loader:
|
||||||
|
for i in range(len(labels)):
|
||||||
|
img = images[i].view(1, 784)
|
||||||
|
with torch.no_grad():
|
||||||
|
logps = model(img)
|
||||||
|
|
||||||
|
ps = torch.exp(logps)
|
||||||
|
probab = list(ps.numpy()[0])
|
||||||
|
pred_label = probab.index(max(probab))
|
||||||
|
true_label = labels.numpy()[i]
|
||||||
|
if true_label == pred_label:
|
||||||
|
correct_count += 1
|
||||||
|
all_count += 1
|
||||||
|
|
||||||
|
print("Number Of Images Tested =", all_count)
|
||||||
|
print("\nModel Accuracy =", (correct_count / all_count))
|
||||||
|
|
||||||
|
|
||||||
|
# saving model
|
||||||
|
|
||||||
|
# torch.save(model, './digit_reco_model.pt')
|
||||||
|
torch.save(model, './digit_reco_model2.pt')
|
Binary file not shown.
Before Width: | Height: | Size: 9.1 KiB |
Loading…
Reference in New Issue
Block a user