36 lines
807 B
Python
36 lines
807 B
Python
import os
|
|
|
|
import pandas as pd
|
|
import tensorflow
|
|
from keras.applications.densenet import layers
|
|
|
|
|
|
def main(EPOCHS):
|
|
if EPOCHS == 0:
|
|
EPOCHS = 500
|
|
train_data_x = pd.read_csv('./X_train.csv')
|
|
|
|
adults_train = train_data_x.copy()
|
|
adults_predict = train_data_x.pop('age')
|
|
normalize = layers.Normalization()
|
|
normalize.adapt(adults_train)
|
|
|
|
adult_model = tensorflow.keras.Sequential([
|
|
normalize,
|
|
layers.Dense(64),
|
|
layers.Dense(1)
|
|
])
|
|
|
|
adult_model.compile(
|
|
loss=tensorflow.keras.losses.MeanSquaredError(),
|
|
optimizer=tensorflow.keras.optimizers.Adam())
|
|
|
|
adult_model.fit(adults_train, adults_predict, epochs=EPOCHS)
|
|
|
|
adult_model.save('model')
|
|
|
|
|
|
if __name__ == "__main__":
|
|
EPOCHS = int(os.environ['EPOCHS'])
|
|
main(EPOCHS)
|