first script

This commit is contained in:
Alicja Szulecka 2024-04-13 19:07:21 +02:00
parent 00f443439e
commit 4709691adf

View File

@ -3,38 +3,36 @@ from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
def split(data): def split(data):
meteorite_train, meteorite_test = train_test_split(data, test_size=0.2, random_state=1) forest_train, forest_test = train_test_split(data, test_size=0.2, random_state=1)
meteorite_train, meteorite_val = train_test_split(meteorite_train, test_size=0.25, random_state=1) forest_train, forest_val = train_test_split(forest_train, test_size=0.25, random_state=1)
return meteorite_train, meteorite_test, meteorite_val return forest_train, forest_test, forest_val
def normalization(data): def normalization(data):
scaler = StandardScaler() scaler = StandardScaler()
data['mass'] = scaler.fit_transform(data[['mass']]) columns_to_normalize = data.columns[~data.columns.str.startswith('Soil_Type')]
columns_to_normalize = columns_to_normalize.to_list()
columns_to_normalize.remove('Cover_Type')
data[columns_to_normalize] = scaler.fit_transform(data[columns_to_normalize])
return data return data
def preprocessing(data): def preprocessing(data):
data = data.dropna(subset=['reclat']) #shuffle
data = data.sample(frac = 1)
incorrect_years_index = data.loc[(data['year'] > 2016) | (data['year'] < 860)].index
incorrect_location_index = data.loc[(data['reclat'] == 0) & (data['reclong'] == 0)].index
data.drop(incorrect_years_index.union(incorrect_location_index), inplace=True)
data.loc[(data['mass'].isnull()) & (data['name'].str.startswith('Österplana')), 'mass'] = 0
return data return data
data = pd.read_csv("meteorite-landings.csv") data = pd.read_csv("covtype.csv")
meteorite_train, meteorite_test, meteorite_val = split(data) forest_train, forest_test, forest_val = split(data)
meteorite_train = normalization(meteorite_train) forest_train = preprocessing(forest_train)
meteorite_test = normalization(meteorite_test) forest_test = preprocessing(forest_test)
meteorite_val = normalization(meteorite_val) forest_val = preprocessing(forest_val)
meteorite_train = normalization(meteorite_train) forest_train = normalization(forest_train)
meteorite_test = normalization(meteorite_test) forest_test = normalization(forest_test)
meteorite_val = normalization(meteorite_val) forest_val = normalization(forest_val)
meteorite_train.to_csv('meteorite_train.csv', encoding='utf-8') forest_train.to_csv('forest_train.csv', encoding='utf-8', index=False)
meteorite_test.to_csv('meteorite_test.csv', encoding='utf-8') forest_test.to_csv('forest_test.csv', encoding='utf-8', index=False)
meteorite_val.to_csv('meteorite_val.csv', encoding='utf-8') forest_val.to_csv('forest_val.csv', encoding='utf-8', index=False)