From 66c7e1c583058093cf86f80eeb1bbab0489f8bd3 Mon Sep 17 00:00:00 2001 From: emkarcinos Date: Mon, 9 May 2022 11:20:41 +0200 Subject: [PATCH] Initial MLFlow setup --- Jenkinsfile-train | 2 +- requirements.txt | 3 +- train_model.py | 80 ++++++++++++++++++++++++++--------------------- 3 files changed, 47 insertions(+), 38 deletions(-) diff --git a/Jenkinsfile-train b/Jenkinsfile-train index df90581..be98319 100644 --- a/Jenkinsfile-train +++ b/Jenkinsfile-train @@ -34,7 +34,7 @@ pipeline { } stage('Train model') { steps { - sh "python train_model.py with 'epochs=${params.EPOCHS}' 'batch_size=${params.BATCHSIZE}'" + sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}" } } stage('Archive model and evaluate it') { diff --git a/requirements.txt b/requirements.txt index 13ff2e0..2aa39c9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ torch==1.11.0 numpy~=1.22.3 matplotlib==3.5.2 sacred==0.8.2 -pymongo==4.1.1 \ No newline at end of file +pymongo==4.1.1 +mlflow==1.25.1 \ No newline at end of file diff --git a/train_model.py b/train_model.py index 90ed416..87fc38f 100644 --- a/train_model.py +++ b/train_model.py @@ -1,12 +1,9 @@ -import argparse - -import numpy as np -import pandas as pd import torch -from sacred.observers import FileStorageObserver, MongoObserver +import mlflow +import argparse from torch import nn -from torch.utils.data import DataLoader, Dataset -from sacred import Experiment +from torch.utils.data import DataLoader +from urllib.parse import urlparse from model import PlantsDataset, MLP, train, test @@ -15,8 +12,23 @@ default_epochs = 5 device = "cuda" if torch.cuda.is_available() else "cpu" +mlflow.set_tracking_uri("http://172.17.0.1:5000") +mlflow.set_experiment("s444409") + + +def setup_args(): + args_parser = argparse.ArgumentParser(prefix_chars='-') + args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size) + args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs) + + return args_parser.parse_args() + + +if __name__ == "__main__": + args = setup_args() + batch_size = args.batchSize + epochs = args.epochs -def main(batch_size, epochs, _run): print(f"Using {device} device") plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') @@ -35,37 +47,33 @@ def main(batch_size, epochs, _run): loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - for t in range(epochs): - print(f"Epoch {t + 1}\n-------------------------------") - train(train_dataloader, model, loss_fn, optimizer) - last_loss = test(test_dataloader, model, loss_fn) - _run.log_scalar('training.loss', last_loss, t) + print("Done!") torch.save(model.state_dict(), './model_out') print("Model saved in ./model_out file.") + with mlflow.start_run() as run: + print("MLflow run experiment_id: {0}".format(run.info.experiment_id)) + print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri)) + mlflow.log_param("batch_size", batch_size) + mlflow.log_param("epochs", epochs) + for t in range(epochs): + print(f"Epoch {t + 1}\n-------------------------------") + train(train_dataloader, model, loss_fn, optimizer) + last_loss = test(test_dataloader, model, loss_fn) + mlflow.log_metric("rmse", last_loss) -def setup_experiment(): - ex = Experiment('Predict power output for a given time') - ex.observers.append(FileStorageObserver('sacred_runs')) - ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017', - db_name='sacred')) - return ex - - -ex = setup_experiment() - - -@ex.config -def experiment_config(): - batch_size = 64 - epochs = 5 - - -@ex.automain -def run(batch_size, epochs, _run): - main(batch_size, epochs, _run) - - -ex.add_artifact('model_out') + with torch.no_grad(): + preds = model(plant_test.x_train) + signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy()) + tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme + if tracking_url_type_store != "file": + mlflow.pytorch.log_model( + model, + "s444409-power-plant-model", + registered_model_name="s444409PowerPlant", + signature=signature + ) + else: + mlflow.pytorch.log_model(model, "s444409-power-plant-model", signature=signature)