diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index b4a15ea..8d13800 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -87,7 +87,7 @@ train_ds, val_ds = random_split(dataset, [1300, 299]) batch_size=50 train_loader = DataLoader(train_ds, batch_size, shuffle=True) val_loader = DataLoader(val_ds, batch_size) - +lr = 1e-6 # In[8]: @@ -140,6 +140,7 @@ output_size = len(output_cols) def my_config(): global epochs epochs = numberOfEpochParam + lr=lr model=model train_loader=train_loader val_loader=val_loader @@ -203,9 +204,8 @@ with open("result.txt", "w+") as file: @ex.automain def main(): - lr = 1e-6 #my_config() #print("number of epochs is: ", epochs) - history5 = fit(lr) + history5 = fit() #ex.run()