diff --git a/create-dataset.py b/create-dataset.py index 132e88f..81def55 100644 --- a/create-dataset.py +++ b/create-dataset.py @@ -23,4 +23,4 @@ Y_test.to_csv('Y_test.csv', index=False) Y_train.to_csv('Y_train.csv', index=False) Y_dev.to_csv('Y_dev.csv', index=False) -train.main() \ No newline at end of file +train.main(0) diff --git a/train.py b/train.py index 70bed0e..4eabf73 100644 --- a/train.py +++ b/train.py @@ -5,8 +5,9 @@ import tensorflow from keras.applications.densenet import layers -def main(): - EPOCHS = int(os.environ['EPOCHS']) +def main(EPOCHS): + if EPOCHS == 0: + EPOCHS = 500 train_data_x = pd.read_csv('./X_train.csv') adults_train = train_data_x.copy() @@ -30,4 +31,5 @@ def main(): if __name__ == "__main__": - main() + EPOCHS = int(os.environ['EPOCHS']) + main(EPOCHS)