Merge branch 'neural_network' of D:\Documents\Studia\Semestr-4\sztuczna-inteligencja\super-saper with conflicts.

This commit is contained in:
dar-gol 2021-05-30 15:46:45 +02:00
parent 72c6c5e2d8
commit 61419f5f64

View File

@ -22,29 +22,28 @@ if __name__ == '__main__':
nn.Linear(input_dim, output_dim), nn.Linear(input_dim, output_dim),
nn.LogSoftmax() nn.LogSoftmax()
) )
def train(model, n_iter): def train(model, n_iter):
criterion = nn.NLLLoss() criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001) optimizer = optim.SGD(model.parameters(), lr=0.001)
for epoch in range(n_iter): for epoch in range(n_iter):
for image, label in zip(train_images, train_labels): for image, label in zip(train_images, train_labels):
print(image.shape)
optimizer.zero_grad() optimizer.zero_grad()
output = model(image) output = model(image)
loss = criterion(output.unsqueeze(0), label.unsqueeze(0)) loss = criterion(output.unsqueeze(0), label.unsqueeze(0))
loss.backward() loss.backward()
optimizer.step() optimizer.step()
print(f'epoch: {epoch:03}') print(f'epoch: {epoch:03}')
train(model, 100) train(model, 100)
# def accuracy(expected, predicted): # def accuracy(expected, predicted):
# return len([1 for e, p in zip(expected, predicted) if e == p]) / len(expected) # return len([_ for e, p in zip(expected, predicted) if e == p]) / len(expected)
# #
# #
# predicted = [model(image).argmax() for image in train_images] # predicted = [model(image).argmax() for image in train_images]