Zaktualizuj 'sacred.py'

This commit is contained in:
Michał Dudziak 2023-05-10 13:17:38 +02:00
parent 9de361e834
commit 7d39517665

View File

@ -29,7 +29,8 @@ def run_experiment():
cars_train, cars_test = sklearn.model_selection.train_test_split(cars_normalized, test_size=23586, random_state=1) cars_train, cars_test = sklearn.model_selection.train_test_split(cars_normalized, test_size=23586, random_state=1)
cars_dev, cars_test = sklearn.model_selection.train_test_split(cars_test, test_size=11793, random_state=1) cars_dev, cars_test = sklearn.model_selection.train_test_split(cars_test, test_size=11793, random_state=1)
cars_train.rename(columns = {list(cars_train)[0]: 'id'}, inplace = True) cars_train.rename(columns = {list(cars_train)[0]: 'id'}, inplace = True)
cars_train.to_csv('test.csv') cars_train.to_csv('train.csv')
cars_test.to_csv('test,csv')
feature_cols = ['year', 'mileage', 'vol_engine'] feature_cols = ['year', 'mileage', 'vol_engine']
inputs = tf.keras.Input(shape=(len(feature_cols),)) inputs = tf.keras.Input(shape=(len(feature_cols),))