projektAI/venv/Lib/site-packages/mlxtend/evaluate/holdout.py
2021-06-06 22:13:05 +02:00

191 lines
5.8 KiB
Python

# Sebastian Raschka 2014-2020
# mlxtend Machine Learning Library Extensions
#
# Author: Sebastian Raschka <sebastianraschka.com>
#
# License: BSD 3 clause
import numpy as np
from sklearn.model_selection import train_test_split
class RandomHoldoutSplit(object):
"""Train/Validation set splitter for sklearn's GridSearchCV etc.
Provides train/validation set indices to split a dataset
into train/validation sets using random indices.
Parameters
----------
valid_size : float (default: 0.5)
Proportion of examples that being assigned as
validation examples. 1-`valid_size` will then automatically
be assigned as training set examples.
random_seed : int (default: None)
The random seed for splitting the data
into training and validation set partitions.
stratify : bool (default: False)
True or False, whether to perform a stratified
split or not
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/evaluate/RandomHoldoutSplit/
"""
def __init__(self, valid_size=0.5, random_seed=None, stratify=False):
self.valid_size = valid_size
self.random_seed = random_seed
self.stratify = stratify
def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Parameters
----------
X : array-like, shape (num_examples, num_features)
Training data, where num_examples is the number of
training examples and num_features is the number of features.
y : array-like, shape (num_examples,)
The target variable for supervised learning problems.
Stratification is done based on the y labels.
groups : object
Always ignored, exists for compatibility.
Yields
------
train_index : ndarray
The training set indices for that split.
valid_index : ndarray
The validation set indices for that split.
"""
ind = np.arange(X.shape[0])
if self.stratify:
train_index, valid_index, _, _ = \
train_test_split(ind, y,
test_size=self.valid_size,
shuffle=True,
stratify=y,
random_state=self.random_seed)
else:
train_index, valid_index, _, _ = \
train_test_split(ind, y,
test_size=self.valid_size,
shuffle=True,
stratify=y,
random_state=self.random_seed)
for i in range(1):
yield train_index, valid_index
def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the cross-validator
Parameters
----------
X : object
Always ignored, exists for compatibility.
y : object
Always ignored, exists for compatibility.
groups : object
Always ignored, exists for compatibility.
Returns
-------
n_splits : 1
Returns the number of splitting iterations in the cross-validator.
Always returns 1.
"""
return 1
class PredefinedHoldoutSplit(object):
"""Train/Validation set splitter for sklearn's GridSearchCV etc.
Uses user-specified train/validation set indices to split a dataset
into train/validation sets using user-defined or random
indices.
Parameters
----------
valid_indices : array-like, shape (num_examples,)
Indices of the training examples in the training set
to be used for validation. All other indices in the
training set are used to for a training subset
for model fitting.
Examples
-----------
For usage examples, please see
http://rasbt.github.io/mlxtend/user_guide/evaluate/PredefinedHoldoutSplit/
"""
def __init__(self, valid_indices):
self.valid_indices = valid_indices
def split(self, X, y, groups=None):
"""Generate indices to split data into training and test set.
Parameters
----------
X : array-like, shape (num_examples, num_features)
Training data, where num_examples is the number of examples
and num_features is the number of features.
y : array-like, shape (num_examples,)
The target variable for supervised learning problems.
Stratification is done based on the y labels.
groups : object
Always ignored, exists for compatibility.
Yields
------
train_index : ndarray
The training set indices for that split.
valid_index : ndarray
The validation set indices for that split.
"""
ind = np.arange(X.shape[0])
train_mask = np.ones(X.shape[0], dtype=np.bool)
train_mask[self.valid_indices] = False
valid_mask = np.where(train_mask, False, True)
for i in range(1):
yield ind[train_mask], ind[valid_mask]
def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the cross-validator
Parameters
----------
X : object
Always ignored, exists for compatibility.
y : object
Always ignored, exists for compatibility.
groups : object
Always ignored, exists for compatibility.
Returns
-------
n_splits : 1
Returns the number of splitting iterations in the cross-validator.
Always returns 1.
"""
return 1