diff --git a/model_creator.py b/model_creator.py index 9d5521a..f35b62f 100644 --- a/model_creator.py +++ b/model_creator.py @@ -80,7 +80,7 @@ check_datasets_presence() result_df = datasets_preparation() Y = result_df[['playlist_genre']] X = result_df.drop(columns='playlist_genre') -X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=sys.argv[1], random_state=42) +X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=float(sys.argv[1]), random_state=42) Y_train = np.ravel(Y_train) @@ -91,7 +91,7 @@ numeric_columns = X_train.select_dtypes(include=['int', 'float']).columns X_train_scaled = scaler.fit_transform(X_train[numeric_columns]) X_test_scaled = scaler.transform(X_test[numeric_columns]) -model = LogisticRegression(max_iter=sys.argv[2]) +model = LogisticRegression(max_iter=int(sys.argv[2])) model.fit(X_train_scaled, Y_train)