Add basic Sacred support
All checks were successful
s444409-evaluation/pipeline/head This commit looks good
s444409-training/pipeline/head This commit looks good

This commit is contained in:
Marcin Kostrzewski 2022-05-06 20:43:53 +02:00
parent 593b17cd2b
commit f716e52cbf
3 changed files with 17 additions and 18 deletions

View File

@ -29,7 +29,7 @@ pipeline {
stages {
stage('Train model') {
steps {
sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}"
sh "python train_model.py with 'epochs=${params.EPOCHS}' 'batch_size=${params.BATCHSIZE}'"
}
}
stage('Archive model and evaluate it') {

View File

@ -3,3 +3,4 @@ pandas==1.4.1
torch==1.11.0
numpy~=1.22.3
matplotlib==3.5.2
sacred==0.8.2

View File

@ -5,6 +5,7 @@ import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from sacred import Experiment
default_batch_size = 64
default_epochs = 5
@ -90,21 +91,9 @@ def test(dataloader, model, loss_fn):
return test_loss
def setup_args():
args_parser = argparse.ArgumentParser(prefix_chars='-')
args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size)
args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs)
return args_parser.parse_args()
def main():
def main(batch_size, epochs):
print(f"Using {device} device")
args = setup_args()
batch_size = args.batchSize
plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test')
plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train')
@ -121,7 +110,6 @@ def main():
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = args.epochs
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
@ -132,5 +120,15 @@ def main():
print("Model saved in ./model_out file.")
if __name__ == "__main__":
main()
ex = Experiment('Predict power output for a given time')
@ex.config
def experiment_config():
batch_size = 64
epochs = 5
@ex.automain
def run(batch_size, epochs):
main(batch_size, epochs)