diff --git a/nn_train.py b/nn_train.py index 8f4331b..b9e8201 100644 --- a/nn_train.py +++ b/nn_train.py @@ -46,6 +46,11 @@ encoder.fit(y_test_set) encoded_Ytt = encoder.transform(y_test_set) dummy_ytt = np_utils.to_categorical(encoded_Ytt) +try: + no_epochs=int(sys.argv[1]) +except: + no_epochs = 200 + # model definition number_of_classes = 33 number_of_features = 5 @@ -53,7 +58,7 @@ model = Sequential() model.add(Dense(number_of_classes, activation='relu')) model.add(Dense(number_of_classes, activation='softmax',input_dim=number_of_features)) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', 'categorical_accuracy']) -model.fit(x_train_set, dummy_y, epochs=int(sys.argv[1]), validation_data=(x_validate_set, dummy_yv)) +model.fit(x_train_set, dummy_y, epochs=no_epochs, validation_data=(x_validate_set, dummy_yv)) model.save("my_model/") #model predictions