add mlflow
Some checks failed
s464980-training/pipeline/head There was a failure building this commit
Some checks failed
s464980-training/pipeline/head There was a failure building this commit
This commit is contained in:
parent
98d1c8d505
commit
8ba3d2b97c
6
train.py
6
train.py
@ -3,8 +3,8 @@ from tensorflow import keras
|
|||||||
from tensorflow.keras import layers
|
from tensorflow.keras import layers
|
||||||
import argparse
|
import argparse
|
||||||
import mlflow
|
import mlflow
|
||||||
import mlflow.sklearn
|
|
||||||
mlflow.set_experiment("s464980")
|
|
||||||
class RegressionModel:
|
class RegressionModel:
|
||||||
def __init__(self, optimizer="adam", loss="mean_squared_error"):
|
def __init__(self, optimizer="adam", loss="mean_squared_error"):
|
||||||
self.model = keras.Sequential([
|
self.model = keras.Sequential([
|
||||||
@ -48,8 +48,8 @@ parser = argparse.ArgumentParser()
|
|||||||
parser.add_argument('--epochs')
|
parser.add_argument('--epochs')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
mlflow.set_tracking_uri("http://localhost:5000")
|
||||||
model = RegressionModel()
|
model = RegressionModel()
|
||||||
model.load_data("df_train.csv", "df_test.csv")
|
|
||||||
with mlflow.start_run() as run:
|
with mlflow.start_run() as run:
|
||||||
model.train(epochs=int(args.epochs))
|
model.train(epochs=int(args.epochs))
|
||||||
rmse = model.evaluate()
|
rmse = model.evaluate()
|
||||||
|
Loading…
Reference in New Issue
Block a user