This commit is contained in:
parent
ed69bd98fa
commit
256b377625
@ -152,6 +152,7 @@ def evaluate(model, val_loader):
|
||||
|
||||
@ex.capture
|
||||
def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD):
|
||||
epochs=epochs
|
||||
history = []
|
||||
optimizer = opt_func(model.parameters(), lr)
|
||||
for epoch in range(epochs):
|
||||
@ -195,11 +196,11 @@ with open("result.txt", "w+") as file:
|
||||
file.write(str(predict_single(input_, target, model)))
|
||||
|
||||
|
||||
@ex.main
|
||||
@ex.automain
|
||||
def main(epochs, _run):
|
||||
lr = 1e-6
|
||||
my_config()
|
||||
print("number of epochs is: ", epochs)
|
||||
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
||||
|
||||
ex.run()
|
||||
#ex.run()
|
||||
|
Loading…
Reference in New Issue
Block a user