# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains the base ProcessingLayer and a subclass that uses Combiners.""" import abc import tensorflow.compat.v2 as tf from keras.engine import data_adapter from keras.engine.base_layer import Layer from keras.utils import version_utils # isort: off from tensorflow.python.eager import context from tensorflow.python.util.tf_export import keras_export from tensorflow.tools.docs import doc_controls keras_kpl_gauge = tf.__internal__.monitoring.BoolGauge( "/tensorflow/api/keras/layers/preprocessing", "keras preprocessing layers usage", "method", ) @keras_export("keras.layers.experimental.preprocessing.PreprocessingLayer") class PreprocessingLayer(Layer, metaclass=abc.ABCMeta): """Base class for Preprocessing Layers. **Don't use this class directly: it's an abstract base class!** You may be looking for one of the many built-in [preprocessing layers](https://keras.io/guides/preprocessing_layers/) instead. Preprocessing layers are layers whose state gets computed before model training starts. They do not get updated during training. Most preprocessing layers implement an `adapt()` method for state computation. The `PreprocessingLayer` class is the base class you would subclass to implement your own preprocessing layers. """ _must_restore_from_config = True def __init__(self, **kwargs): super().__init__(**kwargs) self._is_compiled = False self._is_adapted = False # Sets `is_adapted=False` when `reset_state` is called. self._reset_state_impl = self.reset_state self.reset_state = self._reset_state_wrapper self._adapt_function = None @property def is_adapted(self): """Whether the layer has been fit to data already.""" return self._is_adapted @doc_controls.do_not_generate_docs def update_state(self, data): """Accumulates statistics for the preprocessing layer. Arguments: data: A mini-batch of inputs to the layer. """ raise NotImplementedError @doc_controls.do_not_generate_docs def reset_state(self): """Resets the statistics of the preprocessing layer.""" raise NotImplementedError @doc_controls.do_not_generate_docs def finalize_state(self): """Finalize the statistics for the preprocessing layer. This method is called at the end of `adapt` or after restoring a serialized preprocessing layer's state. This method handles any one-time operations that should occur on the layer's state before `Layer.__call__`. """ pass @doc_controls.do_not_generate_docs def make_adapt_function(self): """Creates a function to execute one step of `adapt`. This method can be overridden to support custom adapt logic. This method is called by `PreprocessingLayer.adapt`. Typically, this method directly controls `tf.function` settings, and delegates the actual state update logic to `PreprocessingLayer.update_state`. This function is cached the first time `PreprocessingLayer.adapt` is called. The cache is cleared whenever `PreprocessingLayer.compile` is called. Returns: Function. The function created by this method should accept a `tf.data.Iterator`, retrieve a batch, and update the state of the layer. """ if self._adapt_function is not None: return self._adapt_function def adapt_step(iterator): data = next(iterator) self._adapt_maybe_build(data) self.update_state(data) if self._steps_per_execution.numpy().item() == 1: adapt_fn = adapt_step else: def adapt_fn(iterator): for _ in tf.range(self._steps_per_execution): adapt_step(iterator) if not self._run_eagerly: adapt_fn = tf.function(adapt_fn) self._adapt_function = adapt_fn return self._adapt_function def compile(self, run_eagerly=None, steps_per_execution=None): """Configures the layer for `adapt`. Arguments: run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic will not be wrapped in a `tf.function`. Recommended to leave this as `None` unless your `Model` cannot be run inside a `tf.function`. steps_per_execution: Int. Defaults to 1. The number of batches to run during each `tf.function` call. Running multiple batches inside a single `tf.function` call can greatly improve performance on TPUs or small models with a large Python overhead. """ if steps_per_execution is None: steps_per_execution = 1 self._configure_steps_per_execution(steps_per_execution) if run_eagerly is None: run_eagerly = self.dynamic self._run_eagerly = run_eagerly self._is_compiled = True def adapt(self, data, batch_size=None, steps=None): """Fits the state of the preprocessing layer to the data being passed. After calling `adapt` on a layer, a preprocessing layer's state will not update during training. In order to make preprocessing layers efficient in any distribution context, they are kept constant with respect to any compiled `tf.Graph`s that call the layer. This does not affect the layer use when adapting each layer only once, but if you adapt a layer multiple times you will need to take care to re-compile any compiled functions as follows: * If you are adding a preprocessing layer to a `keras.Model`, you need to call `model.compile` after each subsequent call to `adapt`. * If you are calling a preprocessing layer inside `tf.data.Dataset.map`, you should call `map` again on the input `tf.data.Dataset` after each `adapt`. * If you are using a `tf.function` directly which calls a preprocessing layer, you need to call `tf.function` again on your callable after each subsequent call to `adapt`. `tf.keras.Model` example with multiple adapts: >>> layer = tf.keras.layers.Normalization( ... axis=None) >>> layer.adapt([0, 2]) >>> model = tf.keras.Sequential(layer) >>> model.predict([0, 1, 2]) array([-1., 0., 1.], dtype=float32) >>> layer.adapt([-1, 1]) >>> model.compile() # This is needed to re-compile model.predict! >>> model.predict([0, 1, 2]) array([0., 1., 2.], dtype=float32) `tf.data.Dataset` example with multiple adapts: >>> layer = tf.keras.layers.Normalization( ... axis=None) >>> layer.adapt([0, 2]) >>> input_ds = tf.data.Dataset.range(3) >>> normalized_ds = input_ds.map(layer) >>> list(normalized_ds.as_numpy_iterator()) [array([-1.], dtype=float32), array([0.], dtype=float32), array([1.], dtype=float32)] >>> layer.adapt([-1, 1]) >>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset. >>> list(normalized_ds.as_numpy_iterator()) [array([0.], dtype=float32), array([1.], dtype=float32), array([2.], dtype=float32)] `adapt()` is meant only as a single machine utility to compute layer state. To analyze a dataset that cannot fit on a single machine, see [Tensorflow Transform]( https://www.tensorflow.org/tfx/transform/get_started) for a multi-machine, map-reduce solution. Arguments: data: The data to train on. It can be passed either as a tf.data Dataset, or as a numpy array. batch_size: Integer or `None`. Number of samples per state update. If unspecified, `batch_size` will default to 32. Do not specify the `batch_size` if your data is in the form of datasets, generators, or `keras.utils.Sequence` instances (since they generate batches). steps: Integer or `None`. Total number of steps (batches of samples) When training with input tensors such as TensorFlow data tensors, the default `None` is equal to the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. If x is a `tf.data` dataset, and 'steps' is None, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the `steps` argument. This argument is not supported with array inputs. """ _disallow_inside_tf_function("adapt") if not version_utils.should_use_v2(): raise RuntimeError("`adapt` is only supported in tensorflow v2.") if not self._is_compiled: self.compile() # Compile with defaults. if self.built: self.reset_state() data_handler = data_adapter.DataHandler( data, batch_size=batch_size, steps_per_epoch=steps, epochs=1, steps_per_execution=self._steps_per_execution, distribute=False, ) self._adapt_function = self.make_adapt_function() for _, iterator in data_handler.enumerate_epochs(): with data_handler.catch_stop_iteration(): for _ in data_handler.steps(): self._adapt_function(iterator) if data_handler.should_sync: context.async_wait() self.finalize_state() self._is_adapted = True def _reset_state_wrapper(self): """Calls `reset_state` and sets `adapted` to `False`.""" self._reset_state_impl() self._is_adapted = False @tf.__internal__.tracking.no_automatic_dependency_tracking def _configure_steps_per_execution(self, steps_per_execution): self._steps_per_execution = tf.Variable( steps_per_execution, dtype="int64", aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, ) # TODO(omalleyt): Unify this logic with `Layer._maybe_build`. def _adapt_maybe_build(self, data): if not self.built: try: # If this is a Numpy array or tensor, we can get shape from # .shape. If not, an attribute error will be thrown. data_shape = data.shape data_shape_nones = tuple([None] * len(data.shape)) except AttributeError: # The input has an unknown number of dimensions. data_shape = None data_shape_nones = None # TODO (b/159261555): move this to base layer build. batch_input_shape = getattr(self, "_batch_input_shape", None) if batch_input_shape is None: # Set the number of dimensions. self._batch_input_shape = data_shape_nones self.build(data_shape) self.built = True def _disallow_inside_tf_function(method_name): """Disallow calling a method inside a `tf.function`.""" if tf.inside_function(): error_msg = ( "Detected a call to `PreprocessingLayer.{method_name}` inside a " "`tf.function`. `PreprocessingLayer.{method_name} is a high-level " "endpoint that manages its own `tf.function`. Please move the call " "to `PreprocessingLayer.{method_name}` outside of all enclosing " "`tf.function`s. Note that you can call a `PreprocessingLayer` " "directly on `Tensor`s inside a `tf.function` like: `layer(x)`, " "or update its state like: `layer.update_state(x)`." ).format(method_name=method_name) raise RuntimeError(error_msg)