This commit is contained in:
parent
ed69bd98fa
commit
256b377625
@ -152,6 +152,7 @@ def evaluate(model, val_loader):
|
|||||||
|
|
||||||
@ex.capture
|
@ex.capture
|
||||||
def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD):
|
def fit(epochs, lr, model, train_loader, val_loader, _run, opt_func=torch.optim.SGD):
|
||||||
|
epochs=epochs
|
||||||
history = []
|
history = []
|
||||||
optimizer = opt_func(model.parameters(), lr)
|
optimizer = opt_func(model.parameters(), lr)
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
@ -195,11 +196,11 @@ with open("result.txt", "w+") as file:
|
|||||||
file.write(str(predict_single(input_, target, model)))
|
file.write(str(predict_single(input_, target, model)))
|
||||||
|
|
||||||
|
|
||||||
@ex.main
|
@ex.automain
|
||||||
def main(epochs, _run):
|
def main(epochs, _run):
|
||||||
lr = 1e-6
|
lr = 1e-6
|
||||||
my_config()
|
my_config()
|
||||||
print("number of epochs is: ", epochs)
|
print("number of epochs is: ", epochs)
|
||||||
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
history5 = fit(epochs, lr, model, train_loader, val_loader)
|
||||||
|
|
||||||
ex.run()
|
#ex.run()
|
||||||
|
Loading…
Reference in New Issue
Block a user