add parameter
This commit is contained in:
parent
643e52a372
commit
303bfb0193
@ -1,8 +1,9 @@
|
||||
import keras
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
|
||||
build_no = int(os.environ['BUILD_NUMBER'])
|
||||
data = tf.keras.models.load_model('model')
|
||||
|
||||
x_to_test = pd.read_csv('./X_test.csv')
|
||||
@ -12,7 +13,7 @@ accu = data.evaluate(x_to_test, y_to_test)
|
||||
|
||||
|
||||
with open('metrics.csv', 'a') as file:
|
||||
file.write(f'1,{accu}\n')
|
||||
file.write(f'{build_no},{accu}\n')
|
||||
|
||||
pre = data.predict(x_to_test)
|
||||
|
||||
|
5
train.py
5
train.py
@ -1,9 +1,12 @@
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import tensorflow
|
||||
from keras.applications.densenet import layers
|
||||
|
||||
|
||||
def main():
|
||||
EPOCHS = int(os.environ['EPOCHS'])
|
||||
train_data_x = pd.read_csv('./X_train.csv')
|
||||
|
||||
adults_train = train_data_x.copy()
|
||||
@ -21,7 +24,7 @@ def main():
|
||||
loss=tensorflow.keras.losses.MeanSquaredError(),
|
||||
optimizer=tensorflow.keras.optimizers.Adam())
|
||||
|
||||
adult_model.fit(adults_train, adults_predict, epochs=500)
|
||||
adult_model.fit(adults_train, adults_predict, epochs=EPOCHS)
|
||||
|
||||
adult_model.save('model')
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user