forked from s444420/AL-2020
conda ready
This commit is contained in:
parent
9b083201e8
commit
d8b857bb0c
@ -3,5 +3,5 @@
|
||||
<component name="JavaScriptSettings">
|
||||
<option name="languageLevel" value="ES6" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.7 (AL-2020)" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.8 (AL-2020)" project-jdk-type="Python SDK" />
|
||||
</project>
|
File diff suppressed because one or more lines are too long
@ -4,10 +4,7 @@
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.7 (AL-2020)" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.8 (AL-2020)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="Unittests" />
|
||||
</component>
|
||||
</module>
|
@ -1,21 +1,9 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import matplotlib.pyplot as plt
|
||||
from time import time
|
||||
from torchvision import datasets, transforms
|
||||
from torch import nn, optim
|
||||
import torch.nn.functional as F
|
||||
|
||||
import cv2
|
||||
|
||||
from nn_model import Net
|
||||
from torch import nn, optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
'''
|
||||
Q:
|
||||
what is batch?
|
||||
|
||||
'''
|
||||
n_epochs = 3
|
||||
batch_size_train = 64
|
||||
batch_size_test = 1000
|
||||
|
@ -1,6 +1,5 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
|
@ -1,13 +1,7 @@
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from torch import nn
|
||||
from torchvision.transforms import transforms
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
|
||||
from nn_model import Net
|
||||
from torchvision.transforms import transforms
|
||||
|
||||
|
||||
def recognizer(a_path):
|
||||
@ -30,9 +24,6 @@ def recognizer(a_path):
|
||||
rects = [cv2.boundingRect(ctr) for ctr in ctrs]
|
||||
|
||||
# load nn model
|
||||
input_size = 784 # = 28*28
|
||||
hidden_sizes = [128, 128, 64]
|
||||
output_size = 10
|
||||
model = Net()
|
||||
model.load_state_dict(torch.load('model.pt'))
|
||||
model.eval()
|
||||
@ -60,4 +51,4 @@ def recognizer(a_path):
|
||||
|
||||
|
||||
recognizer("55555.jpg")
|
||||
# print(recognizer("55555.jpg"))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user