This commit is contained in:
Mateusz 2024-05-04 11:41:45 +02:00
parent c0b07aaac4
commit f6c7f5981e
2 changed files with 10 additions and 6 deletions

6
Jenkinsfile vendored
View File

@ -13,8 +13,8 @@ pipeline {
description: 'Which build to use for copying artifacts', description: 'Which build to use for copying artifacts',
name: 'BUILD_SELECTOR' name: 'BUILD_SELECTOR'
) )
string(name: 'learning_rate', defaultValue: '0.001', description: 'Learning rate') string(name: 'LEARNING_RATE', defaultValue: '0.001', description: 'Learning rate')
string(name: 'epochs', defaultValue: '5', description: 'Number of epochs') string(name: 'EPOCHS', defaultValue: '5', description: 'Number of epochs')
} }
stages { stages {
@ -33,7 +33,7 @@ pipeline {
stage('Run train_model script') { stage('Run train_model script') {
steps { steps {
sh 'chmod +x train_model.py' sh 'chmod +x train_model.py'
sh 'python3 ./train_model.py' sh 'python3 ./train_model.py ${params.LEARNING_RATE} ${params.EPOCHS}'
} }
} }

View File

@ -6,6 +6,7 @@ from keras.models import Sequential
from keras.layers import BatchNormalization, Dropout, Dense, Flatten, Conv1D from keras.layers import BatchNormalization, Dropout, Dense, Flatten, Conv1D
from keras.optimizers import Adam from keras.optimizers import Adam
import pandas as pd import pandas as pd
import sys
def main(): def main():
@ -22,6 +23,9 @@ def main():
X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1) X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1) X_val = X_val.reshape(X_val.shape[0], X_val.shape[1], 1)
learning_rate = float(sys.argv[1])
epochs = int(sys.argv[2])
model = Sequential( model = Sequential(
[ [
Conv1D(32, 2, activation="relu", input_shape=X_train[0].shape), Conv1D(32, 2, activation="relu", input_shape=X_train[0].shape),
@ -38,7 +42,7 @@ def main():
) )
model.compile( model.compile(
optimizer=Adam(learning_rate=1e-3), optimizer=Adam(learning_rate=learning_rate),
loss="binary_crossentropy", loss="binary_crossentropy",
metrics=["accuracy"], metrics=["accuracy"],
) )
@ -47,7 +51,7 @@ def main():
X_train, X_train,
y_train, y_train,
validation_data=(X_val, y_val), validation_data=(X_val, y_val),
epochs=5, epochs=epochs,
verbose=1, verbose=1,
) )