s
This commit is contained in:
parent
1c1a1d6658
commit
e1ca9046a6
7
train.py
7
train.py
@ -8,6 +8,7 @@ from sacred.observers import MongoObserver, FileStorageObserver
|
|||||||
import os
|
import os
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
os.environ["SACRED_NO_GIT"] = "1"
|
os.environ["SACRED_NO_GIT"] = "1"
|
||||||
|
|
||||||
@ -36,7 +37,7 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
|||||||
from imblearn.over_sampling import SMOTE
|
from imblearn.over_sampling import SMOTE
|
||||||
|
|
||||||
smote = SMOTE(random_state=random_state)
|
smote = SMOTE(random_state=random_state)
|
||||||
data = pd.read_csv(data_file, sep=';')
|
data = pd.read_csv(data_file, sep=';', header=0)
|
||||||
|
|
||||||
print('Total rows:', len(data))
|
print('Total rows:', len(data))
|
||||||
print('Rows with medal:', len(data.dropna(subset=['Medal'])))
|
print('Rows with medal:', len(data.dropna(subset=['Medal'])))
|
||||||
@ -75,7 +76,9 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
|
|||||||
input_signature = {
|
input_signature = {
|
||||||
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
|
'input': tensor_spec.TensorSpec(shape=X_train[0].shape, dtype=X_train.dtype)
|
||||||
}
|
}
|
||||||
signature = infer_signature(input_signature, model.output)
|
X_train_numpy = X_train.to_numpy()
|
||||||
|
signature = infer_signature(X_train_numpy, model.predict(X_train_numpy))
|
||||||
|
input_example = X_train.head(1).to_numpy()
|
||||||
|
|
||||||
mlflow.keras.log_model(model, "model")
|
mlflow.keras.log_model(model, "model")
|
||||||
mlflow.log_artifact("model.h5")
|
mlflow.log_artifact("model.h5")
|
||||||
|
Loading…
Reference in New Issue
Block a user