diff --git a/pytorch/pytorch.py b/pytorch/pytorch.py index 2367513..8e44a1f 100644 --- a/pytorch/pytorch.py +++ b/pytorch/pytorch.py @@ -29,16 +29,18 @@ ex.observers.append(FileStorageObserver('my_runs')) # ex.observers.append(MongoObserver(url='mongodb://mongo_user:mongo_password_IUM_2021@localhost:27017', db_name='sacred')) -#try: -# numberOfEpochParam = int(sys.argv[1]) -#except: -# dafault val -#numberOfEpochParam = 1500 + + + +try: + numberOfEpochParam = int(sys.argv[1]) +except: + numberOfEpochParam = 1500 @ex.config def my_config(): global epochs - epochs = 1500 + epochs = numberOfEpochParam @@ -200,8 +202,8 @@ with open("result.txt", "w+") as file: @ex.automain def main(epochs, _run): lr = 1e-6 - my_config() - print("number of epochs is: ", epochs) + #my_config() + #print("number of epochs is: ", epochs) history5 = fit(epochs, lr, model, train_loader, val_loader) #ex.run()