diff --git a/training.py b/training.py index 0348e6b..38820bb 100644 --- a/training.py +++ b/training.py @@ -14,6 +14,8 @@ from tensorflow import keras from tensorflow.keras import layers from tensorflow.keras.layers.experimental import preprocessing +EPOCHS = int(sys.argv[1]) +BATCH_SIZE = int(sys.argv[2]) age = {"5-14 years": 0, "15-24 years": 1, "25-34 years": 2, "35-54 years": 3, "55-74 years": 4, "75+ years": 5} @@ -64,8 +66,8 @@ model.compile( # Train model history = model.fit( X_train, y_train, - batch_size=int(sys.argv[0]), - epochs=int(sys.argv[1]), + batch_size=BATCH_SIZE, + epochs=EPOCHS, validation_split=0.2) test_results = {}