""" The :mod:`sklearn.model_selection._validation` module includes classes and functions to validate the model. """ # Author: Alexandre Gramfort # Gael Varoquaux # Olivier Grisel # Raghav RV # Michal Karbownik # License: BSD 3 clause import warnings import numbers import time from functools import partial from traceback import format_exc from contextlib import suppress from collections import Counter import numpy as np import scipy.sparse as sp from joblib import logger from ..base import is_classifier, clone from ..utils import indexable, check_random_state, _safe_indexing from ..utils.validation import _check_fit_params from ..utils.validation import _num_samples from ..utils.parallel import delayed, Parallel from ..utils.metaestimators import _safe_split from ..metrics import check_scoring from ..metrics._scorer import _check_multimetric_scoring, _MultimetricScorer from ..exceptions import FitFailedWarning from ._split import check_cv from ..preprocessing import LabelEncoder __all__ = [ "cross_validate", "cross_val_score", "cross_val_predict", "permutation_test_score", "learning_curve", "validation_curve", ] def cross_validate( estimator, X, y=None, *, groups=None, scoring=None, cv=None, n_jobs=None, verbose=0, fit_params=None, pre_dispatch="2*n_jobs", return_train_score=False, return_estimator=False, error_score=np.nan, ): """Evaluate metric(s) by cross-validation and also record fit/score times. Read more in the :ref:`User Guide `. Parameters ---------- estimator : estimator object implementing 'fit' The object to use to fit the data. X : array-like of shape (n_samples, n_features) The data to fit. Can be for example a list, or an array. y : array-like of shape (n_samples,) or (n_samples, n_outputs), default=None The target variable to try to predict in the case of supervised learning. groups : array-like of shape (n_samples,), default=None Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`GroupKFold`). scoring : str, callable, list, tuple, or dict, default=None Strategy to evaluate the performance of the cross-validated model on the test set. If `scoring` represents a single score, one can use: - a single string (see :ref:`scoring_parameter`); - a callable (see :ref:`scoring`) that returns a single value. If `scoring` represents multiple scores, one can use: - a list or tuple of unique strings; - a callable returning a dictionary where the keys are the metric names and the values are the metric scores; - a dictionary with metric names as keys and callables a values. See :ref:`multimetric_grid_search` for an example. cv : int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - int, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. For int/None inputs, if the estimator is a classifier and ``y`` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 ``cv`` default value if None changed from 3-fold to 5-fold. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and computing the score are parallelized over the cross-validation splits. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. verbose : int, default=0 The verbosity level. fit_params : dict, default=None Parameters to pass to the fit method of the estimator. pre_dispatch : int or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an explosion of memory consumption when more jobs get dispatched than CPUs can process. This parameter can be: - None, in which case all the jobs are immediately created and spawned. Use this for lightweight and fast-running jobs, to avoid delays due to on-demand spawning of the jobs - An int, giving the exact number of total jobs that are spawned - A str, giving an expression as a function of n_jobs, as in '2*n_jobs' return_train_score : bool, default=False Whether to include train scores. Computing training scores is used to get insights on how different parameter settings impact the overfitting/underfitting trade-off. However computing the scores on the training set can be computationally expensive and is not strictly required to select the parameters that yield the best generalization performance. .. versionadded:: 0.19 .. versionchanged:: 0.21 Default value was changed from ``True`` to ``False`` return_estimator : bool, default=False Whether to return the estimators fitted on each split. .. versionadded:: 0.20 error_score : 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. .. versionadded:: 0.20 Returns ------- scores : dict of float arrays of shape (n_splits,) Array of scores of the estimator for each run of the cross validation. A dict of arrays containing the score/time arrays for each scorer is returned. The possible keys for this ``dict`` are: ``test_score`` The score array for test scores on each cv split. Suffix ``_score`` in ``test_score`` changes to a specific metric like ``test_r2`` or ``test_auc`` if there are multiple scoring metrics in the scoring parameter. ``train_score`` The score array for train scores on each cv split. Suffix ``_score`` in ``train_score`` changes to a specific metric like ``train_r2`` or ``train_auc`` if there are multiple scoring metrics in the scoring parameter. This is available only if ``return_train_score`` parameter is ``True``. ``fit_time`` The time for fitting the estimator on the train set for each cv split. ``score_time`` The time for scoring the estimator on the test set for each cv split. (Note time for scoring on the train set is not included even if ``return_train_score`` is set to ``True`` ``estimator`` The estimator objects for each cv split. This is available only if ``return_estimator`` parameter is set to ``True``. See Also -------- cross_val_score : Run cross-validation for single metric evaluation. cross_val_predict : Get predictions from each split of cross-validation for diagnostic purposes. sklearn.metrics.make_scorer : Make a scorer from a performance metric or loss function. Examples -------- >>> from sklearn import datasets, linear_model >>> from sklearn.model_selection import cross_validate >>> from sklearn.metrics import make_scorer >>> from sklearn.metrics import confusion_matrix >>> from sklearn.svm import LinearSVC >>> diabetes = datasets.load_diabetes() >>> X = diabetes.data[:150] >>> y = diabetes.target[:150] >>> lasso = linear_model.Lasso() Single metric evaluation using ``cross_validate`` >>> cv_results = cross_validate(lasso, X, y, cv=3) >>> sorted(cv_results.keys()) ['fit_time', 'score_time', 'test_score'] >>> cv_results['test_score'] array([0.3315057 , 0.08022103, 0.03531816]) Multiple metric evaluation using ``cross_validate`` (please refer the ``scoring`` parameter doc for more information) >>> scores = cross_validate(lasso, X, y, cv=3, ... scoring=('r2', 'neg_mean_squared_error'), ... return_train_score=True) >>> print(scores['test_neg_mean_squared_error']) [-3635.5... -3573.3... -6114.7...] >>> print(scores['train_r2']) [0.28009951 0.3908844 0.22784907] """ X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) if callable(scoring): scorers = scoring elif scoring is None or isinstance(scoring, str): scorers = check_scoring(estimator, scoring) else: scorers = _check_multimetric_scoring(estimator, scoring) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) results = parallel( delayed(_fit_and_score)( clone(estimator), X, y, scorers, train, test, verbose, None, fit_params, return_train_score=return_train_score, return_times=True, return_estimator=return_estimator, error_score=error_score, ) for train, test in cv.split(X, y, groups) ) _warn_or_raise_about_fit_failures(results, error_score) # For callabe scoring, the return type is only know after calling. If the # return type is a dictionary, the error scores can now be inserted with # the correct key. if callable(scoring): _insert_error_scores(results, error_score) results = _aggregate_score_dicts(results) ret = {} ret["fit_time"] = results["fit_time"] ret["score_time"] = results["score_time"] if return_estimator: ret["estimator"] = results["estimator"] test_scores_dict = _normalize_score_results(results["test_scores"]) if return_train_score: train_scores_dict = _normalize_score_results(results["train_scores"]) for name in test_scores_dict: ret["test_%s" % name] = test_scores_dict[name] if return_train_score: key = "train_%s" % name ret[key] = train_scores_dict[name] return ret def _insert_error_scores(results, error_score): """Insert error in `results` by replacing them inplace with `error_score`. This only applies to multimetric scores because `_fit_and_score` will handle the single metric case. """ successful_score = None failed_indices = [] for i, result in enumerate(results): if result["fit_error"] is not None: failed_indices.append(i) elif successful_score is None: successful_score = result["test_scores"] if isinstance(successful_score, dict): formatted_error = {name: error_score for name in successful_score} for i in failed_indices: results[i]["test_scores"] = formatted_error.copy() if "train_scores" in results[i]: results[i]["train_scores"] = formatted_error.copy() def _normalize_score_results(scores, scaler_score_key="score"): """Creates a scoring dictionary based on the type of `scores`""" if isinstance(scores[0], dict): # multimetric scoring return _aggregate_score_dicts(scores) # scaler return {scaler_score_key: scores} def _warn_or_raise_about_fit_failures(results, error_score): fit_errors = [ result["fit_error"] for result in results if result["fit_error"] is not None ] if fit_errors: num_failed_fits = len(fit_errors) num_fits = len(results) fit_errors_counter = Counter(fit_errors) delimiter = "-" * 80 + "\n" fit_errors_summary = "\n".join( f"{delimiter}{n} fits failed with the following error:\n{error}" for error, n in fit_errors_counter.items() ) if num_failed_fits == num_fits: all_fits_failed_message = ( f"\nAll the {num_fits} fits failed.\n" "It is very likely that your model is misconfigured.\n" "You can try to debug the error by setting error_score='raise'.\n\n" f"Below are more details about the failures:\n{fit_errors_summary}" ) raise ValueError(all_fits_failed_message) else: some_fits_failed_message = ( f"\n{num_failed_fits} fits failed out of a total of {num_fits}.\n" "The score on these train-test partitions for these parameters" f" will be set to {error_score}.\n" "If these failures are not expected, you can try to debug them " "by setting error_score='raise'.\n\n" f"Below are more details about the failures:\n{fit_errors_summary}" ) warnings.warn(some_fits_failed_message, FitFailedWarning) def cross_val_score( estimator, X, y=None, *, groups=None, scoring=None, cv=None, n_jobs=None, verbose=0, fit_params=None, pre_dispatch="2*n_jobs", error_score=np.nan, ): """Evaluate a score by cross-validation. Read more in the :ref:`User Guide `. Parameters ---------- estimator : estimator object implementing 'fit' The object to use to fit the data. X : array-like of shape (n_samples, n_features) The data to fit. Can be for example a list, or an array. y : array-like of shape (n_samples,) or (n_samples, n_outputs), \ default=None The target variable to try to predict in the case of supervised learning. groups : array-like of shape (n_samples,), default=None Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`GroupKFold`). scoring : str or callable, default=None A str (see model evaluation documentation) or a scorer callable object / function with signature ``scorer(estimator, X, y)`` which should return only a single value. Similar to :func:`cross_validate` but only a single metric is permitted. If `None`, the estimator's default scorer (if available) is used. cv : int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - `None`, to use the default 5-fold cross validation, - int, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable that generates (train, test) splits as arrays of indices. For `int`/`None` inputs, if the estimator is a classifier and `y` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 `cv` default value if `None` changed from 3-fold to 5-fold. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and computing the score are parallelized over the cross-validation splits. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. verbose : int, default=0 The verbosity level. fit_params : dict, default=None Parameters to pass to the fit method of the estimator. pre_dispatch : int or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an explosion of memory consumption when more jobs get dispatched than CPUs can process. This parameter can be: - ``None``, in which case all the jobs are immediately created and spawned. Use this for lightweight and fast-running jobs, to avoid delays due to on-demand spawning of the jobs - An int, giving the exact number of total jobs that are spawned - A str, giving an expression as a function of n_jobs, as in '2*n_jobs' error_score : 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. .. versionadded:: 0.20 Returns ------- scores : ndarray of float of shape=(len(list(cv)),) Array of scores of the estimator for each run of the cross validation. See Also -------- cross_validate : To run cross-validation on multiple metrics and also to return train scores, fit times and score times. cross_val_predict : Get predictions from each split of cross-validation for diagnostic purposes. sklearn.metrics.make_scorer : Make a scorer from a performance metric or loss function. Examples -------- >>> from sklearn import datasets, linear_model >>> from sklearn.model_selection import cross_val_score >>> diabetes = datasets.load_diabetes() >>> X = diabetes.data[:150] >>> y = diabetes.target[:150] >>> lasso = linear_model.Lasso() >>> print(cross_val_score(lasso, X, y, cv=3)) [0.3315057 0.08022103 0.03531816] """ # To ensure multimetric format is not supported scorer = check_scoring(estimator, scoring=scoring) cv_results = cross_validate( estimator=estimator, X=X, y=y, groups=groups, scoring={"score": scorer}, cv=cv, n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, pre_dispatch=pre_dispatch, error_score=error_score, ) return cv_results["test_score"] def _fit_and_score( estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score=False, return_parameters=False, return_n_test_samples=False, return_times=False, return_estimator=False, split_progress=None, candidate_progress=None, error_score=np.nan, ): """Fit estimator and compute scores for a given dataset split. Parameters ---------- estimator : estimator object implementing 'fit' The object to use to fit the data. X : array-like of shape (n_samples, n_features) The data to fit. y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None The target variable to try to predict in the case of supervised learning. scorer : A single callable or dict mapping scorer name to the callable If it is a single callable, the return value for ``train_scores`` and ``test_scores`` is a single float. For a dict, it should be one mapping the scorer name to the scorer callable object / function. The callable object / fn should have signature ``scorer(estimator, X, y)``. train : array-like of shape (n_train_samples,) Indices of training samples. test : array-like of shape (n_test_samples,) Indices of test samples. verbose : int The verbosity level. error_score : 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. parameters : dict or None Parameters to be set on the estimator. fit_params : dict or None Parameters that will be passed to ``estimator.fit``. return_train_score : bool, default=False Compute and return score on training set. return_parameters : bool, default=False Return parameters that has been used for the estimator. split_progress : {list, tuple} of int, default=None A list or tuple of format (, ). candidate_progress : {list, tuple} of int, default=None A list or tuple of format (, ). return_n_test_samples : bool, default=False Whether to return the ``n_test_samples``. return_times : bool, default=False Whether to return the fit/score times. return_estimator : bool, default=False Whether to return the fitted estimator. Returns ------- result : dict with the following attributes train_scores : dict of scorer name -> float Score on training set (for all the scorers), returned only if `return_train_score` is `True`. test_scores : dict of scorer name -> float Score on testing set (for all the scorers). n_test_samples : int Number of test samples. fit_time : float Time spent for fitting in seconds. score_time : float Time spent for scoring in seconds. parameters : dict or None The parameters that have been evaluated. estimator : estimator object The fitted estimator. fit_error : str or None Traceback str if the fit failed, None if the fit succeeded. """ if not isinstance(error_score, numbers.Number) and error_score != "raise": raise ValueError( "error_score must be the string 'raise' or a numeric value. " "(Hint: if using 'raise', please make sure that it has been " "spelled correctly.)" ) progress_msg = "" if verbose > 2: if split_progress is not None: progress_msg = f" {split_progress[0]+1}/{split_progress[1]}" if candidate_progress and verbose > 9: progress_msg += f"; {candidate_progress[0]+1}/{candidate_progress[1]}" if verbose > 1: if parameters is None: params_msg = "" else: sorted_keys = sorted(parameters) # Ensure deterministic o/p params_msg = ", ".join(f"{k}={parameters[k]}" for k in sorted_keys) if verbose > 9: start_msg = f"[CV{progress_msg}] START {params_msg}" print(f"{start_msg}{(80 - len(start_msg)) * '.'}") # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} fit_params = _check_fit_params(X, fit_params, train) if parameters is not None: # clone after setting parameters in case any parameters # are estimators (like pipeline steps) # because pipeline doesn't clone steps in fit cloned_parameters = {} for k, v in parameters.items(): cloned_parameters[k] = clone(v, safe=False) estimator = estimator.set_params(**cloned_parameters) start_time = time.time() X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) result = {} try: if y_train is None: estimator.fit(X_train, **fit_params) else: estimator.fit(X_train, y_train, **fit_params) except Exception: # Note fit time as time until error fit_time = time.time() - start_time score_time = 0.0 if error_score == "raise": raise elif isinstance(error_score, numbers.Number): if isinstance(scorer, dict): test_scores = {name: error_score for name in scorer} if return_train_score: train_scores = test_scores.copy() else: test_scores = error_score if return_train_score: train_scores = error_score result["fit_error"] = format_exc() else: result["fit_error"] = None fit_time = time.time() - start_time test_scores = _score(estimator, X_test, y_test, scorer, error_score) score_time = time.time() - start_time - fit_time if return_train_score: train_scores = _score(estimator, X_train, y_train, scorer, error_score) if verbose > 1: total_time = score_time + fit_time end_msg = f"[CV{progress_msg}] END " result_msg = params_msg + (";" if params_msg else "") if verbose > 2: if isinstance(test_scores, dict): for scorer_name in sorted(test_scores): result_msg += f" {scorer_name}: (" if return_train_score: scorer_scores = train_scores[scorer_name] result_msg += f"train={scorer_scores:.3f}, " result_msg += f"test={test_scores[scorer_name]:.3f})" else: result_msg += ", score=" if return_train_score: result_msg += f"(train={train_scores:.3f}, test={test_scores:.3f})" else: result_msg += f"{test_scores:.3f}" result_msg += f" total time={logger.short_format_time(total_time)}" # Right align the result_msg end_msg += "." * (80 - len(end_msg) - len(result_msg)) end_msg += result_msg print(end_msg) result["test_scores"] = test_scores if return_train_score: result["train_scores"] = train_scores if return_n_test_samples: result["n_test_samples"] = _num_samples(X_test) if return_times: result["fit_time"] = fit_time result["score_time"] = score_time if return_parameters: result["parameters"] = parameters if return_estimator: result["estimator"] = estimator return result def _score(estimator, X_test, y_test, scorer, error_score="raise"): """Compute the score(s) of an estimator on a given test set. Will return a dict of floats if `scorer` is a dict, otherwise a single float is returned. """ if isinstance(scorer, dict): # will cache method calls if needed. scorer() returns a dict scorer = _MultimetricScorer(scorers=scorer, raise_exc=(error_score == "raise")) try: if y_test is None: scores = scorer(estimator, X_test) else: scores = scorer(estimator, X_test, y_test) except Exception: if isinstance(scorer, _MultimetricScorer): # If `_MultimetricScorer` raises exception, the `error_score` # parameter is equal to "raise". raise else: if error_score == "raise": raise else: scores = error_score warnings.warn( "Scoring failed. The score on this train-test partition for " f"these parameters will be set to {error_score}. Details: \n" f"{format_exc()}", UserWarning, ) # Check non-raised error messages in `_MultimetricScorer` if isinstance(scorer, _MultimetricScorer): exception_messages = [ (name, str_e) for name, str_e in scores.items() if isinstance(str_e, str) ] if exception_messages: # error_score != "raise" for name, str_e in exception_messages: scores[name] = error_score warnings.warn( "Scoring failed. The score on this train-test partition for " f"these parameters will be set to {error_score}. Details: \n" f"{str_e}", UserWarning, ) error_msg = "scoring must return a number, got %s (%s) instead. (scorer=%s)" if isinstance(scores, dict): for name, score in scores.items(): if hasattr(score, "item"): with suppress(ValueError): # e.g. unwrap memmapped scalars score = score.item() if not isinstance(score, numbers.Number): raise ValueError(error_msg % (score, type(score), name)) scores[name] = score else: # scalar if hasattr(scores, "item"): with suppress(ValueError): # e.g. unwrap memmapped scalars scores = scores.item() if not isinstance(scores, numbers.Number): raise ValueError(error_msg % (scores, type(scores), scorer)) return scores def cross_val_predict( estimator, X, y=None, *, groups=None, cv=None, n_jobs=None, verbose=0, fit_params=None, pre_dispatch="2*n_jobs", method="predict", ): """Generate cross-validated estimates for each input data point. The data is split according to the cv parameter. Each sample belongs to exactly one test set, and its prediction is computed with an estimator fitted on the corresponding training set. Passing these predictions into an evaluation metric may not be a valid way to measure generalization performance. Results can differ from :func:`cross_validate` and :func:`cross_val_score` unless all tests sets have equal size and the metric decomposes over samples. Read more in the :ref:`User Guide `. Parameters ---------- estimator : estimator object implementing 'fit' and 'predict' The object to use to fit the data. X : array-like of shape (n_samples, n_features) The data to fit. Can be, for example a list, or an array at least 2d. y : array-like of shape (n_samples,) or (n_samples, n_outputs), \ default=None The target variable to try to predict in the case of supervised learning. groups : array-like of shape (n_samples,), default=None Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`GroupKFold`). cv : int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - int, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable that generates (train, test) splits as arrays of indices. For int/None inputs, if the estimator is a classifier and ``y`` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 ``cv`` default value if None changed from 3-fold to 5-fold. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and predicting are parallelized over the cross-validation splits. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. verbose : int, default=0 The verbosity level. fit_params : dict, default=None Parameters to pass to the fit method of the estimator. pre_dispatch : int or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an explosion of memory consumption when more jobs get dispatched than CPUs can process. This parameter can be: - None, in which case all the jobs are immediately created and spawned. Use this for lightweight and fast-running jobs, to avoid delays due to on-demand spawning of the jobs - An int, giving the exact number of total jobs that are spawned - A str, giving an expression as a function of n_jobs, as in '2*n_jobs' method : {'predict', 'predict_proba', 'predict_log_proba', \ 'decision_function'}, default='predict' The method to be invoked by `estimator`. Returns ------- predictions : ndarray This is the result of calling `method`. Shape: - When `method` is 'predict' and in special case where `method` is 'decision_function' and the target is binary: (n_samples,) - When `method` is one of {'predict_proba', 'predict_log_proba', 'decision_function'} (unless special case above): (n_samples, n_classes) - If `estimator` is :term:`multioutput`, an extra dimension 'n_outputs' is added to the end of each shape above. See Also -------- cross_val_score : Calculate score for each CV split. cross_validate : Calculate one or more scores and timings for each CV split. Notes ----- In the case that one or more classes are absent in a training portion, a default score needs to be assigned to all instances for that class if ``method`` produces columns per class, as in {'decision_function', 'predict_proba', 'predict_log_proba'}. For ``predict_proba`` this value is 0. In order to ensure finite output, we approximate negative infinity by the minimum finite float value for the dtype in other cases. Examples -------- >>> from sklearn import datasets, linear_model >>> from sklearn.model_selection import cross_val_predict >>> diabetes = datasets.load_diabetes() >>> X = diabetes.data[:150] >>> y = diabetes.target[:150] >>> lasso = linear_model.Lasso() >>> y_pred = cross_val_predict(lasso, X, y, cv=3) """ X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) splits = list(cv.split(X, y, groups)) test_indices = np.concatenate([test for _, test in splits]) if not _check_is_permutation(test_indices, _num_samples(X)): raise ValueError("cross_val_predict only works for partitions") # If classification methods produce multiple columns of output, # we need to manually encode classes to ensure consistent column ordering. encode = ( method in ["decision_function", "predict_proba", "predict_log_proba"] and y is not None ) if encode: y = np.asarray(y) if y.ndim == 1: le = LabelEncoder() y = le.fit_transform(y) elif y.ndim == 2: y_enc = np.zeros_like(y, dtype=int) for i_label in range(y.shape[1]): y_enc[:, i_label] = LabelEncoder().fit_transform(y[:, i_label]) y = y_enc # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) predictions = parallel( delayed(_fit_and_predict)( clone(estimator), X, y, train, test, verbose, fit_params, method ) for train, test in splits ) inv_test_indices = np.empty(len(test_indices), dtype=int) inv_test_indices[test_indices] = np.arange(len(test_indices)) if sp.issparse(predictions[0]): predictions = sp.vstack(predictions, format=predictions[0].format) elif encode and isinstance(predictions[0], list): # `predictions` is a list of method outputs from each fold. # If each of those is also a list, then treat this as a # multioutput-multiclass task. We need to separately concatenate # the method outputs for each label into an `n_labels` long list. n_labels = y.shape[1] concat_pred = [] for i_label in range(n_labels): label_preds = np.concatenate([p[i_label] for p in predictions]) concat_pred.append(label_preds) predictions = concat_pred else: predictions = np.concatenate(predictions) if isinstance(predictions, list): return [p[inv_test_indices] for p in predictions] else: return predictions[inv_test_indices] def _fit_and_predict(estimator, X, y, train, test, verbose, fit_params, method): """Fit estimator and predict values for a given dataset split. Read more in the :ref:`User Guide `. Parameters ---------- estimator : estimator object implementing 'fit' and 'predict' The object to use to fit the data. X : array-like of shape (n_samples, n_features) The data to fit. .. versionchanged:: 0.20 X is only required to be an object with finite length or shape now y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None The target variable to try to predict in the case of supervised learning. train : array-like of shape (n_train_samples,) Indices of training samples. test : array-like of shape (n_test_samples,) Indices of test samples. verbose : int The verbosity level. fit_params : dict or None Parameters that will be passed to ``estimator.fit``. method : str Invokes the passed method name of the passed estimator. Returns ------- predictions : sequence Result of calling 'estimator.method' """ # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} fit_params = _check_fit_params(X, fit_params, train) X_train, y_train = _safe_split(estimator, X, y, train) X_test, _ = _safe_split(estimator, X, y, test, train) if y_train is None: estimator.fit(X_train, **fit_params) else: estimator.fit(X_train, y_train, **fit_params) func = getattr(estimator, method) predictions = func(X_test) encode = ( method in ["decision_function", "predict_proba", "predict_log_proba"] and y is not None ) if encode: if isinstance(predictions, list): predictions = [ _enforce_prediction_order( estimator.classes_[i_label], predictions[i_label], n_classes=len(set(y[:, i_label])), method=method, ) for i_label in range(len(predictions)) ] else: # A 2D y array should be a binary label indicator matrix n_classes = len(set(y)) if y.ndim == 1 else y.shape[1] predictions = _enforce_prediction_order( estimator.classes_, predictions, n_classes, method ) return predictions def _enforce_prediction_order(classes, predictions, n_classes, method): """Ensure that prediction arrays have correct column order When doing cross-validation, if one or more classes are not present in the subset of data used for training, then the output prediction array might not have the same columns as other folds. Use the list of class names (assumed to be ints) to enforce the correct column order. Note that `classes` is the list of classes in this fold (a subset of the classes in the full training set) and `n_classes` is the number of classes in the full training set. """ if n_classes != len(classes): recommendation = ( "To fix this, use a cross-validation " "technique resulting in properly " "stratified folds" ) warnings.warn( "Number of classes in training fold ({}) does " "not match total number of classes ({}). " "Results may not be appropriate for your use case. " "{}".format(len(classes), n_classes, recommendation), RuntimeWarning, ) if method == "decision_function": if predictions.ndim == 2 and predictions.shape[1] != len(classes): # This handles the case when the shape of predictions # does not match the number of classes used to train # it with. This case is found when sklearn.svm.SVC is # set to `decision_function_shape='ovo'`. raise ValueError( "Output shape {} of {} does not match " "number of classes ({}) in fold. " "Irregular decision_function outputs " "are not currently supported by " "cross_val_predict".format(predictions.shape, method, len(classes)) ) if len(classes) <= 2: # In this special case, `predictions` contains a 1D array. raise ValueError( "Only {} class/es in training fold, but {} " "in overall dataset. This " "is not supported for decision_function " "with imbalanced folds. {}".format( len(classes), n_classes, recommendation ) ) float_min = np.finfo(predictions.dtype).min default_values = { "decision_function": float_min, "predict_log_proba": float_min, "predict_proba": 0, } predictions_for_all_classes = np.full( (_num_samples(predictions), n_classes), default_values[method], dtype=predictions.dtype, ) predictions_for_all_classes[:, classes] = predictions predictions = predictions_for_all_classes return predictions def _check_is_permutation(indices, n_samples): """Check whether indices is a reordering of the array np.arange(n_samples) Parameters ---------- indices : ndarray int array to test n_samples : int number of expected elements Returns ------- is_partition : bool True iff sorted(indices) is np.arange(n) """ if len(indices) != n_samples: return False hit = np.zeros(n_samples, dtype=bool) hit[indices] = True if not np.all(hit): return False return True def permutation_test_score( estimator, X, y, *, groups=None, cv=None, n_permutations=100, n_jobs=None, random_state=0, verbose=0, scoring=None, fit_params=None, ): """Evaluate the significance of a cross-validated score with permutations. Permutes targets to generate 'randomized data' and compute the empirical p-value against the null hypothesis that features and targets are independent. The p-value represents the fraction of randomized data sets where the estimator performed as well or better than in the original data. A small p-value suggests that there is a real dependency between features and targets which has been used by the estimator to give good predictions. A large p-value may be due to lack of real dependency between features and targets or the estimator was not able to use the dependency to give good predictions. Read more in the :ref:`User Guide `. Parameters ---------- estimator : estimator object implementing 'fit' The object to use to fit the data. X : array-like of shape at least 2D The data to fit. y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None The target variable to try to predict in the case of supervised learning. groups : array-like of shape (n_samples,), default=None Labels to constrain permutation within groups, i.e. ``y`` values are permuted among samples with the same group identifier. When not specified, ``y`` values are permuted among all samples. When a grouped cross-validator is used, the group labels are also passed on to the ``split`` method of the cross-validator. The cross-validator uses them for grouping the samples while splitting the dataset into train/test set. cv : int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - `None`, to use the default 5-fold cross validation, - int, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. For `int`/`None` inputs, if the estimator is a classifier and `y` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 `cv` default value if `None` changed from 3-fold to 5-fold. n_permutations : int, default=100 Number of times to permute ``y``. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and computing the cross-validated score are parallelized over the permutations. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. random_state : int, RandomState instance or None, default=0 Pass an int for reproducible output for permutation of ``y`` values among samples. See :term:`Glossary `. verbose : int, default=0 The verbosity level. scoring : str or callable, default=None A single str (see :ref:`scoring_parameter`) or a callable (see :ref:`scoring`) to evaluate the predictions on the test set. If `None` the estimator's score method is used. fit_params : dict, default=None Parameters to pass to the fit method of the estimator. .. versionadded:: 0.24 Returns ------- score : float The true score without permuting targets. permutation_scores : array of shape (n_permutations,) The scores obtained for each permutations. pvalue : float The p-value, which approximates the probability that the score would be obtained by chance. This is calculated as: `(C + 1) / (n_permutations + 1)` Where C is the number of permutations whose score >= the true score. The best possible p-value is 1/(n_permutations + 1), the worst is 1.0. Notes ----- This function implements Test 1 in: Ojala and Garriga. `Permutation Tests for Studying Classifier Performance `_. The Journal of Machine Learning Research (2010) vol. 11 """ X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) scorer = check_scoring(estimator, scoring=scoring) random_state = check_random_state(random_state) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. score = _permutation_test_score( clone(estimator), X, y, groups, cv, scorer, fit_params=fit_params ) permutation_scores = Parallel(n_jobs=n_jobs, verbose=verbose)( delayed(_permutation_test_score)( clone(estimator), X, _shuffle(y, groups, random_state), groups, cv, scorer, fit_params=fit_params, ) for _ in range(n_permutations) ) permutation_scores = np.array(permutation_scores) pvalue = (np.sum(permutation_scores >= score) + 1.0) / (n_permutations + 1) return score, permutation_scores, pvalue def _permutation_test_score(estimator, X, y, groups, cv, scorer, fit_params): """Auxiliary function for permutation_test_score""" # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} avg_score = [] for train, test in cv.split(X, y, groups): X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) fit_params = _check_fit_params(X, fit_params, train) estimator.fit(X_train, y_train, **fit_params) avg_score.append(scorer(estimator, X_test, y_test)) return np.mean(avg_score) def _shuffle(y, groups, random_state): """Return a shuffled copy of y eventually shuffle among same groups.""" if groups is None: indices = random_state.permutation(len(y)) else: indices = np.arange(len(groups)) for group in np.unique(groups): this_mask = groups == group indices[this_mask] = random_state.permutation(indices[this_mask]) return _safe_indexing(y, indices) def learning_curve( estimator, X, y, *, groups=None, train_sizes=np.linspace(0.1, 1.0, 5), cv=None, scoring=None, exploit_incremental_learning=False, n_jobs=None, pre_dispatch="all", verbose=0, shuffle=False, random_state=None, error_score=np.nan, return_times=False, fit_params=None, ): """Learning curve. Determines cross-validated training and test scores for different training set sizes. A cross-validation generator splits the whole dataset k times in training and test data. Subsets of the training set with varying sizes will be used to train the estimator and a score for each training subset size and the test set will be computed. Afterwards, the scores will be averaged over all k runs for each training subset size. Read more in the :ref:`User Guide `. Parameters ---------- estimator : object type that implements the "fit" and "predict" methods An object of that type which is cloned for each validation. X : array-like of shape (n_samples, n_features) Training vector, where `n_samples` is the number of samples and `n_features` is the number of features. y : array-like of shape (n_samples,) or (n_samples, n_outputs) Target relative to X for classification or regression; None for unsupervised learning. groups : array-like of shape (n_samples,), default=None Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`GroupKFold`). train_sizes : array-like of shape (n_ticks,), \ default=np.linspace(0.1, 1.0, 5) Relative or absolute numbers of training examples that will be used to generate the learning curve. If the dtype is float, it is regarded as a fraction of the maximum size of the training set (that is determined by the selected validation method), i.e. it has to be within (0, 1]. Otherwise it is interpreted as absolute sizes of the training sets. Note that for classification the number of samples usually have to be big enough to contain at least one sample from each class. cv : int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - int, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. For int/None inputs, if the estimator is a classifier and ``y`` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 ``cv`` default value if None changed from 3-fold to 5-fold. scoring : str or callable, default=None A str (see model evaluation documentation) or a scorer callable object / function with signature ``scorer(estimator, X, y)``. exploit_incremental_learning : bool, default=False If the estimator supports incremental learning, this will be used to speed up fitting for different training set sizes. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and computing the score are parallelized over the different training and test sets. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. pre_dispatch : int or str, default='all' Number of predispatched jobs for parallel execution (default is all). The option can reduce the allocated memory. The str can be an expression like '2*n_jobs'. verbose : int, default=0 Controls the verbosity: the higher, the more messages. shuffle : bool, default=False Whether to shuffle training data before taking prefixes of it based on``train_sizes``. random_state : int, RandomState instance or None, default=None Used when ``shuffle`` is True. Pass an int for reproducible output across multiple function calls. See :term:`Glossary `. error_score : 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. .. versionadded:: 0.20 return_times : bool, default=False Whether to return the fit and score times. fit_params : dict, default=None Parameters to pass to the fit method of the estimator. .. versionadded:: 0.24 Returns ------- train_sizes_abs : array of shape (n_unique_ticks,) Numbers of training examples that has been used to generate the learning curve. Note that the number of ticks might be less than n_ticks because duplicate entries will be removed. train_scores : array of shape (n_ticks, n_cv_folds) Scores on training sets. test_scores : array of shape (n_ticks, n_cv_folds) Scores on test set. fit_times : array of shape (n_ticks, n_cv_folds) Times spent for fitting in seconds. Only present if ``return_times`` is True. score_times : array of shape (n_ticks, n_cv_folds) Times spent for scoring in seconds. Only present if ``return_times`` is True. Examples -------- >>> from sklearn.datasets import make_classification >>> from sklearn.tree import DecisionTreeClassifier >>> from sklearn.model_selection import learning_curve >>> X, y = make_classification(n_samples=100, n_features=10, random_state=42) >>> tree = DecisionTreeClassifier(max_depth=4, random_state=42) >>> train_size_abs, train_scores, test_scores = learning_curve( ... tree, X, y, train_sizes=[0.3, 0.6, 0.9] ... ) >>> for train_size, cv_train_scores, cv_test_scores in zip( ... train_size_abs, train_scores, test_scores ... ): ... print(f"{train_size} samples were used to train the model") ... print(f"The average train accuracy is {cv_train_scores.mean():.2f}") ... print(f"The average test accuracy is {cv_test_scores.mean():.2f}") 24 samples were used to train the model The average train accuracy is 1.00 The average test accuracy is 0.85 48 samples were used to train the model The average train accuracy is 1.00 The average test accuracy is 0.90 72 samples were used to train the model The average train accuracy is 1.00 The average test accuracy is 0.93 """ if exploit_incremental_learning and not hasattr(estimator, "partial_fit"): raise ValueError( "An estimator must support the partial_fit interface " "to exploit incremental learning" ) X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) # Store it as list as we will be iterating over the list multiple times cv_iter = list(cv.split(X, y, groups)) scorer = check_scoring(estimator, scoring=scoring) n_max_training_samples = len(cv_iter[0][0]) # Because the lengths of folds can be significantly different, it is # not guaranteed that we use all of the available training data when we # use the first 'n_max_training_samples' samples. train_sizes_abs = _translate_train_sizes(train_sizes, n_max_training_samples) n_unique_ticks = train_sizes_abs.shape[0] if verbose > 0: print("[learning_curve] Training set sizes: " + str(train_sizes_abs)) parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose) if shuffle: rng = check_random_state(random_state) cv_iter = ((rng.permutation(train), test) for train, test in cv_iter) if exploit_incremental_learning: classes = np.unique(y) if is_classifier(estimator) else None out = parallel( delayed(_incremental_fit_estimator)( clone(estimator), X, y, classes, train, test, train_sizes_abs, scorer, verbose, return_times, error_score=error_score, fit_params=fit_params, ) for train, test in cv_iter ) out = np.asarray(out).transpose((2, 1, 0)) else: train_test_proportions = [] for train, test in cv_iter: for n_train_samples in train_sizes_abs: train_test_proportions.append((train[:n_train_samples], test)) results = parallel( delayed(_fit_and_score)( clone(estimator), X, y, scorer, train, test, verbose, parameters=None, fit_params=fit_params, return_train_score=True, error_score=error_score, return_times=return_times, ) for train, test in train_test_proportions ) results = _aggregate_score_dicts(results) train_scores = results["train_scores"].reshape(-1, n_unique_ticks).T test_scores = results["test_scores"].reshape(-1, n_unique_ticks).T out = [train_scores, test_scores] if return_times: fit_times = results["fit_time"].reshape(-1, n_unique_ticks).T score_times = results["score_time"].reshape(-1, n_unique_ticks).T out.extend([fit_times, score_times]) ret = train_sizes_abs, out[0], out[1] if return_times: ret = ret + (out[2], out[3]) return ret def _translate_train_sizes(train_sizes, n_max_training_samples): """Determine absolute sizes of training subsets and validate 'train_sizes'. Examples: _translate_train_sizes([0.5, 1.0], 10) -> [5, 10] _translate_train_sizes([5, 10], 10) -> [5, 10] Parameters ---------- train_sizes : array-like of shape (n_ticks,) Numbers of training examples that will be used to generate the learning curve. If the dtype is float, it is regarded as a fraction of 'n_max_training_samples', i.e. it has to be within (0, 1]. n_max_training_samples : int Maximum number of training samples (upper bound of 'train_sizes'). Returns ------- train_sizes_abs : array of shape (n_unique_ticks,) Numbers of training examples that will be used to generate the learning curve. Note that the number of ticks might be less than n_ticks because duplicate entries will be removed. """ train_sizes_abs = np.asarray(train_sizes) n_ticks = train_sizes_abs.shape[0] n_min_required_samples = np.min(train_sizes_abs) n_max_required_samples = np.max(train_sizes_abs) if np.issubdtype(train_sizes_abs.dtype, np.floating): if n_min_required_samples <= 0.0 or n_max_required_samples > 1.0: raise ValueError( "train_sizes has been interpreted as fractions " "of the maximum number of training samples and " "must be within (0, 1], but is within [%f, %f]." % (n_min_required_samples, n_max_required_samples) ) train_sizes_abs = (train_sizes_abs * n_max_training_samples).astype( dtype=int, copy=False ) train_sizes_abs = np.clip(train_sizes_abs, 1, n_max_training_samples) else: if ( n_min_required_samples <= 0 or n_max_required_samples > n_max_training_samples ): raise ValueError( "train_sizes has been interpreted as absolute " "numbers of training samples and must be within " "(0, %d], but is within [%d, %d]." % ( n_max_training_samples, n_min_required_samples, n_max_required_samples, ) ) train_sizes_abs = np.unique(train_sizes_abs) if n_ticks > train_sizes_abs.shape[0]: warnings.warn( "Removed duplicate entries from 'train_sizes'. Number " "of ticks will be less than the size of " "'train_sizes': %d instead of %d." % (train_sizes_abs.shape[0], n_ticks), RuntimeWarning, ) return train_sizes_abs def _incremental_fit_estimator( estimator, X, y, classes, train, test, train_sizes, scorer, verbose, return_times, error_score, fit_params, ): """Train estimator on training subsets incrementally and compute scores.""" train_scores, test_scores, fit_times, score_times = [], [], [], [] partitions = zip(train_sizes, np.split(train, train_sizes)[:-1]) if fit_params is None: fit_params = {} if classes is None: partial_fit_func = partial(estimator.partial_fit, **fit_params) else: partial_fit_func = partial(estimator.partial_fit, classes=classes, **fit_params) for n_train_samples, partial_train in partitions: train_subset = train[:n_train_samples] X_train, y_train = _safe_split(estimator, X, y, train_subset) X_partial_train, y_partial_train = _safe_split(estimator, X, y, partial_train) X_test, y_test = _safe_split(estimator, X, y, test, train_subset) start_fit = time.time() if y_partial_train is None: partial_fit_func(X_partial_train) else: partial_fit_func(X_partial_train, y_partial_train) fit_time = time.time() - start_fit fit_times.append(fit_time) start_score = time.time() test_scores.append(_score(estimator, X_test, y_test, scorer, error_score)) train_scores.append(_score(estimator, X_train, y_train, scorer, error_score)) score_time = time.time() - start_score score_times.append(score_time) ret = ( (train_scores, test_scores, fit_times, score_times) if return_times else (train_scores, test_scores) ) return np.array(ret).T def validation_curve( estimator, X, y, *, param_name, param_range, groups=None, cv=None, scoring=None, n_jobs=None, pre_dispatch="all", verbose=0, error_score=np.nan, fit_params=None, ): """Validation curve. Determine training and test scores for varying parameter values. Compute scores for an estimator with different values of a specified parameter. This is similar to grid search with one parameter. However, this will also compute training scores and is merely a utility for plotting the results. Read more in the :ref:`User Guide `. Parameters ---------- estimator : object type that implements the "fit" and "predict" methods An object of that type which is cloned for each validation. X : array-like of shape (n_samples, n_features) Training vector, where `n_samples` is the number of samples and `n_features` is the number of features. y : array-like of shape (n_samples,) or (n_samples, n_outputs) or None Target relative to X for classification or regression; None for unsupervised learning. param_name : str Name of the parameter that will be varied. param_range : array-like of shape (n_values,) The values of the parameter that will be evaluated. groups : array-like of shape (n_samples,), default=None Group labels for the samples used while splitting the dataset into train/test set. Only used in conjunction with a "Group" :term:`cv` instance (e.g., :class:`GroupKFold`). cv : int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - int, to specify the number of folds in a `(Stratified)KFold`, - :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. For int/None inputs, if the estimator is a classifier and ``y`` is either binary or multiclass, :class:`StratifiedKFold` is used. In all other cases, :class:`KFold` is used. These splitters are instantiated with `shuffle=False` so the splits will be the same across calls. Refer :ref:`User Guide ` for the various cross-validation strategies that can be used here. .. versionchanged:: 0.22 ``cv`` default value if None changed from 3-fold to 5-fold. scoring : str or callable, default=None A str (see model evaluation documentation) or a scorer callable object / function with signature ``scorer(estimator, X, y)``. n_jobs : int, default=None Number of jobs to run in parallel. Training the estimator and computing the score are parallelized over the combinations of each parameter value and each cross-validation split. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. ``-1`` means using all processors. See :term:`Glossary ` for more details. pre_dispatch : int or str, default='all' Number of predispatched jobs for parallel execution (default is all). The option can reduce the allocated memory. The str can be an expression like '2*n_jobs'. verbose : int, default=0 Controls the verbosity: the higher, the more messages. error_score : 'raise' or numeric, default=np.nan Value to assign to the score if an error occurs in estimator fitting. If set to 'raise', the error is raised. If a numeric value is given, FitFailedWarning is raised. .. versionadded:: 0.20 fit_params : dict, default=None Parameters to pass to the fit method of the estimator. .. versionadded:: 0.24 Returns ------- train_scores : array of shape (n_ticks, n_cv_folds) Scores on training sets. test_scores : array of shape (n_ticks, n_cv_folds) Scores on test set. Notes ----- See :ref:`sphx_glr_auto_examples_model_selection_plot_validation_curve.py` """ X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) scorer = check_scoring(estimator, scoring=scoring) parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose) results = parallel( delayed(_fit_and_score)( clone(estimator), X, y, scorer, train, test, verbose, parameters={param_name: v}, fit_params=fit_params, return_train_score=True, error_score=error_score, ) # NOTE do not change order of iteration to allow one time cv splitters for train, test in cv.split(X, y, groups) for v in param_range ) n_params = len(param_range) results = _aggregate_score_dicts(results) train_scores = results["train_scores"].reshape(-1, n_params).T test_scores = results["test_scores"].reshape(-1, n_params).T return train_scores, test_scores def _aggregate_score_dicts(scores): """Aggregate the list of dict to dict of np ndarray The aggregated output of _aggregate_score_dicts will be a list of dict of form [{'prec': 0.1, 'acc':1.0}, {'prec': 0.1, 'acc':1.0}, ...] Convert it to a dict of array {'prec': np.array([0.1 ...]), ...} Parameters ---------- scores : list of dict List of dicts of the scores for all scorers. This is a flat list, assumed originally to be of row major order. Example ------- >>> scores = [{'a': 1, 'b':10}, {'a': 2, 'b':2}, {'a': 3, 'b':3}, ... {'a': 10, 'b': 10}] # doctest: +SKIP >>> _aggregate_score_dicts(scores) # doctest: +SKIP {'a': array([1, 2, 3, 10]), 'b': array([10, 2, 3, 10])} """ return { key: np.asarray([score[key] for score in scores]) if isinstance(scores[0][key], numbers.Number) else [score[key] for score in scores] for key in scores[0] }