ium_07 sacred

This commit is contained in:
Michal Gulczynski 2024-06-11 19:56:09 +02:00
parent 9e90c820b8
commit b88ddb3066
2 changed files with 10 additions and 0 deletions

View File

@ -63,6 +63,12 @@ def run_experiment(test_size, random_state, model_filename):
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=test_size, random_state=random_state)
Y_train = np.ravel(Y_train)
Y_test = np.ravel(Y_test)
ex.add_resource(X_train)
ex.add_resource(X_test)
ex.add_resource(Y_train)
ex.add_resource(Y_test)
scaler = StandardScaler()
numeric_columns = X_train.select_dtypes(include=['int', 'float']).columns
X_train_scaled = scaler.fit_transform(X_train[numeric_columns])

View File

@ -26,6 +26,10 @@ def run_evaluation(model_filename, test_dataset_filename):
X_test = test_df.drop(columns='playlist_genre')
Y_test = np.ravel(Y_test)
scaler = StandardScaler()
ex.add_resource(X_test)
ex.add_resource(Y_test)
numeric_columns = X_test.select_dtypes(include=['int', 'float']).columns
X_test_scaled = scaler.fit_transform(X_test[numeric_columns])
Y_pred = model.predict(X_test_scaled)