From 8ba3d2b97c186cc497a647783e57f83535151ff4 Mon Sep 17 00:00:00 2001 From: Sheaza Date: Tue, 14 May 2024 23:43:06 +0200 Subject: [PATCH] add mlflow --- train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index babcd5c..9a2a74f 100644 --- a/train.py +++ b/train.py @@ -3,8 +3,8 @@ from tensorflow import keras from tensorflow.keras import layers import argparse import mlflow -import mlflow.sklearn -mlflow.set_experiment("s464980") + + class RegressionModel: def __init__(self, optimizer="adam", loss="mean_squared_error"): self.model = keras.Sequential([ @@ -48,8 +48,8 @@ parser = argparse.ArgumentParser() parser.add_argument('--epochs') args = parser.parse_args() +mlflow.set_tracking_uri("http://localhost:5000") model = RegressionModel() -model.load_data("df_train.csv", "df_test.csv") with mlflow.start_run() as run: model.train(epochs=int(args.epochs)) rmse = model.evaluate()