add parameter
This commit is contained in:
parent
643e52a372
commit
303bfb0193
@ -1,8 +1,9 @@
|
|||||||
import keras
|
import os
|
||||||
import numpy as np
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
|
build_no = int(os.environ['BUILD_NUMBER'])
|
||||||
data = tf.keras.models.load_model('model')
|
data = tf.keras.models.load_model('model')
|
||||||
|
|
||||||
x_to_test = pd.read_csv('./X_test.csv')
|
x_to_test = pd.read_csv('./X_test.csv')
|
||||||
@ -12,7 +13,7 @@ accu = data.evaluate(x_to_test, y_to_test)
|
|||||||
|
|
||||||
|
|
||||||
with open('metrics.csv', 'a') as file:
|
with open('metrics.csv', 'a') as file:
|
||||||
file.write(f'1,{accu}\n')
|
file.write(f'{build_no},{accu}\n')
|
||||||
|
|
||||||
pre = data.predict(x_to_test)
|
pre = data.predict(x_to_test)
|
||||||
|
|
||||||
|
5
train.py
5
train.py
@ -1,9 +1,12 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tensorflow
|
import tensorflow
|
||||||
from keras.applications.densenet import layers
|
from keras.applications.densenet import layers
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
EPOCHS = int(os.environ['EPOCHS'])
|
||||||
train_data_x = pd.read_csv('./X_train.csv')
|
train_data_x = pd.read_csv('./X_train.csv')
|
||||||
|
|
||||||
adults_train = train_data_x.copy()
|
adults_train = train_data_x.copy()
|
||||||
@ -21,7 +24,7 @@ def main():
|
|||||||
loss=tensorflow.keras.losses.MeanSquaredError(),
|
loss=tensorflow.keras.losses.MeanSquaredError(),
|
||||||
optimizer=tensorflow.keras.optimizers.Adam())
|
optimizer=tensorflow.keras.optimizers.Adam())
|
||||||
|
|
||||||
adult_model.fit(adults_train, adults_predict, epochs=500)
|
adult_model.fit(adults_train, adults_predict, epochs=EPOCHS)
|
||||||
|
|
||||||
adult_model.save('model')
|
adult_model.save('model')
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user