This commit is contained in:
Jakub Zaręba 2023-05-10 23:54:20 +02:00
parent 2dffac65ab
commit 04fe4bb719
2 changed files with 16 additions and 5 deletions

View File

@ -1,7 +1,12 @@
name: s487187-training name: s487187-training
docker_env:
image: python:3.11
volumes: [ "/root/.cache:/root/.cache" ]
user_env_vars: [ "SACRED_IGNORE_GIT" ]
entry_points: entry_points:
main: main:
parameters: parameters:
epochs: int epochs: int
command: "python train.py --epochs {epochs}" command: "python train.py --epochs {epochs}"

View File

@ -1,8 +1,10 @@
from sacred import Experiment from sacred import Experiment
from sacred.observers import MongoObserver, FileStorageObserver from sacred.observers import MongoObserver
import os import os
import mlflow import mlflow
import mlflow.keras import mlflow.keras
from mlflow.models.signature import infer_signature
from mlflow.models import Model
os.environ["SACRED_NO_GIT"] = "1" os.environ["SACRED_NO_GIT"] = "1"
@ -12,12 +14,13 @@ ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017
@ex.config @ex.config
def my_config(): def my_config():
data_file = 'data.csv' data_file = 'data.csv'
model_file = 'model.h5' model_file = 'model'
epochs = 10 epochs = 10
batch_size = 32 batch_size = 32
test_size = 0.2 test_size = 0.2
random_state = 42 random_state = 42
@ex.capture @ex.capture
def train_model(data_file, model_file, epochs, batch_size, test_size, random_state): def train_model(data_file, model_file, epochs, batch_size, test_size, random_state):
import pandas as pd import pandas as pd
@ -62,7 +65,10 @@ def train_model(data_file, model_file, epochs, batch_size, test_size, random_sta
mlflow.log_metric("loss", loss) mlflow.log_metric("loss", loss)
mlflow.log_metric("accuracy", accuracy) mlflow.log_metric("accuracy", accuracy)
mlflow.keras.save_model(model, model_file) signature = infer_signature(X_train, model.predict(X_train))
input_example = Model.log_input_example(X_train.iloc[0])
mlflow.keras.log_model(model, model_file, signature=signature, input_example=input_example)
return accuracy return accuracy
@ -72,4 +78,4 @@ def run_experiment():
ex.log_scalar('accuracy', accuracy) ex.log_scalar('accuracy', accuracy)
ex.add_artifact('model.h5') ex.add_artifact('model.h5')
ex.run() ex.run()