Add parameters to training.py

This commit is contained in:
s487179 2023-06-11 12:40:44 +02:00
parent 7b1acab17d
commit 4208581057
2 changed files with 13 additions and 2 deletions

6
ML/Jenkinsfile vendored
View File

@ -6,6 +6,12 @@ pipeline {
defaultSelector: lastSuccessful(), defaultSelector: lastSuccessful(),
description: 'A build to take the artifacts from' description: 'A build to take the artifacts from'
) )
string(
name: 'EPOCHS',
description: 'Number of epochs',
defaultValue: '10'
)
}
} }
stages { stages {
stage('Copy artifacts') { stage('Copy artifacts') {

View File

@ -8,7 +8,7 @@ from sklearn.preprocessing import LabelBinarizer
import numpy as np import numpy as np
from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import argparse
class MyNeuralNetwork(nn.Module): class MyNeuralNetwork(nn.Module):
@ -116,6 +116,11 @@ def create_dataset():
def main() -> None: def main() -> None:
# create_dataset() # create_dataset()
parser = argparse.ArgumentParser(description='A test program.')
parser.add_argument("epochs", help="Prints the supplied argument." ,default='10')
args = parser.parse_args()
epochs = int(args["epochs"])
train_dataset = load_data("home_loan_train.csv") train_dataset = load_data("home_loan_train.csv")
val_dataset = load_data("home_loan_val.csv") val_dataset = load_data("home_loan_val.csv")
@ -123,7 +128,7 @@ def main() -> None:
dataloader_train = DataLoader(train_dataset, batch_size = batch_size, shuffle=True) dataloader_train = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
dataloader_val = DataLoader(val_dataset, batch_size = batch_size) dataloader_val = DataLoader(val_dataset, batch_size = batch_size)
model = train(20, dataloader_train, dataloader_val) model = train(epochs, dataloader_train, dataloader_val)
torch.save(model.state_dict(), 'model.pt') torch.save(model.state_dict(), 'model.pt')
if __name__ == "__main__": if __name__ == "__main__":