hotfix/netural_network #29

Merged
s464843 merged 3 commits from hotfix/netural_network into master 2022-05-26 21:48:29 +02:00
4 changed files with 10 additions and 3 deletions

View File

@ -7,7 +7,7 @@ from NeuralNetwork import NeuralNetwork
def getPrediction(img_path): def getPrediction(img_path):
# Inicjacja sieci neuronowej # Inicjacja sieci neuronowej
neural_net = NeuralNetwork() neural_net = NeuralNetwork.NeuralNetwork()
PATH = './trained_nn.pth' PATH = './trained_nn.pth'
img = Image.open(img_path) img = Image.open(img_path)
transform_tensor = transforms.ToTensor()(img).unsqueeze_(0) transform_tensor = transforms.ToTensor()(img).unsqueeze_(0)

View File

@ -20,7 +20,7 @@ def trainNeuralNetwork():
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(neural_net.parameters(), lr=0.001, momentum=0.9) optimizer = optim.SGD(neural_net.parameters(), lr=0.001, momentum=0.9)
epoch_num = 4 # najlepiej 10, dla lepszej wiarygodności epoch_num = 10 # najlepiej 10, dla lepszej wiarygodności
for epoch in range(epoch_num): for epoch in range(epoch_num):
measure_loss = 0.0 measure_loss = 0.0
for i, data in enumerate(trainloader, 0): for i, data in enumerate(trainloader, 0):

7
test_nn.py Normal file
View File

@ -0,0 +1,7 @@
from NeuralNetwork import prediction
def main():
print(prediction.getPrediction("./resources/trash_dataset/test/paper/google-image(0639).jpeg"))
if __name__ == '__main__':
main()