This commit is contained in:
Klaudia 2023-05-11 18:37:18 +02:00
parent a839ae2bb9
commit 2b95806379
2 changed files with 20 additions and 15 deletions

View File

@ -1,6 +1,7 @@
import pandas as pd import pandas as pd
import os import os
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import train
CUTOFF = int(os.environ['CUTOFF']) CUTOFF = int(os.environ['CUTOFF'])
adults = pd.read_csv('adult.csv') adults = pd.read_csv('adult.csv')
@ -20,3 +21,5 @@ X_test.to_csv('X_test.csv', index=False)
Y_test.to_csv('Y_test.csv', index=False) Y_test.to_csv('Y_test.csv', index=False)
Y_train.to_csv('Y_train.csv', index=False) Y_train.to_csv('Y_train.csv', index=False)
Y_dev.to_csv('Y_dev.csv', index=False) Y_dev.to_csv('Y_dev.csv', index=False)
train.main()

View File

@ -2,23 +2,25 @@ import pandas as pd
import tensorflow import tensorflow
from keras.applications.densenet import layers from keras.applications.densenet import layers
train_data_x = pd.read_csv('./X_train.csv')
adults_train = train_data_x.copy() def main():
adults_predict = train_data_x.pop('age') train_data_x = pd.read_csv('./X_train.csv')
normalize = layers.Normalization()
normalize.adapt(adults_train)
adult_model = tensorflow.keras.Sequential([ adults_train = train_data_x.copy()
normalize, adults_predict = train_data_x.pop('age')
layers.Dense(64), normalize = layers.Normalization()
layers.Dense(1) normalize.adapt(adults_train)
])
adult_model.compile( adult_model = tensorflow.keras.Sequential([
loss=tensorflow.keras.losses.MeanSquaredError(), normalize,
optimizer=tensorflow.keras.optimizers.Adam()) layers.Dense(64),
layers.Dense(1)
])
adult_model.fit(adults_train, adults_predict, epochs=500) adult_model.compile(
loss=tensorflow.keras.losses.MeanSquaredError(),
optimizer=tensorflow.keras.optimizers.Adam())
adult_model.save('model') adult_model.fit(adults_train, adults_predict, epochs=500)
adult_model.save('model')