Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/layers/preprocessing/hashed_crossing.py
2023-06-19 00:49:18 +02:00

228 lines
8.5 KiB
Python

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras hashed crossing preprocessing layer."""
import tensorflow.compat.v2 as tf
from keras import backend
from keras.engine import base_layer
from keras.engine import base_preprocessing_layer
from keras.layers.preprocessing import preprocessing_utils as utils
from keras.utils import layer_utils
# isort: off
from tensorflow.python.util.tf_export import keras_export
INT = utils.INT
ONE_HOT = utils.ONE_HOT
@keras_export(
"keras.layers.HashedCrossing",
"keras.layers.experimental.preprocessing.HashedCrossing",
v1=[],
)
class HashedCrossing(base_layer.Layer):
"""A preprocessing layer which crosses features using the "hashing trick".
This layer performs crosses of categorical features using the "hasing
trick". Conceptually, the transformation can be thought of as:
hash(concatenation of features) % `num_bins`.
This layer currently only performs crosses of scalar inputs and batches of
scalar inputs. Valid input shapes are `(batch_size, 1)`, `(batch_size,)` and
`()`.
For an overview and full list of preprocessing layers, see the preprocessing
[guide](https://www.tensorflow.org/guide/keras/preprocessing_layers).
Args:
num_bins: Number of hash bins.
output_mode: Specification for the output of the layer. Defaults to
`"int"`. Values can be `"int"`, or `"one_hot"` configuring the layer as
follows:
- `"int"`: Return the integer bin indices directly.
- `"one_hot"`: Encodes each individual element in the input into an
array the same size as `num_bins`, containing a 1 at the input's bin
index.
sparse: Boolean. Only applicable to `"one_hot"` mode. If True, returns a
`SparseTensor` instead of a dense `Tensor`. Defaults to False.
**kwargs: Keyword arguments to construct a layer.
Examples:
**Crossing two scalar features.**
>>> layer = tf.keras.layers.HashedCrossing(
... num_bins=5)
>>> feat1 = tf.constant(['A', 'B', 'A', 'B', 'A'])
>>> feat2 = tf.constant([101, 101, 101, 102, 102])
>>> layer((feat1, feat2))
<tf.Tensor: shape=(5,), dtype=int64, numpy=array([1, 4, 1, 1, 3])>
**Crossing and one-hotting two scalar features.**
>>> layer = tf.keras.layers.HashedCrossing(
... num_bins=5, output_mode='one_hot')
>>> feat1 = tf.constant(['A', 'B', 'A', 'B', 'A'])
>>> feat2 = tf.constant([101, 101, 101, 102, 102])
>>> layer((feat1, feat2))
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[0., 1., 0., 0., 0.],
[0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 0., 1., 0.]], dtype=float32)>
"""
def __init__(self, num_bins, output_mode="int", sparse=False, **kwargs):
# By default, output int64 when output_mode="int" and floats otherwise.
if "dtype" not in kwargs or kwargs["dtype"] is None:
kwargs["dtype"] = (
tf.int64 if output_mode == INT else backend.floatx()
)
super().__init__(**kwargs)
base_preprocessing_layer.keras_kpl_gauge.get_cell("HashedCrossing").set(
True
)
# Check dtype only after base layer parses it; dtype parsing is complex.
if (
output_mode == INT
and not tf.as_dtype(self.compute_dtype).is_integer
):
input_dtype = kwargs["dtype"]
raise ValueError(
"When `output_mode='int'`, `dtype` should be an integer "
f"type. Received: dtype={input_dtype}"
)
# "output_mode" must be one of (INT, ONE_HOT)
layer_utils.validate_string_arg(
output_mode,
allowable_strings=(INT, ONE_HOT),
layer_name=self.__class__.__name__,
arg_name="output_mode",
)
self.num_bins = num_bins
self.output_mode = output_mode
self.sparse = sparse
def call(self, inputs):
# Convert all inputs to tensors and check shape. This layer only
# supports sclars and batches of scalars for the initial version.
self._check_at_least_two_inputs(inputs)
inputs = [utils.ensure_tensor(x) for x in inputs]
self._check_input_shape_and_type(inputs)
# Uprank to rank 2 for the cross_hashed op.
rank = inputs[0].shape.rank
if rank < 2:
inputs = [utils.expand_dims(x, -1) for x in inputs]
if rank < 1:
inputs = [utils.expand_dims(x, -1) for x in inputs]
# Perform the cross and convert to dense
outputs = tf.sparse.cross_hashed(inputs, self.num_bins)
outputs = tf.sparse.to_dense(outputs)
# Fix output shape and downrank to match input rank.
if rank == 2:
# tf.sparse.cross_hashed output shape will always be None on the
# last dimension. Given our input shape restrictions, we want to
# force shape 1 instead.
outputs = tf.reshape(outputs, [-1, 1])
elif rank == 1:
outputs = tf.reshape(outputs, [-1])
elif rank == 0:
outputs = tf.reshape(outputs, [])
# Encode outputs.
return utils.encode_categorical_inputs(
outputs,
output_mode=self.output_mode,
depth=self.num_bins,
sparse=self.sparse,
dtype=self.compute_dtype,
)
def compute_output_shape(self, input_shapes):
self._check_at_least_two_inputs(input_shapes)
return utils.compute_shape_for_encode_categorical(input_shapes[0])
def compute_output_signature(self, input_specs):
input_shapes = [x.shape.as_list() for x in input_specs]
output_shape = self.compute_output_shape(input_shapes)
if self.sparse or any(
isinstance(x, tf.SparseTensorSpec) for x in input_specs
):
return tf.SparseTensorSpec(
shape=output_shape, dtype=self.compute_dtype
)
return tf.TensorSpec(shape=output_shape, dtype=self.compute_dtype)
def get_config(self):
config = super().get_config()
config.update(
{
"num_bins": self.num_bins,
"output_mode": self.output_mode,
"sparse": self.sparse,
}
)
return config
def _check_at_least_two_inputs(self, inputs):
if not isinstance(inputs, (list, tuple)):
raise ValueError(
"`HashedCrossing` should be called on a list or tuple of "
f"inputs. Received: inputs={inputs}"
)
if len(inputs) < 2:
raise ValueError(
"`HashedCrossing` should be called on at least two inputs. "
f"Received: inputs={inputs}"
)
def _check_input_shape_and_type(self, inputs):
first_shape = inputs[0].shape.as_list()
rank = len(first_shape)
if rank > 2 or (rank == 2 and first_shape[-1] != 1):
raise ValueError(
"All `HashedCrossing` inputs should have shape `[]`, "
"`[batch_size]` or `[batch_size, 1]`. "
f"Received: inputs={inputs}"
)
if not all(x.shape.as_list() == first_shape for x in inputs[1:]):
raise ValueError(
"All `HashedCrossing` inputs should have equal shape. "
f"Received: inputs={inputs}"
)
if any(
isinstance(x, (tf.RaggedTensor, tf.SparseTensor)) for x in inputs
):
raise ValueError(
"All `HashedCrossing` inputs should be dense tensors. "
f"Received: inputs={inputs}"
)
if not all(x.dtype.is_integer or x.dtype == tf.string for x in inputs):
raise ValueError(
"All `HashedCrossing` inputs should have an integer or "
f"string dtype. Received: inputs={inputs}"
)