402 lines
14 KiB
Python
402 lines
14 KiB
Python
|
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Wrapper for using the Scikit-Learn API with Keras models."""
|
||
|
|
||
|
|
||
|
import copy
|
||
|
import types
|
||
|
import warnings
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
from keras import losses
|
||
|
from keras.models import Sequential
|
||
|
from keras.utils.generic_utils import has_arg
|
||
|
from keras.utils.np_utils import to_categorical
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
from tensorflow.tools.docs import doc_controls
|
||
|
|
||
|
|
||
|
class BaseWrapper:
|
||
|
"""Base class for the Keras scikit-learn wrapper.
|
||
|
|
||
|
Warning: This class should not be used directly.
|
||
|
Use descendant classes instead.
|
||
|
|
||
|
Args:
|
||
|
build_fn: callable function or class instance
|
||
|
**sk_params: model parameters & fitting parameters
|
||
|
|
||
|
The `build_fn` should construct, compile and return a Keras model, which
|
||
|
will then be used to fit/predict. One of the following
|
||
|
three values could be passed to `build_fn`:
|
||
|
1. A function
|
||
|
2. An instance of a class that implements the `__call__` method
|
||
|
3. None. This means you implement a class that inherits from either
|
||
|
`KerasClassifier` or `KerasRegressor`. The `__call__` method of the
|
||
|
present class will then be treated as the default `build_fn`.
|
||
|
|
||
|
`sk_params` takes both model parameters and fitting parameters. Legal model
|
||
|
parameters are the arguments of `build_fn`. Note that like all other
|
||
|
estimators in scikit-learn, `build_fn` should provide default values for
|
||
|
its arguments, so that you could create the estimator without passing any
|
||
|
values to `sk_params`.
|
||
|
|
||
|
`sk_params` could also accept parameters for calling `fit`, `predict`,
|
||
|
`predict_proba`, and `score` methods (e.g., `epochs`, `batch_size`).
|
||
|
fitting (predicting) parameters are selected in the following order:
|
||
|
|
||
|
1. Values passed to the dictionary arguments of
|
||
|
`fit`, `predict`, `predict_proba`, and `score` methods
|
||
|
2. Values passed to `sk_params`
|
||
|
3. The default values of the `keras.models.Sequential`
|
||
|
`fit`, `predict` methods.
|
||
|
|
||
|
When using scikit-learn's `grid_search` API, legal tunable parameters are
|
||
|
those you could pass to `sk_params`, including fitting parameters.
|
||
|
In other words, you could use `grid_search` to search for the best
|
||
|
`batch_size` or `epochs` as well as the model parameters.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, build_fn=None, **sk_params):
|
||
|
self.build_fn = build_fn
|
||
|
self.sk_params = sk_params
|
||
|
self.check_params(sk_params)
|
||
|
|
||
|
def check_params(self, params):
|
||
|
"""Checks for user typos in `params`.
|
||
|
|
||
|
Args:
|
||
|
params: dictionary; the parameters to be checked
|
||
|
|
||
|
Raises:
|
||
|
ValueError: if any member of `params` is not a valid argument.
|
||
|
"""
|
||
|
legal_params_fns = [
|
||
|
Sequential.fit,
|
||
|
Sequential.predict,
|
||
|
Sequential.evaluate,
|
||
|
]
|
||
|
if self.build_fn is None:
|
||
|
legal_params_fns.append(self.__call__)
|
||
|
elif not isinstance(
|
||
|
self.build_fn, types.FunctionType
|
||
|
) and not isinstance(self.build_fn, types.MethodType):
|
||
|
legal_params_fns.append(self.build_fn.__call__)
|
||
|
else:
|
||
|
legal_params_fns.append(self.build_fn)
|
||
|
|
||
|
for params_name in params:
|
||
|
for fn in legal_params_fns:
|
||
|
if has_arg(fn, params_name):
|
||
|
break
|
||
|
else:
|
||
|
if params_name != "nb_epoch":
|
||
|
raise ValueError(
|
||
|
f"{params_name} is not a legal parameter"
|
||
|
) # noqa: E501
|
||
|
|
||
|
def get_params(self, **params):
|
||
|
"""Gets parameters for this estimator.
|
||
|
|
||
|
Args:
|
||
|
**params: ignored (exists for API compatibility).
|
||
|
|
||
|
Returns:
|
||
|
Dictionary of parameter names mapped to their values.
|
||
|
"""
|
||
|
res = self.sk_params.copy()
|
||
|
res.update({"build_fn": self.build_fn})
|
||
|
return res
|
||
|
|
||
|
def set_params(self, **params):
|
||
|
"""Sets the parameters of this estimator.
|
||
|
|
||
|
Args:
|
||
|
**params: Dictionary of parameter names mapped to their values.
|
||
|
|
||
|
Returns:
|
||
|
self
|
||
|
"""
|
||
|
self.check_params(params)
|
||
|
self.sk_params.update(params)
|
||
|
return self
|
||
|
|
||
|
def fit(self, x, y, **kwargs):
|
||
|
"""Constructs a new model with `build_fn` & fit the model to `(x, y)`.
|
||
|
|
||
|
Args:
|
||
|
x : array-like, shape `(n_samples, n_features)`
|
||
|
Training samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
|
||
|
True labels for `x`.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments of `Sequential.fit`
|
||
|
|
||
|
Returns:
|
||
|
history : object
|
||
|
details about the training history at each epoch.
|
||
|
"""
|
||
|
if self.build_fn is None:
|
||
|
self.model = self.__call__(**self.filter_sk_params(self.__call__))
|
||
|
elif not isinstance(
|
||
|
self.build_fn, types.FunctionType
|
||
|
) and not isinstance(self.build_fn, types.MethodType):
|
||
|
self.model = self.build_fn(
|
||
|
**self.filter_sk_params(self.build_fn.__call__)
|
||
|
)
|
||
|
else:
|
||
|
self.model = self.build_fn(**self.filter_sk_params(self.build_fn))
|
||
|
|
||
|
if (
|
||
|
losses.is_categorical_crossentropy(self.model.loss)
|
||
|
and len(y.shape) != 2
|
||
|
):
|
||
|
y = to_categorical(y)
|
||
|
|
||
|
fit_args = copy.deepcopy(self.filter_sk_params(Sequential.fit))
|
||
|
fit_args.update(kwargs)
|
||
|
|
||
|
history = self.model.fit(x, y, **fit_args)
|
||
|
|
||
|
return history
|
||
|
|
||
|
def filter_sk_params(self, fn, override=None):
|
||
|
"""Filters `sk_params` and returns those in `fn`'s arguments.
|
||
|
|
||
|
Args:
|
||
|
fn : arbitrary function
|
||
|
override: dictionary, values to override `sk_params`
|
||
|
|
||
|
Returns:
|
||
|
res : dictionary containing variables
|
||
|
in both `sk_params` and `fn`'s arguments.
|
||
|
"""
|
||
|
override = override or {}
|
||
|
res = {}
|
||
|
for name, value in self.sk_params.items():
|
||
|
if has_arg(fn, name):
|
||
|
res.update({name: value})
|
||
|
res.update(override)
|
||
|
return res
|
||
|
|
||
|
|
||
|
@keras_export("keras.wrappers.scikit_learn.KerasClassifier")
|
||
|
@doc_controls.do_not_generate_docs
|
||
|
class KerasClassifier(BaseWrapper):
|
||
|
"""Implementation of the scikit-learn classifier API for Keras.
|
||
|
|
||
|
DEPRECATED. Use [Sci-Keras](https://github.com/adriangb/scikeras) instead.
|
||
|
See https://www.adriangb.com/scikeras/stable/migration.html
|
||
|
for help migrating.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, build_fn=None, **sk_params):
|
||
|
warnings.warn(
|
||
|
"KerasClassifier is deprecated, "
|
||
|
"use Sci-Keras (https://github.com/adriangb/scikeras) instead. "
|
||
|
"See https://www.adriangb.com/scikeras/stable/migration.html "
|
||
|
"for help migrating.",
|
||
|
DeprecationWarning,
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
super().__init__(build_fn, **sk_params)
|
||
|
|
||
|
def fit(self, x, y, **kwargs):
|
||
|
"""Constructs a new model with `build_fn` & fit the model to `(x, y)`.
|
||
|
|
||
|
Args:
|
||
|
x : array-like, shape `(n_samples, n_features)`
|
||
|
Training samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
|
||
|
True labels for `x`.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments of `Sequential.fit`
|
||
|
|
||
|
Returns:
|
||
|
history : object
|
||
|
details about the training history at each epoch.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: In case of invalid shape for `y` argument.
|
||
|
"""
|
||
|
y = np.array(y)
|
||
|
if len(y.shape) == 2 and y.shape[1] > 1:
|
||
|
self.classes_ = np.arange(y.shape[1])
|
||
|
elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
|
||
|
self.classes_ = np.unique(y)
|
||
|
y = np.searchsorted(self.classes_, y)
|
||
|
else:
|
||
|
raise ValueError("Invalid shape for y: " + str(y.shape))
|
||
|
self.n_classes_ = len(self.classes_)
|
||
|
return super().fit(x, y, **kwargs)
|
||
|
|
||
|
def predict(self, x, **kwargs):
|
||
|
"""Returns the class predictions for the given test data.
|
||
|
|
||
|
Args:
|
||
|
x: array-like, shape `(n_samples, n_features)`
|
||
|
Test samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments
|
||
|
of `Sequential.predict`.
|
||
|
|
||
|
Returns:
|
||
|
preds: array-like, shape `(n_samples,)`
|
||
|
Class predictions.
|
||
|
"""
|
||
|
proba = self.model.predict(x, **kwargs)
|
||
|
if proba.shape[-1] > 1:
|
||
|
classes = proba.argmax(axis=-1)
|
||
|
else:
|
||
|
classes = (proba > 0.5).astype("int32")
|
||
|
return self.classes_[classes]
|
||
|
|
||
|
def predict_proba(self, x, **kwargs):
|
||
|
"""Returns class probability estimates for the given test data.
|
||
|
|
||
|
Args:
|
||
|
x: array-like, shape `(n_samples, n_features)`
|
||
|
Test samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments
|
||
|
of `Sequential.predict`.
|
||
|
|
||
|
Returns:
|
||
|
proba: array-like, shape `(n_samples, n_outputs)`
|
||
|
Class probability estimates.
|
||
|
In the case of binary classification,
|
||
|
to match the scikit-learn API,
|
||
|
will return an array of shape `(n_samples, 2)`
|
||
|
(instead of `(n_sample, 1)` as in Keras).
|
||
|
"""
|
||
|
probs = self.model.predict(x, **kwargs)
|
||
|
|
||
|
# check if binary classification
|
||
|
if probs.shape[1] == 1:
|
||
|
# first column is probability of class 0 and second is of class 1
|
||
|
probs = np.hstack([1 - probs, probs])
|
||
|
return probs
|
||
|
|
||
|
def score(self, x, y, **kwargs):
|
||
|
"""Returns the mean accuracy on the given test data and labels.
|
||
|
|
||
|
Args:
|
||
|
x: array-like, shape `(n_samples, n_features)`
|
||
|
Test samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
y: array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
|
||
|
True labels for `x`.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments of `Sequential.evaluate`.
|
||
|
|
||
|
Returns:
|
||
|
score: float
|
||
|
Mean accuracy of predictions on `x` wrt. `y`.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If the underlying model isn't configured to
|
||
|
compute accuracy. You should pass `metrics=["accuracy"]` to
|
||
|
the `.compile()` method of the model.
|
||
|
"""
|
||
|
y = np.searchsorted(self.classes_, y)
|
||
|
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
|
||
|
|
||
|
loss_name = self.model.loss
|
||
|
if hasattr(loss_name, "__name__"):
|
||
|
loss_name = loss_name.__name__
|
||
|
if loss_name == "categorical_crossentropy" and len(y.shape) != 2:
|
||
|
y = to_categorical(y)
|
||
|
|
||
|
outputs = self.model.evaluate(x, y, **kwargs)
|
||
|
if not isinstance(outputs, list):
|
||
|
outputs = [outputs]
|
||
|
for name, output in zip(self.model.metrics_names, outputs):
|
||
|
if name in ["accuracy", "acc"]:
|
||
|
return output
|
||
|
raise ValueError(
|
||
|
"The model is not configured to compute accuracy. "
|
||
|
'You should pass `metrics=["accuracy"]` to '
|
||
|
"the `model.compile()` method."
|
||
|
)
|
||
|
|
||
|
|
||
|
@keras_export("keras.wrappers.scikit_learn.KerasRegressor")
|
||
|
@doc_controls.do_not_generate_docs
|
||
|
class KerasRegressor(BaseWrapper):
|
||
|
"""Implementation of the scikit-learn regressor API for Keras.
|
||
|
|
||
|
DEPRECATED. Use [Sci-Keras](https://github.com/adriangb/scikeras) instead.
|
||
|
See https://www.adriangb.com/scikeras/stable/migration.html
|
||
|
for help migrating.
|
||
|
"""
|
||
|
|
||
|
@doc_controls.do_not_doc_inheritable
|
||
|
def __init__(self, build_fn=None, **sk_params):
|
||
|
warnings.warn(
|
||
|
"KerasRegressor is deprecated, "
|
||
|
"use Sci-Keras (https://github.com/adriangb/scikeras) instead. "
|
||
|
"See https://www.adriangb.com/scikeras/stable/migration.html "
|
||
|
"for help migrating.",
|
||
|
DeprecationWarning,
|
||
|
stacklevel=2,
|
||
|
)
|
||
|
super().__init__(build_fn, **sk_params)
|
||
|
|
||
|
def predict(self, x, **kwargs):
|
||
|
"""Returns predictions for the given test data.
|
||
|
|
||
|
Args:
|
||
|
x: array-like, shape `(n_samples, n_features)`
|
||
|
Test samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments of `Sequential.predict`.
|
||
|
|
||
|
Returns:
|
||
|
preds: array-like, shape `(n_samples,)`
|
||
|
Predictions.
|
||
|
"""
|
||
|
kwargs = self.filter_sk_params(Sequential.predict, kwargs)
|
||
|
return np.squeeze(self.model.predict(x, **kwargs))
|
||
|
|
||
|
def score(self, x, y, **kwargs):
|
||
|
"""Returns the mean loss on the given test data and labels.
|
||
|
|
||
|
Args:
|
||
|
x: array-like, shape `(n_samples, n_features)`
|
||
|
Test samples where `n_samples` is the number of samples
|
||
|
and `n_features` is the number of features.
|
||
|
y: array-like, shape `(n_samples,)`
|
||
|
True labels for `x`.
|
||
|
**kwargs: dictionary arguments
|
||
|
Legal arguments are the arguments of `Sequential.evaluate`.
|
||
|
|
||
|
Returns:
|
||
|
score: float
|
||
|
Mean accuracy of predictions on `x` wrt. `y`.
|
||
|
"""
|
||
|
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
|
||
|
loss = self.model.evaluate(x, y, **kwargs)
|
||
|
if isinstance(loss, list):
|
||
|
return -loss[0]
|
||
|
return -loss
|