ium_z444439/train.py

36 lines
807 B
Python
Raw Normal View History

2023-05-11 19:00:42 +02:00
import os
2023-05-11 18:27:25 +02:00
import pandas as pd
import tensorflow
from keras.applications.densenet import layers
2023-05-11 19:02:39 +02:00
def main(EPOCHS):
if EPOCHS == 0:
EPOCHS = 500
2023-05-11 18:37:18 +02:00
train_data_x = pd.read_csv('./X_train.csv')
2023-05-11 18:27:25 +02:00
2023-05-11 18:37:18 +02:00
adults_train = train_data_x.copy()
adults_predict = train_data_x.pop('age')
normalize = layers.Normalization()
normalize.adapt(adults_train)
2023-05-11 18:27:25 +02:00
2023-05-11 18:37:18 +02:00
adult_model = tensorflow.keras.Sequential([
normalize,
layers.Dense(64),
layers.Dense(1)
])
2023-05-11 18:27:25 +02:00
2023-05-11 18:37:18 +02:00
adult_model.compile(
loss=tensorflow.keras.losses.MeanSquaredError(),
optimizer=tensorflow.keras.optimizers.Adam())
2023-05-11 18:27:25 +02:00
2023-05-11 19:00:42 +02:00
adult_model.fit(adults_train, adults_predict, epochs=EPOCHS)
2023-05-11 18:37:18 +02:00
adult_model.save('model')
2023-05-11 18:55:10 +02:00
if __name__ == "__main__":
2023-05-11 19:02:39 +02:00
EPOCHS = int(os.environ['EPOCHS'])
main(EPOCHS)