Initial MLFlow setup
Some checks failed
s444409-training/pipeline/head There was a failure building this commit
Some checks failed
s444409-training/pipeline/head There was a failure building this commit
This commit is contained in:
parent
b6f47e9fef
commit
66c7e1c583
@ -34,7 +34,7 @@ pipeline {
|
||||
}
|
||||
stage('Train model') {
|
||||
steps {
|
||||
sh "python train_model.py with 'epochs=${params.EPOCHS}' 'batch_size=${params.BATCHSIZE}'"
|
||||
sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}"
|
||||
}
|
||||
}
|
||||
stage('Archive model and evaluate it') {
|
||||
|
@ -4,4 +4,5 @@ torch==1.11.0
|
||||
numpy~=1.22.3
|
||||
matplotlib==3.5.2
|
||||
sacred==0.8.2
|
||||
pymongo==4.1.1
|
||||
pymongo==4.1.1
|
||||
mlflow==1.25.1
|
@ -1,12 +1,9 @@
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sacred.observers import FileStorageObserver, MongoObserver
|
||||
import mlflow
|
||||
import argparse
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
from sacred import Experiment
|
||||
from torch.utils.data import DataLoader
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from model import PlantsDataset, MLP, train, test
|
||||
|
||||
@ -15,8 +12,23 @@ default_epochs = 5
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
mlflow.set_tracking_uri("http://172.17.0.1:5000")
|
||||
mlflow.set_experiment("s444409")
|
||||
|
||||
|
||||
def setup_args():
|
||||
args_parser = argparse.ArgumentParser(prefix_chars='-')
|
||||
args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size)
|
||||
args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs)
|
||||
|
||||
return args_parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = setup_args()
|
||||
batch_size = args.batchSize
|
||||
epochs = args.epochs
|
||||
|
||||
def main(batch_size, epochs, _run):
|
||||
print(f"Using {device} device")
|
||||
|
||||
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
|
||||
@ -35,37 +47,33 @@ def main(batch_size, epochs, _run):
|
||||
|
||||
loss_fn = nn.MSELoss()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
last_loss = test(test_dataloader, model, loss_fn)
|
||||
_run.log_scalar('training.loss', last_loss, t)
|
||||
|
||||
print("Done!")
|
||||
|
||||
torch.save(model.state_dict(), './model_out')
|
||||
print("Model saved in ./model_out file.")
|
||||
with mlflow.start_run() as run:
|
||||
print("MLflow run experiment_id: {0}".format(run.info.experiment_id))
|
||||
print("MLflow run artifact_uri: {0}".format(run.info.artifact_uri))
|
||||
|
||||
mlflow.log_param("batch_size", batch_size)
|
||||
mlflow.log_param("epochs", epochs)
|
||||
for t in range(epochs):
|
||||
print(f"Epoch {t + 1}\n-------------------------------")
|
||||
train(train_dataloader, model, loss_fn, optimizer)
|
||||
last_loss = test(test_dataloader, model, loss_fn)
|
||||
mlflow.log_metric("rmse", last_loss)
|
||||
|
||||
def setup_experiment():
|
||||
ex = Experiment('Predict power output for a given time')
|
||||
ex.observers.append(FileStorageObserver('sacred_runs'))
|
||||
ex.observers.append(MongoObserver(url='mongodb://admin:IUM_2021@172.17.0.1:27017',
|
||||
db_name='sacred'))
|
||||
return ex
|
||||
|
||||
|
||||
ex = setup_experiment()
|
||||
|
||||
|
||||
@ex.config
|
||||
def experiment_config():
|
||||
batch_size = 64
|
||||
epochs = 5
|
||||
|
||||
|
||||
@ex.automain
|
||||
def run(batch_size, epochs, _run):
|
||||
main(batch_size, epochs, _run)
|
||||
|
||||
|
||||
ex.add_artifact('model_out')
|
||||
with torch.no_grad():
|
||||
preds = model(plant_test.x_train)
|
||||
signature = mlflow.models.signature.infer_signature(plant_test.x_train.numpy(), preds.numpy())
|
||||
tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
|
||||
if tracking_url_type_store != "file":
|
||||
mlflow.pytorch.log_model(
|
||||
model,
|
||||
"s444409-power-plant-model",
|
||||
registered_model_name="s444409PowerPlant",
|
||||
signature=signature
|
||||
)
|
||||
else:
|
||||
mlflow.pytorch.log_model(model, "s444409-power-plant-model", signature=signature)
|
||||
|
Loading…
Reference in New Issue
Block a user