ium_478841/scripts/grab_avocado.py

47 lines
1.8 KiB
Python
Raw Normal View History

2022-04-03 19:39:46 +02:00
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
2022-04-03 19:39:46 +02:00
2022-04-03 20:17:21 +02:00
cols = list(pd.read_csv("data/avocado.csv", nrows=1))
2022-04-03 19:39:46 +02:00
# print("###\n", cols, "\n###")
avocados = pd.read_csv(
"data/avocado.csv").rename(columns={"Unnamed: 0": 'Week'})
2022-04-03 19:39:46 +02:00
avocados.describe(include="all")
# * columns containing float values to
float_cols = ['AveragePrice', 'Total Volume', '4046', '4225',
'4770', 'Total Bags', 'Small Bags', 'Large Bags', 'XLarge Bags']
avocados.loc[:, float_cols] = StandardScaler(
).fit_transform(avocados.loc[:, float_cols])
enc = OneHotEncoder(handle_unknown='ignore')
encoded_region = enc.fit_transform(
avocados['region'].to_numpy().reshape(-1, 1)).toarray()
encoded_region_frame = pd.DataFrame(
encoded_region, columns=enc.get_feature_names_out())
encoded_types = enc.fit_transform(
avocados['type'].to_numpy().reshape(-1, 1)).toarray()
encoded_types_frame = pd.DataFrame(
encoded_types, columns=enc.get_feature_names_out())
avocados = pd.concat([avocados, encoded_types_frame, encoded_region_frame], axis=1).drop(
['type', 'region', 'Date'], axis=1)
2022-04-03 19:39:46 +02:00
print(avocados.head())
# avocados.loc[:, float_cols] = MinMaxScaler().fit_transform(avocados.loc[:, float_cols])
# print(avocados.head())
avocado_train, avocado_test = train_test_split(
avocados, test_size=2000, random_state=3337)
avocado_train, avocado_valid = train_test_split(
avocado_train, test_size=2249, random_state=3337)
2022-04-03 19:39:46 +02:00
print("Train\n", avocado_train.describe(include="all"), "\n")
print("Valid\n", avocado_valid.describe(include="all"), "\n")
print("Test\n", avocado_test.describe(include="all"))
2022-04-03 20:17:21 +02:00
avocado_train.to_csv("data/avocado.data.train", index=False)
avocado_valid.to_csv("data/avocado.data.valid", index=False)
avocado_test.to_csv("data/avocado.data.test", index=False)