diff --git a/Jenkinsfile-train b/Jenkinsfile-train index 8a09238..e75d67c 100644 --- a/Jenkinsfile-train +++ b/Jenkinsfile-train @@ -29,7 +29,7 @@ pipeline { stages { stage('Train model') { steps { - sh "python train_model.py -e ${params.EPOCHS} -b ${params.BATCHSIZE}" + sh "python train_model.py with 'epochs=${params.EPOCHS}' 'batch_size=${params.BATCHSIZE}'" } } stage('Archive model and evaluate it') { diff --git a/requirements.txt b/requirements.txt index a1e279b..5fef6d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ kaggle==1.5.12 pandas==1.4.1 torch==1.11.0 numpy~=1.22.3 -matplotlib==3.5.2 \ No newline at end of file +matplotlib==3.5.2 +sacred==0.8.2 diff --git a/train_model.py b/train_model.py index 5aa3932..8695f6d 100644 --- a/train_model.py +++ b/train_model.py @@ -5,6 +5,7 @@ import pandas as pd import torch from torch import nn from torch.utils.data import DataLoader, Dataset +from sacred import Experiment default_batch_size = 64 default_epochs = 5 @@ -90,21 +91,9 @@ def test(dataloader, model, loss_fn): return test_loss -def setup_args(): - args_parser = argparse.ArgumentParser(prefix_chars='-') - args_parser.add_argument('-b', '--batchSize', type=int, default=default_batch_size) - args_parser.add_argument('-e', '--epochs', type=int, default=default_epochs) - - return args_parser.parse_args() - - -def main(): +def main(batch_size, epochs): print(f"Using {device} device") - args = setup_args() - - batch_size = args.batchSize - plant_test = PlantsDataset('data/Plant_1_Generation_Data.csv.test') plant_train = PlantsDataset('data/Plant_1_Generation_Data.csv.train') @@ -121,7 +110,6 @@ def main(): loss_fn = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - epochs = args.epochs for t in range(epochs): print(f"Epoch {t + 1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) @@ -132,5 +120,15 @@ def main(): print("Model saved in ./model_out file.") -if __name__ == "__main__": - main() +ex = Experiment('Predict power output for a given time') + + +@ex.config +def experiment_config(): + batch_size = 64 + epochs = 5 + + +@ex.automain +def run(batch_size, epochs): + main(batch_size, epochs)