automain
s444354-training/pipeline/head There was a failure building this commit Details

This commit is contained in:
Adrian Charkiewicz 2022-05-07 23:03:06 +02:00
parent ed69bd98fa
commit 256b377625
1 changed files with 3 additions and 2 deletions

View File

@ -152,6 +152,7 @@ def evaluate(model, val_loader):
@ex.capture @ex.capture
def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD): def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD):
epochs=epochs
history = [] history = []
optimizer = opt_func(model.parameters(), lr) optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs): for epoch in range(epochs):
@ -195,11 +196,11 @@ with open("result.txt", "w+") as file:
file.write(str(predict_single(input_, target, model))) file.write(str(predict_single(input_, target, model)))
@ex.main @ex.automain
def main(epochs, _run): def main(epochs, _run):
lr = 1e-6 lr = 1e-6
my_config() my_config()
print("number of epochs is: ", epochs) print("number of epochs is: ", epochs)
history5 = fit(epochs, lr, model, train_loader, val_loader) history5 = fit(epochs, lr, model, train_loader, val_loader)
ex.run() #ex.run()