hotfix/netural_network #29
@ -7,7 +7,7 @@ from NeuralNetwork import NeuralNetwork
|
|||||||
def getPrediction(img_path):
|
def getPrediction(img_path):
|
||||||
|
|
||||||
# Inicjacja sieci neuronowej
|
# Inicjacja sieci neuronowej
|
||||||
neural_net = NeuralNetwork()
|
neural_net = NeuralNetwork.NeuralNetwork()
|
||||||
PATH = './trained_nn.pth'
|
PATH = './trained_nn.pth'
|
||||||
img = Image.open(img_path)
|
img = Image.open(img_path)
|
||||||
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
|
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)
|
||||||
|
@ -20,7 +20,7 @@ def trainNeuralNetwork():
|
|||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = optim.SGD(neural_net.parameters(), lr=0.001, momentum=0.9)
|
optimizer = optim.SGD(neural_net.parameters(), lr=0.001, momentum=0.9)
|
||||||
|
|
||||||
epoch_num = 4 # najlepiej 10, dla lepszej wiarygodności
|
epoch_num = 10 # najlepiej 10, dla lepszej wiarygodności
|
||||||
for epoch in range(epoch_num):
|
for epoch in range(epoch_num):
|
||||||
measure_loss = 0.0
|
measure_loss = 0.0
|
||||||
for i, data in enumerate(trainloader, 0):
|
for i, data in enumerate(trainloader, 0):
|
||||||
|
7
test_nn.py
Normal file
7
test_nn.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from NeuralNetwork import prediction
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print(prediction.getPrediction("./resources/trash_dataset/test/paper/google-image(0639).jpeg"))
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Binary file not shown.
Loading…
Reference in New Issue
Block a user