diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index bde8aab..4d89412 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -149,7 +149,7 @@ def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): # In[12]: -epochs = sys.argv[1] +epochs = int(sys.argv[1]) lr = 1e-6 history5 = fit(epochs, lr, model, train_loader, val_loader)