MLflow
This commit is contained in:
parent
ba28ee1234
commit
fb163e5653
1
.gitignore
vendored
1
.gitignore
vendored
@ -65,3 +65,4 @@ dev.csv
|
||||
.venv/
|
||||
model.h5
|
||||
evaluation.png
|
||||
mlruns/*
|
11
MLProject
Normal file
11
MLProject
Normal file
@ -0,0 +1,11 @@
|
||||
name: Fifa Players
|
||||
docker_env:
|
||||
image: docker.io/adnovac/ium_s434760:2.0
|
||||
entry_points:
|
||||
train:
|
||||
parameters:
|
||||
batch_size: {type: int, default: 15}
|
||||
epochs: {type: int, default: 16}
|
||||
command: "python train.py {batch_size} {epochs}"
|
||||
evaluate:
|
||||
command: "python evaluate.py"
|
@ -2,8 +2,8 @@ import pandas as pd
|
||||
import numpy as np
|
||||
from os import path
|
||||
from tensorflow import keras
|
||||
import sys
|
||||
import matplotlib.pyplot as plt
|
||||
import mlflow
|
||||
|
||||
model_name = "model.h5"
|
||||
|
||||
@ -16,6 +16,8 @@ Y_test=test_data[["Overall"]].to_numpy()
|
||||
|
||||
#MeanSquaredError
|
||||
results_test = model.evaluate(X_test, Y_test, batch_size=128)
|
||||
mlflow.log_metric("rmse", results_test)
|
||||
|
||||
with open('results.txt', 'a+', encoding="UTF-8") as f:
|
||||
f.write(str(results_test) +"\n")
|
||||
|
||||
|
@ -5,3 +5,4 @@ sklearn
|
||||
tensorflow==2.4.1
|
||||
jinja2==2.11.3
|
||||
matplotlib
|
||||
mlflow==1.17.0
|
10
train.py
10
train.py
@ -2,6 +2,7 @@ import pandas as pd
|
||||
from os import path
|
||||
from tensorflow import keras
|
||||
from tensorflow.keras import layers
|
||||
import mlflow
|
||||
import sys
|
||||
|
||||
model_name = "model.h5"
|
||||
@ -25,11 +26,16 @@ model.compile(
|
||||
loss=keras.losses.MeanSquaredError(),
|
||||
)
|
||||
|
||||
batch_size = int(sys.argv[1])
|
||||
epochs = int(sys.argv[2])
|
||||
mlflow.log_param("batch_size", batch_size)
|
||||
mlflow.log_param("epochs", epochs)
|
||||
|
||||
history = model.fit(
|
||||
X,
|
||||
Y,
|
||||
batch_size=int(sys.argv[1]),
|
||||
epochs=int(sys.argv[2]),
|
||||
batch_size=batch_size,
|
||||
epochs=epochs,
|
||||
)
|
||||
|
||||
model.save(model_name)
|
Loading…
Reference in New Issue
Block a user