283 lines
11 KiB
Python
283 lines
11 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Preprocessing stage."""
|
|
|
|
import numpy as np
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras.engine import base_preprocessing_layer
|
|
from keras.engine import functional
|
|
from keras.engine import sequential
|
|
from keras.utils import tf_utils
|
|
|
|
|
|
# Sequential methods should take precedence.
|
|
class PreprocessingStage(
|
|
sequential.Sequential, base_preprocessing_layer.PreprocessingLayer
|
|
):
|
|
"""A sequential preprocessing stage.
|
|
|
|
This preprocessing stage wraps a list of preprocessing layers into a
|
|
Sequential-like object that enables you to `adapt()` the whole list via
|
|
a single `adapt()` call on the preprocessing stage.
|
|
|
|
Args:
|
|
layers: List of layers. Can include layers that aren't preprocessing
|
|
layers.
|
|
name: String. Optional name for the preprocessing stage object.
|
|
"""
|
|
|
|
def adapt(self, data, reset_state=True):
|
|
"""Adapt the state of the layers of the preprocessing stage to the data.
|
|
|
|
Args:
|
|
data: A batched Dataset object, or a NumPy array, or an EagerTensor.
|
|
Data to be iterated over to adapt the state of the layers in this
|
|
preprocessing stage.
|
|
reset_state: Whether this call to `adapt` should reset the state of
|
|
the layers in this preprocessing stage.
|
|
"""
|
|
if not isinstance(
|
|
data, (tf.data.Dataset, np.ndarray, tf.__internal__.EagerTensor)
|
|
):
|
|
raise ValueError(
|
|
"`adapt()` requires a batched Dataset, an EagerTensor, or a "
|
|
f"Numpy array as input. Received data={data}"
|
|
)
|
|
if isinstance(data, tf.data.Dataset):
|
|
# Validate the datasets to try and ensure we haven't been passed one
|
|
# with infinite size. That would cause an infinite loop here.
|
|
if tf_utils.dataset_is_infinite(data):
|
|
raise ValueError(
|
|
"The dataset passed to `adapt()` has an infinite number of "
|
|
"elements. Please use dataset.take(...) to make the number "
|
|
"of elements finite."
|
|
)
|
|
|
|
for current_layer_index in range(0, len(self.layers)):
|
|
if not hasattr(self.layers[current_layer_index], "adapt"):
|
|
# Skip any layer that does not need adapting.
|
|
continue
|
|
|
|
def map_fn(x):
|
|
"""Maps this object's inputs to those at current_layer_index.
|
|
|
|
Args:
|
|
x: Batch of inputs seen in entry of the `PreprocessingStage`
|
|
instance.
|
|
|
|
Returns:
|
|
Batch of inputs to be processed by layer
|
|
`self.layers[current_layer_index]`
|
|
"""
|
|
if current_layer_index == 0:
|
|
return x
|
|
for i in range(current_layer_index):
|
|
x = self.layers[i](x)
|
|
return x
|
|
|
|
if isinstance(data, tf.data.Dataset):
|
|
current_layer_data = data.map(map_fn)
|
|
else:
|
|
current_layer_data = map_fn(data)
|
|
self.layers[current_layer_index].adapt(
|
|
current_layer_data, reset_state=reset_state
|
|
)
|
|
|
|
|
|
# Functional methods should take precedence.
|
|
class FunctionalPreprocessingStage(
|
|
functional.Functional, base_preprocessing_layer.PreprocessingLayer
|
|
):
|
|
"""A functional preprocessing stage.
|
|
|
|
This preprocessing stage wraps a graph of preprocessing layers into a
|
|
Functional-like object that enables you to `adapt()` the whole graph via
|
|
a single `adapt()` call on the preprocessing stage.
|
|
|
|
Preprocessing stage is not a complete model, so it cannot be called with
|
|
`fit()`. However, it is possible to add regular layers that may be trainable
|
|
to a preprocessing stage.
|
|
|
|
A functional preprocessing stage is created in the same way as `Functional`
|
|
models. A stage can be instantiated by passing two arguments to
|
|
`__init__`. The first argument is the `keras.Input` Tensors that represent
|
|
the inputs to the stage. The second argument specifies the output
|
|
tensors that represent the outputs of this stage. Both arguments can be a
|
|
nested structure of tensors.
|
|
|
|
Example:
|
|
|
|
>>> inputs = {'x2': tf.keras.Input(shape=(5,)),
|
|
... 'x1': tf.keras.Input(shape=(1,))}
|
|
>>> norm_layer = tf.keras.layers.Normalization()
|
|
>>> y = norm_layer(inputs['x2'])
|
|
>>> y, z = tf.keras.layers.Lambda(lambda x: (x, x))(inputs['x1'])
|
|
>>> outputs = [inputs['x1'], [y, z]]
|
|
>>> stage = FunctionalPreprocessingStage(inputs, outputs)
|
|
|
|
Args:
|
|
inputs: An input tensor (must be created via `tf.keras.Input()`), or a
|
|
list, a dict, or a nested structure of input tensors.
|
|
outputs: An output tensor, or a list, a dict or a nested structure of
|
|
output tensors.
|
|
name: String, optional. Name of the preprocessing stage.
|
|
"""
|
|
|
|
def fit(self, *args, **kwargs):
|
|
raise ValueError(
|
|
"Preprocessing stage is not a complete model, and hence should not "
|
|
"be `fit`. Instead, you may feed data to `adapt` the stage to set "
|
|
"appropriate states of the layers in the stage."
|
|
)
|
|
|
|
def adapt(self, data, reset_state=True):
|
|
"""Adapt the state of the layers of the preprocessing stage to the data.
|
|
|
|
Args:
|
|
data: A batched Dataset object, a NumPy array, an EagerTensor, or a
|
|
list, dict or nested structure of Numpy Arrays or EagerTensors. The
|
|
elements of Dataset object need to conform with inputs of the stage.
|
|
The first dimension of NumPy arrays or EagerTensors are understood
|
|
to be batch dimension. Data to be iterated over to adapt the state
|
|
of the layers in this preprocessing stage.
|
|
reset_state: Whether this call to `adapt` should reset the state of
|
|
the layers in this preprocessing stage.
|
|
|
|
Examples:
|
|
|
|
>>> # For a stage with dict input
|
|
>>> inputs = {'x2': tf.keras.Input(shape=(5,)),
|
|
... 'x1': tf.keras.Input(shape=(1,))}
|
|
>>> outputs = [inputs['x1'], inputs['x2']]
|
|
>>> stage = FunctionalPreprocessingStage(inputs, outputs)
|
|
>>> ds = tf.data.Dataset.from_tensor_slices({'x1': tf.ones((4,5)),
|
|
... 'x2': tf.ones((4,1))})
|
|
>>> sorted(ds.element_spec.items()) # Check element_spec
|
|
[('x1', TensorSpec(shape=(5,), dtype=tf.float32, name=None)),
|
|
('x2', TensorSpec(shape=(1,), dtype=tf.float32, name=None))]
|
|
>>> stage.adapt(ds)
|
|
>>> data_np = {'x1': np.ones((4, 5)), 'x2': np.ones((4, 1))}
|
|
>>> stage.adapt(data_np)
|
|
|
|
"""
|
|
if not isinstance(data, tf.data.Dataset):
|
|
data = self._flatten_to_reference_inputs(data)
|
|
if any(
|
|
not isinstance(datum, (np.ndarray, tf.__internal__.EagerTensor))
|
|
for datum in data
|
|
):
|
|
raise ValueError(
|
|
"`adapt()` requires a batched Dataset, a list of "
|
|
f"EagerTensors or Numpy arrays as input, got {type(data)}"
|
|
)
|
|
ds_input = [
|
|
tf.data.Dataset.from_tensor_slices(x).batch(1) for x in data
|
|
]
|
|
|
|
if isinstance(data, tf.data.Dataset):
|
|
# Validate the datasets to try and ensure we haven't been passed one
|
|
# with infinite size. That would cause an infinite loop here.
|
|
if tf_utils.dataset_is_infinite(data):
|
|
raise ValueError(
|
|
"The dataset passed to `adapt()` has an infinite number of "
|
|
"elements. Please use dataset.take(...) to make the number "
|
|
"of elements finite."
|
|
)
|
|
# Unzip dataset object to a list of single input dataset.
|
|
ds_input = _unzip_dataset(data)
|
|
|
|
# Dictionary mapping reference tensors to datasets
|
|
ds_dict = {}
|
|
tensor_usage_count = self._tensor_usage_count
|
|
for x, y in zip(self.inputs, ds_input):
|
|
x_id = str(id(x))
|
|
ds_dict[x_id] = [y] * tensor_usage_count[x_id]
|
|
|
|
nodes_by_depth = self._nodes_by_depth
|
|
depth_keys = sorted(nodes_by_depth.keys(), reverse=True)
|
|
|
|
def build_map_fn(node, args, kwargs):
|
|
if not isinstance(args.element_spec, tuple):
|
|
|
|
def map_fn(*x):
|
|
return tf.nest.flatten(node.layer(*x, **kwargs))
|
|
|
|
else:
|
|
|
|
def map_fn(*x):
|
|
return tf.nest.flatten(node.layer(x, **kwargs))
|
|
|
|
return map_fn
|
|
|
|
for depth in depth_keys:
|
|
for node in nodes_by_depth[depth]:
|
|
# Input node
|
|
if node.is_input:
|
|
continue
|
|
|
|
# Node with input not computed yet
|
|
if any(t_id not in ds_dict for t_id in node.flat_input_ids):
|
|
continue
|
|
|
|
args, kwargs = node.map_arguments(ds_dict)
|
|
args = tf.data.Dataset.zip(
|
|
tf.__internal__.nest.list_to_tuple(*args)
|
|
)
|
|
|
|
if node.layer.stateful and hasattr(node.layer, "adapt"):
|
|
node.layer.adapt(args, reset_state=reset_state)
|
|
|
|
map_fn = build_map_fn(node, args, kwargs)
|
|
outputs = args.map(map_fn)
|
|
outputs = _unzip_dataset(outputs)
|
|
|
|
# Update ds_dict.
|
|
for x_id, y in zip(node.flat_output_ids, outputs):
|
|
ds_dict[x_id] = [y] * tensor_usage_count[x_id]
|
|
|
|
|
|
def _unzip_dataset(ds):
|
|
"""Unzip dataset into a list of single element datasets.
|
|
|
|
Args:
|
|
ds: A Dataset object.
|
|
|
|
Returns:
|
|
A list of Dataset object, each correspond to one of the `element_spec` of
|
|
the input Dataset object.
|
|
|
|
Example:
|
|
|
|
>>> ds1 = tf.data.Dataset.from_tensor_slices([1, 2, 3])
|
|
>>> ds2 = tf.data.Dataset.from_tensor_slices([4, 5, 6])
|
|
>>> ds_zipped_tuple = tf.data.Dataset.zip((ds1, ds2))
|
|
>>> ds_unzipped_tuple = _unzip_dataset(ds_zipped_tuple)
|
|
>>> ds_zipped_dict = tf.data.Dataset.zip({'ds1': ds1, 'ds2': ds2})
|
|
>>> ds_unzipped_dict = _unzip_dataset(ds_zipped_dict)
|
|
|
|
Then the two elements of `ds_unzipped_tuple` and `ds_unzipped_dict` are both
|
|
the same as `ds1` and `ds2`.
|
|
"""
|
|
element_count = len(tf.nest.flatten(ds.element_spec))
|
|
ds_unzipped = []
|
|
for i in range(element_count):
|
|
|
|
def map_fn(*x, j=i):
|
|
return tf.nest.flatten(x)[j]
|
|
|
|
ds_unzipped.append(ds.map(map_fn))
|
|
return ds_unzipped
|