s
This commit is contained in:
parent
1c1a1d6658
commit
e1ca9046a6
7
train.py
7
train.py
@ -8,6 +8,7 @@ from sacred.observers import MongoObserver, FileStorageObserver
|
||||
import os
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
import numpy as np
|
||||
|
||||
os.environ["SACRED_NO_GIT"] = "1"
|
||||
|
||||
@ -36,7 +37,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
||||
from imblearn.over_sampling import SMOTE
|
||||
|
||||
smote = SMOTE(random_state=random_state)
|
||||
data = pd.read_csv(data_file, sep=';')
|
||||
data = pd.read_csv(data_file, sep=';', header=0)
|
||||
|
||||
print('Total rows:', len(data))
|
||||
print('Rows with medal:', len(data.dropna(subset=['Medal'])))
|
||||
@ -75,7 +76,9 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
||||
input_signature = {
|
||||
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
|
||||
}
|
||||
signature = infer_signature(input_signature, model.output)
|
||||
X_train_numpy = X_train.to_numpy()
|
||||
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
||||
input_example = X_train.head(1).to_numpy()
|
||||
|
||||
mlflow.keras.log_model(model, "model")
|
||||
mlflow.log_artifact("model.h5")
|
||||
|
Loading…
Reference in New Issue
Block a user