ium_434704/ml_model.py
Wojciech Jarmosz 8c94666b25
All checks were successful
s434704-training/pipeline/head This commit looks good
s434704-evaluation/pipeline/head This commit looks good
Add registry saving
2021-05-23 20:12:07 +02:00

62 lines
1.9 KiB
Python

import pandas as pd
import numpy as np
import tensorflow as tf
import os.path
import mlflow
import sys
from mlflow.tracking import MlflowClient
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers.experimental import preprocessing
arguments = sys.argv[1:]
verbose = int(arguments[0])
epochs = int(arguments[1])
# Wczytanie danych
train_data = pd.read_csv("./MoviesOnStreamingPlatforms_updated.train")
test_data = pd.read_csv("./MoviesOnStreamingPlatforms_updated.test")
# Stworzenie modelu
columns_to_use = ['Year', 'Runtime', 'Netflix']
train_X = tf.convert_to_tensor(train_data[columns_to_use])
train_Y = tf.convert_to_tensor(train_data[["IMDb"]])
test_X = tf.convert_to_tensor(test_data[columns_to_use])
test_Y = tf.convert_to_tensor(test_data[["IMDb"]])
normalizer = preprocessing.Normalization(input_shape=[3,])
normalizer.adapt(train_X)
model = keras.Sequential([
keras.Input(shape=(len(columns_to_use),)),
normalizer,
layers.Dense(30, activation='relu'),
layers.Dense(10, activation='relu'),
layers.Dense(25, activation='relu'),
layers.Dense(1)
])
model.compile(loss='mean_absolute_error',
optimizer=tf.keras.optimizers.Adam(0.001),
metrics=[tf.keras.metrics.RootMeanSquaredError()])
model.fit(train_X, train_Y, verbose=verbose, epochs=epochs)
signature = mlflow.models.signature.infer_signature(train_X.numpy(), model.predict(train_X.numpy()))
input_data = test_X
# Dane do rejestracji modelu w MlFlow
mlflow.set_tracking_uri("http://172.17.0.1:5000")
client = MlflowClient()
model_name = "s434704"
with mlflow.start_run():
mlflow.keras.log_model(model, "movies_on_streaming_platforms_model", registered_model_name="s434704", input_example=input_data.numpy(), signature=signature)
mlflow.keras.save_model(model, "movies_on_streaming_platforms_model", input_example=input_data.numpy(), signature=signature)