diff --git a/sacred_train.py b/sacred_train.py index 0623a99..0488419 100644 --- a/sacred_train.py +++ b/sacred_train.py @@ -63,7 +63,7 @@ def my_config(): epochs = 10 learning_rate = 0.001 -@ex.main +@ex.automain def my_main(epochs, learning_rate, _run): ex.open_resource("Customers.csv", "r") @@ -97,7 +97,7 @@ def my_main(epochs, learning_rate, _run): ex.add_artifact("cifar_net.pth") -if __name__ == '__main__': +#if __name__ == '__main__': - ex.run() + #ex.run()