diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index 28b0975..96099e1 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -204,6 +204,6 @@ def main(): lr = 1e-6 #my_config() #print("number of epochs is: ", epochs) - history5 = fit(epochs, lr, model, train_loader, val_loader) + history5 = fit(lr, model, train_loader, val_loader) #ex.run()