Log model.py as artifact
All checks were successful
s444409-evaluation/pipeline/head This commit looks good
s444409-training/pipeline/head This commit looks good

This commit is contained in:
Marcin Kostrzewski 2022-05-11 19:16:57 +02:00
parent 8a967075c4
commit aa966f8960
2 changed files with 15 additions and 10 deletions

View File

@ -36,12 +36,13 @@ pipeline {
stage('Train model') { stage('Train model') {
steps { steps {
sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}" sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}"
archiveArtifacts artifacts: 'model_out', onlyIfSuccessful: true
archiveArtifacts artifacts: 'mlruns/**', onlyIfSuccessful: true
sh 'rm -r mlruns'
} }
} }
stage('Archive model and evaluate it') { stage('Evaluate model') {
steps { steps {
archiveArtifacts artifacts: 'model_out', onlyIfSuccessful: true
archiveArtifacts artifacts: 'mlruns/**', onlyIfSuccessful: true
build job: "s444409-evaluation/${params.BRANCH}/", parameters: [string(name: 'BRANCH', value: "${params.BRANCH}")] build job: "s444409-evaluation/${params.BRANCH}/", parameters: [string(name: 'BRANCH', value: "${params.BRANCH}")]
} }
} }

View File

@ -1,10 +1,11 @@
import torch
import mlflow
import argparse import argparse
from urllib.parse import urlparse
import mlflow
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from urllib.parse import urlparse
from model import PlantsDataset, MLP, train, test from model import PlantsDataset, MLP, train, test
@ -13,7 +14,7 @@ default_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else "cpu"
mlflow.set_tracking_uri("http://172.17.0.1:5000") # mlflow.set_tracking_uri("http://172.17.0.1:5000")
mlflow.set_experiment("s444409") mlflow.set_experiment("s444409")
@ -50,7 +51,7 @@ if __name__ == "__main__":
loss_fn = nn.MSELoss() loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
print("Done!") print("Done!")
torch.save(model.state_dict(), './model_out') torch.save(model.state_dict(), './model_out')
@ -72,12 +73,15 @@ if __name__ == "__main__":
signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy()) signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy())
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
if tracking_url_type_store != "file": if tracking_url_type_store != "file":
mlflow.log_artifact('model.py')
mlflow.pytorch.log_model( mlflow.pytorch.log_model(
model, model,
"s444409", "s444409",
registered_model_name="s444409", registered_model_name="s444409",
signature=signature, signature=signature,
input_example=input_example input_example=input_example,
extra_files=['model.py']
) )
else: else:
mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example) mlflow.pytorch.log_model(model, "s444409", signature=signature, input_example=input_example,
extra_files=['model.py'])