diff --git a/sacred.py b/sacred.py index d9c2c48..42342e9 100644 --- a/sacred.py +++ b/sacred.py @@ -29,7 +29,8 @@ def run_experiment(): cars_train, cars_test = sklearn.model_selection.train_test_split(cars_normalized, test_size=23586, random_state=1) cars_dev, cars_test = sklearn.model_selection.train_test_split(cars_test, test_size=11793, random_state=1) cars_train.rename(columns = {list(cars_train)[0]: 'id'}, inplace = True) - cars_train.to_csv('test.csv') + cars_train.to_csv('train.csv') + cars_test.to_csv('test,csv') feature_cols = ['year', 'mileage', 'vol_engine'] inputs = tf.keras.Input(shape=(len(feature_cols),))