add mlflow
This commit is contained in:
parent
793211467e
commit
98d1c8d505
@ -17,9 +17,9 @@ predictions = model.predict(X_test)
|
||||
with open("predictions.txt", "w") as f:
|
||||
f.write(str(predictions))
|
||||
|
||||
accuracy = root_mean_squared_error(y_test, predictions)
|
||||
rmse = root_mean_squared_error(y_test, predictions)
|
||||
with open("rmse.txt", 'a') as file:
|
||||
file.write(str(accuracy)+"\n")
|
||||
file.write(str(rmse)+"\n")
|
||||
|
||||
with open("rmse.txt", 'r') as file:
|
||||
lines = file.readlines()
|
||||
|
@ -3,3 +3,4 @@ scikit-learn
|
||||
tensorflow
|
||||
numpy
|
||||
matplotlib
|
||||
mlflow
|
11
train.py
11
train.py
@ -2,7 +2,9 @@ import pandas as pd
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
import argparse
|
||||
|
||||
import mlflow
|
||||
import mlflow.sklearn
|
||||
mlflow.set_experiment("s464980")
|
||||
class RegressionModel:
|
||||
def __init__(self, optimizer="adam", loss="mean_squared_error"):
|
||||
self.model = keras.Sequential([
|
||||
@ -26,7 +28,6 @@ class RegressionModel:
|
||||
self.y_test = data_test["Performance Index"]
|
||||
|
||||
def train(self, epochs=30):
|
||||
|
||||
self.model.compile(optimizer=self.optimizer, loss=self.loss)
|
||||
self.model.fit(self.X_train, self.y_train, epochs=epochs, batch_size=32, validation_data=(self.X_test, self.y_test))
|
||||
|
||||
@ -37,6 +38,7 @@ class RegressionModel:
|
||||
def evaluate(self):
|
||||
test_loss = self.model.evaluate(self.X_test, self.y_test)
|
||||
print(f"Test Loss: {test_loss:.4f}")
|
||||
return test_loss
|
||||
|
||||
def save_model(self):
|
||||
self.model.save("model.keras")
|
||||
@ -48,6 +50,9 @@ parser.add_argument('--epochs')
|
||||
args = parser.parse_args()
|
||||
model = RegressionModel()
|
||||
model.load_data("df_train.csv", "df_test.csv")
|
||||
with mlflow.start_run() as run:
|
||||
model.train(epochs=int(args.epochs))
|
||||
model.evaluate()
|
||||
rmse = model.evaluate()
|
||||
mlflow.log_param("epoch", int(args.epochs))
|
||||
mlflow.log_metric("rmse", rmse)
|
||||
model.save_model()
|
||||
|
Loading…
Reference in New Issue
Block a user