Add parameters to training.py
This commit is contained in:
parent
7b1acab17d
commit
4208581057
6
ML/Jenkinsfile
vendored
6
ML/Jenkinsfile
vendored
@ -6,6 +6,12 @@ pipeline {
|
|||||||
defaultSelector: lastSuccessful(),
|
defaultSelector: lastSuccessful(),
|
||||||
description: 'A build to take the artifacts from'
|
description: 'A build to take the artifacts from'
|
||||||
)
|
)
|
||||||
|
string(
|
||||||
|
name: 'EPOCHS',
|
||||||
|
description: 'Number of epochs',
|
||||||
|
defaultValue: '10'
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
stages {
|
stages {
|
||||||
stage('Copy artifacts') {
|
stage('Copy artifacts') {
|
||||||
|
@ -8,7 +8,7 @@ from sklearn.preprocessing import LabelBinarizer
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn.preprocessing import MinMaxScaler
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
|
||||||
class MyNeuralNetwork(nn.Module):
|
class MyNeuralNetwork(nn.Module):
|
||||||
@ -116,6 +116,11 @@ def create_dataset():
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
# create_dataset()
|
# create_dataset()
|
||||||
|
parser = argparse.ArgumentParser(description='A test program.')
|
||||||
|
parser.add_argument("epochs", help="Prints the supplied argument." ,default='10')
|
||||||
|
args = parser.parse_args()
|
||||||
|
epochs = int(args["epochs"])
|
||||||
|
|
||||||
train_dataset = load_data("home_loan_train.csv")
|
train_dataset = load_data("home_loan_train.csv")
|
||||||
val_dataset = load_data("home_loan_val.csv")
|
val_dataset = load_data("home_loan_val.csv")
|
||||||
|
|
||||||
@ -123,7 +128,7 @@ def main() -> None:
|
|||||||
dataloader_train = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
|
dataloader_train = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
|
||||||
dataloader_val = DataLoader(val_dataset, batch_size = batch_size)
|
dataloader_val = DataLoader(val_dataset, batch_size = batch_size)
|
||||||
|
|
||||||
model = train(20, dataloader_train, dataloader_val)
|
model = train(epochs, dataloader_train, dataloader_val)
|
||||||
torch.save(model.state_dict(), 'model.pt')
|
torch.save(model.state_dict(), 'model.pt')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user