Zaktualizuj 'train-sacred.py'

This commit is contained in:
Kornelia Girejko 2022-05-08 19:07:49 +02:00
parent 806dff6116
commit da3b0bbee8

View File

@ -75,7 +75,7 @@ criterion = nn.BCELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
@ex.automain
def my_main(epochs, _run):
def my_main(epochs):
# Trening
#num_epochs = EPOCHS
for epochs in range(epochs):
@ -98,7 +98,5 @@ def my_main(epochs, _run):
torch.save(model, "modelP.pkl")
_run.info['accuracy'] = accuracy_score(y_testing, np.argmax(y_predicted, axis=1))
ex.run()