This commit is contained in:
Jakub Zaręba 2023-05-11 01:06:21 +02:00
parent 1c1a1d6658
commit e1ca9046a6

View File

@ -8,6 +8,7 @@ from sacred.observers import MongoObserver, FileStorageObserver
import os
import tensorflow as tf
from tensorflow.python.framework import tensor_spec
import numpy as np
os.environ["SACRED_NO_GIT"] = "1"
@ -36,7 +37,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
from imblearn.over_sampling import SMOTE
smote = SMOTE(random_state=random_state)
data = pd.read_csv(data_file, sep=';')
data = pd.read_csv(data_file, sep=';', header=0)
print('Total rows:', len(data))
print('Rows with medal:', len(data.dropna(subset=['Medal'])))
@ -75,7 +76,9 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
input_signature = {
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
}
signature = infer_signature(input_signature, model.output)
X_train_numpy = X_train.to_numpy()
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
input_example = X_train.head(1).to_numpy()
mlflow.keras.log_model(model, "model")
mlflow.log_artifact("model.h5")