From aa966f8960c12f65de89970be7f8a4a343cfa7f8 Mon Sep 17 00:00:00 2001 From: Marcin Kostrzewski Date: Wed, 11 May 2022 19:16:57 +0200 Subject: [PATCH] Log model.py as artifact --- Jenkinsfile-train | 7 ++++--- train_model.py | 18 +++++++++++------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/Jenkinsfile-train b/Jenkinsfile-train index f317d74..eb05d97 100644 --- a/Jenkinsfile-train +++ b/Jenkinsfile-train @@ -36,12 +36,13 @@ pipeline { stage('Train model') { steps { sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}" + archiveArtifacts artifacts: 'model_out', onlyIfSuccessful: true + archiveArtifacts artifacts: 'mlruns/**', onlyIfSuccessful: true + sh 'rm -r mlruns' } } - stage('Archive model and evaluate it') { + stage('Evaluate model') { steps { - archiveArtifacts artifacts: 'model_out', onlyIfSuccessful: true - archiveArtifacts artifacts: 'mlruns/**', onlyIfSuccessful: true build job: "s444409-evaluation/${params.BRANCH}/", parameters: [string(name: 'BRANCH', value: "${params.BRANCH}")] } } diff --git a/train_model.py b/train_model.py index d648d55..2983507 100644 --- a/train_model.py +++ b/train_model.py @@ -1,10 +1,11 @@ -import torch -import mlflow import argparse +from urllib.parse import urlparse + +import mlflow import numpy as np +import torch from torch import nn from torch.utils.data import DataLoader -from urllib.parse import urlparse from model import PlantsDataset, MLP, train, test @@ -13,7 +14,7 @@ default_epochs = 5 device = "cuda" if torch.cuda.is_available() else "cpu" -mlflow.set_tracking_uri("http://172.17.0.1:5000") +# mlflow.set_tracking_uri("http://172.17.0.1:5000") mlflow.set_experiment("s444409") @@ -50,7 +51,7 @@ if __name__ == "__main__": loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - + print("Done!") torch.save(model.state_dict(), './model_out') @@ -72,12 +73,15 @@ if __name__ == "__main__": signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy()) tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme if tracking_url_type_store != "file": + mlflow.log_artifact('model.py') mlflow.pytorch.log_model( model, "s444409", registered_model_name="s444409", signature=signature, - input_example=input_example + input_example=input_example, + extra_files=['model.py'] ) else: - mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example) + mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example, + extra_files=['model.py'])