fix mlflow

This commit is contained in:
Sheaza 2024-05-15 00:30:21 +02:00
parent 32448bea84
commit 2b853809be
5 changed files with 23 additions and 9 deletions

13
MLFLOW/Dockerfile Normal file
View File

@ -0,0 +1,13 @@
FROM python:3.11
RUN apt-get update && apt-get -y upgrade
RUN apt-get install -y build-essential
RUN python -m pip install --upgrade pip
COPY ../requirements.txt /tmp
RUN python -m pip install -r /tmp/requirements.txt
WORKDIR ./app
COPY train.py ./
CMD bash

3
MLFLOW/requirements.txt Normal file
View File

@ -0,0 +1,3 @@
pandas
tensorflow
mlflow

View File

@ -3,8 +3,6 @@ from tensorflow import keras
from tensorflow.keras import layers from tensorflow.keras import layers
import argparse import argparse
import mlflow import mlflow
class RegressionModel: class RegressionModel:
def __init__(self, optimizer="adam", loss="mean_squared_error"): def __init__(self, optimizer="adam", loss="mean_squared_error"):
self.model = keras.Sequential([ self.model = keras.Sequential([
@ -48,11 +46,11 @@ parser = argparse.ArgumentParser()
parser.add_argument('--epochs') parser.add_argument('--epochs')
args = parser.parse_args() args = parser.parse_args()
mlflow.set_tracking_uri("http://localhost:5000")
model = RegressionModel()
with mlflow.start_run() as run: with mlflow.start_run() as run:
model = RegressionModel()
model.load_data("df_train.csv", "df_test.csv")
model.train(epochs=int(args.epochs)) model.train(epochs=int(args.epochs))
mlflow.log_param("epochs", int(args.epochs))
rmse = model.evaluate() rmse = model.evaluate()
mlflow.log_param("epoch", int(args.epochs))
mlflow.log_metric("rmse", rmse) mlflow.log_metric("rmse", rmse)
model.save_model() model.save_model()

View File

@ -5,9 +5,9 @@ pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None) pd.set_option('display.max_rows', None)
df = pd.read_csv('./dataset.csv') df = pd.read_csv('MLFLOW/dataset.csv')
df_train = pd.read_csv('./df_train.csv') df_train = pd.read_csv('MLFLOW/df_train.csv')
df_test = pd.read_csv('./df_test.csv') df_test = pd.read_csv('MLFLOW/df_test.csv')
with open('stats.txt', 'w') as f: with open('stats.txt', 'w') as f:

View File

@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
np.set_printoptions(threshold=np.inf) np.set_printoptions(threshold=np.inf)
data = pd.read_csv("df_test.csv") data = pd.read_csv("MLFLOW/df_test.csv")
X_test = data.drop("Performance Index", axis=1) X_test = data.drop("Performance Index", axis=1)
y_test = data["Performance Index"] y_test = data["Performance Index"]