fix github workflow
This commit is contained in:
parent
578976bbe4
commit
81aa2cbbec
@ -12,8 +12,8 @@ epochs = int(sys.argv[1])
|
|||||||
learning_rate = float(sys.argv[2])
|
learning_rate = float(sys.argv[2])
|
||||||
batch_size = int(sys.argv[3])
|
batch_size = int(sys.argv[3])
|
||||||
|
|
||||||
hp_train = pd.read_csv('hp_train.csv')
|
hp_train = pd.read_csv('./github_project/hp_train.csv')
|
||||||
hp_dev = pd.read_csv('hp_dev.csv')
|
hp_dev = pd.read_csv('./github_project/hp_dev.csv')
|
||||||
|
|
||||||
X_train, Y_train = prepare_tensors(hp_train)
|
X_train, Y_train = prepare_tensors(hp_train)
|
||||||
X_dev, Y_dev = prepare_tensors(hp_dev)
|
X_dev, Y_dev = prepare_tensors(hp_dev)
|
||||||
@ -30,7 +30,7 @@ model.compile(optimizer=adam, loss='mean_squared_error')
|
|||||||
|
|
||||||
model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_dev, Y_dev))
|
model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_data=(X_dev, Y_dev))
|
||||||
|
|
||||||
model.save('hp_model.h5')
|
model.save('./github_project/hp_model.h5')
|
||||||
|
|
||||||
with mlflow.start_run() as run:
|
with mlflow.start_run() as run:
|
||||||
mlflow.log_param("epochs", epochs)
|
mlflow.log_param("epochs", epochs)
|
||||||
|
@ -14,15 +14,15 @@ if len(sys.argv) > 1:
|
|||||||
else:
|
else:
|
||||||
build_number = 0
|
build_number = 0
|
||||||
|
|
||||||
hp_test = pd.read_csv('hp_test.csv')
|
hp_test = pd.read_csv('./github_project/hp_test.csv')
|
||||||
X_test, Y_test = prepare_tensors(hp_test)
|
X_test, Y_test = prepare_tensors(hp_test)
|
||||||
|
|
||||||
model = load_model('hp_model.h5')
|
model = load_model('./github_project/hp_model.h5')
|
||||||
|
|
||||||
test_predictions = model.predict(X_test)
|
test_predictions = model.predict(X_test)
|
||||||
|
|
||||||
predictions_df = pd.DataFrame(test_predictions, columns=["Predicted_Price"])
|
predictions_df = pd.DataFrame(test_predictions, columns=["Predicted_Price"])
|
||||||
predictions_df.to_csv('hp_test_predictions.csv', index=False)
|
predictions_df.to_csv('./github_project/hp_test_predictions.csv', index=False)
|
||||||
|
|
||||||
rmse = np.sqrt(mean_squared_error(Y_test, test_predictions))
|
rmse = np.sqrt(mean_squared_error(Y_test, test_predictions))
|
||||||
mae = mean_absolute_error(Y_test, test_predictions)
|
mae = mean_absolute_error(Y_test, test_predictions)
|
||||||
@ -35,7 +35,7 @@ metrics_df = pd.DataFrame({
|
|||||||
'R2': [r2]
|
'R2': [r2]
|
||||||
})
|
})
|
||||||
|
|
||||||
metrics_file = 'hp_test_metrics.csv'
|
metrics_file = './github_project/hp_test_metrics.csv'
|
||||||
if os.path.isfile(metrics_file):
|
if os.path.isfile(metrics_file):
|
||||||
existing_metrics_df = pd.read_csv(metrics_file)
|
existing_metrics_df = pd.read_csv(metrics_file)
|
||||||
updated_metrics_df = pd.concat([existing_metrics_df, metrics_df], ignore_index=True)
|
updated_metrics_df = pd.concat([existing_metrics_df, metrics_df], ignore_index=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user