diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index 91ac241..efe1254 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -37,11 +37,6 @@ try: except: numberOfEpochParam = 1500 -@ex.config -def my_config(): - global epochs - epochs = numberOfEpochParam - @@ -141,7 +136,14 @@ output_size = len(output_cols) # In[10]: - +@ex.config +def my_config(): + global epochs + epochs = numberOfEpochParam + lr=lr + model=model + train_loader=train_loader + val_loader=val_loader model=WineQuality() @@ -205,6 +207,6 @@ def main(): lr = 1e-6 #my_config() #print("number of epochs is: ", epochs) - history5 = fit(lr, model, train_loader, val_loader) + history5 = fit() #ex.run()