From 61419f5f643089be315efe1b03c62578b2fd526b Mon Sep 17 00:00:00 2001 From: dar-gol Date: Sun, 30 May 2021 15:46:45 +0200 Subject: [PATCH] Merge branch 'neural_network' of D:\Documents\Studia\Semestr-4\sztuczna-inteligencja\super-saper with conflicts. --- .../neural_network/learning_neural_network.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/machine_learning/neural_network/learning_neural_network.py b/src/machine_learning/neural_network/learning_neural_network.py index 69ab85a..1f503ef 100644 --- a/src/machine_learning/neural_network/learning_neural_network.py +++ b/src/machine_learning/neural_network/learning_neural_network.py @@ -22,29 +22,28 @@ if __name__ == '__main__': nn.Linear(input_dim, output_dim), nn.LogSoftmax() ) - - + + def train(model, n_iter): criterion = nn.NLLLoss() optimizer = optim.SGD(model.parameters(), lr=0.001) - + for epoch in range(n_iter): for image, label in zip(train_images, train_labels): - print(image.shape) optimizer.zero_grad() - + output = model(image) loss = criterion(output.unsqueeze(0), label.unsqueeze(0)) loss.backward() optimizer.step() - + print(f'epoch: {epoch:03}') train(model, 100) # def accuracy(expected, predicted): - # return len([1 for e, p in zip(expected, predicted) if e == p]) / len(expected) + # return len([_ for e, p in zip(expected, predicted) if e == p]) / len(expected) # # # predicted = [model(image).argmax() for image in train_images]