assure that epochs is saved
s444354-training/pipeline/head There was a failure building this commit Details

This commit is contained in:
Adrian Charkiewicz 2022-05-07 21:45:27 +02:00
parent 718378d69a
commit 2a5e48f63c
1 changed files with 3 additions and 2 deletions

View File

@ -38,7 +38,7 @@ except:
@ex.config @ex.config
def my_config(): def my_config():
epochs = numberOfEpochParam epochs = numberOfEpochParam
my_config()
dataframe_raw = pd.read_csv("winequality-red.csv") dataframe_raw = pd.read_csv("winequality-red.csv")
@ -150,7 +150,7 @@ def evaluate(model, val_loader):
return model.validation_epoch_end(outputs) return model.validation_epoch_end(outputs)
@ex.capture @ex.capture
def fit(epochs, lr, model, train_loader, val_loader,_run, opt_func=torch.optim.SGD): def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD):
history = [] history = []
optimizer = opt_func(model.parameters(), lr) optimizer = opt_func(model.parameters(), lr)
for epoch in range(epochs): for epoch in range(epochs):
@ -216,6 +216,7 @@ with open("result.txt", "w+") as file:
@ex.main @ex.main
def main(): def main():
lr = 1e-6 lr = 1e-6
my_config()
history5 = fit(epochs, lr, model, train_loader, val_loader) history5 = fit(epochs, lr, model, train_loader, val_loader)
ex.run() ex.run()