# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Contains the TFOpLambda layer.""" import tensorflow.compat.v2 as tf from keras import backend from keras.engine import keras_tensor from keras.engine.base_layer import Layer # isort: off from tensorflow.python.platform import tf_logging from tensorflow.python.util.tf_export import ( get_canonical_name_for_symbol, ) from tensorflow.python.util.tf_export import ( get_symbol_from_name, ) class ClassMethod(Layer): """Wraps a TF API Class's class method in a `Layer` object. It is inserted by the Functional API construction whenever users call a supported TF Class's class method on KerasTensors. This is useful in the case where users do something like: x = keras.Input(...) y = keras.Input(...) out = tf.RaggedTensor.from_row_splits(x, y) """ @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self, cls_ref, method_name, **kwargs): self.cls_ref = cls_ref self.method_name = method_name self.cls_symbol = get_canonical_name_for_symbol( self.cls_ref, add_prefix_to_v1_names=True ) or get_canonical_name_for_symbol( self.cls_ref, api_name="keras", add_prefix_to_v1_names=True ) if "name" not in kwargs: kwargs["name"] = backend.unique_object_name( "tf." + self.cls_symbol + "." + self.method_name, zero_based=True, avoid_observed_names=True, ) kwargs["autocast"] = False # Do not individually trace op layers in the SavedModel. self._must_restore_from_config = True super().__init__(**kwargs) # Preserve all argument data structures when saving/loading a config # (e.g., don't unnest lists that contain one element) self._preserve_input_structure_in_config = True self._call_spec.expects_training_arg = False self._call_spec.expects_mask_arg = False def call(self, args, kwargs): return getattr(self.cls_ref, self.method_name)(*args, **kwargs) def get_config(self): if not self.cls_symbol: raise ValueError( "This Keras class method conversion tried to convert " f"a method belonging to class {self.cls_symbol}, a class " "that is not publicly exposed in the TensorFlow API. " "To ensure cross-version compatibility of Keras models " "that use op layers, only op layers produced from " "public TensorFlow API symbols can be serialized." ) config = { "cls_symbol": self.cls_symbol, "method_name": self.method_name, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): config = config.copy() symbol_name = config.pop("cls_symbol") cls_ref = get_symbol_from_name(symbol_name) if not cls_ref: raise ValueError( f"TensorFlow symbol `{symbol_name}` could not be found." ) config["cls_ref"] = cls_ref return cls(**config) class KerasOpDispatcher(tf.__internal__.dispatch.GlobalOpDispatcher): """A global dispatcher that allows building a functional model with TF Ops.""" def handle(self, op, args, kwargs): """Handle the specified operation with the specified arguments.""" if any( isinstance(x, keras_tensor.KerasTensor) for x in tf.nest.flatten([args, kwargs]) ): return TFOpLambda(op)(*args, **kwargs) else: return self.NOT_SUPPORTED KerasOpDispatcher().register() class InstanceProperty(Layer): """Wraps an instance property access (e.g. `x.foo`) in a Keras Layer. This layer takes an attribute name `attr_name` in the constructor and, when called on input tensor `obj` returns `obj.attr_name`. KerasTensors specialized for specific extension types use it to represent instance property accesses on the represented object in the case where the property needs to be dynamically accessed as opposed to being statically computed from the typespec, e.g. x = keras.Input(..., ragged=True) out = x.flat_values """ @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self, attr_name, **kwargs): self.attr_name = attr_name if "name" not in kwargs: kwargs["name"] = backend.unique_object_name( "input." + self.attr_name, zero_based=True, avoid_observed_names=True, ) kwargs["autocast"] = False # Do not individually trace op layers in the SavedModel. self._must_restore_from_config = True super().__init__(**kwargs) # Preserve all argument data structures when saving/loading a config # (e.g., don't unnest lists that contain one element) self._preserve_input_structure_in_config = True def call(self, obj): return getattr(obj, self.attr_name) def get_config(self): config = {"attr_name": self.attr_name} base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): return cls(**config) class InstanceMethod(InstanceProperty): """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer. This layer takes an attribute name `attr_name` in the constructor and, when called on input tensor `obj` with additional arguments `args` and `kwargs` returns `obj.attr_name(*args, **kwargs)`. KerasTensors specialized for specific extension types use it to represent dynamic instance method calls on the represented object, e.g. x = keras.Input(..., ragged=True) new_values = keras.Input(...) out = x.with_values(new_values) """ def call(self, obj, args, kwargs): method = getattr(obj, self.attr_name) return method(*args, **kwargs) class TFOpLambda(Layer): """Wraps TF API symbols in a `Layer` object. It is inserted by the Functional API construction whenever users call a supported TF symbol on KerasTensors. Like Lambda layers, this layer tries to raise warnings when it detects users explicitly use variables in the call. (To let them know that the layer will not capture the variables). This is useful in the case where users do something like: x = keras.Input(...) y = tf.Variable(...) out = x * tf_variable """ @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self, function, **kwargs): self.function = function self.symbol = get_canonical_name_for_symbol( self.function, add_prefix_to_v1_names=True ) or get_canonical_name_for_symbol( self.function, api_name="keras", add_prefix_to_v1_names=True ) if "name" not in kwargs: # Generate a name. # TFOpLambda layers avoid already-observed names, # because users cannot easily control the generated names. # Without this avoidance, users would be more likely to run # into unavoidable duplicate layer name collisions. # (For standard layers users could just set `name` when creating the # layer to work around a collision, but they can't do that for # auto-generated layers) if self.symbol: name = "tf." + self.symbol else: name = self.function.__name__ kwargs["name"] = backend.unique_object_name( name, zero_based=True, avoid_observed_names=True ) kwargs["autocast"] = False # Decorate the function to produce this layer's call method def _call_wrapper(*args, **kwargs): return self._call_wrapper(*args, **kwargs) self.call = tf.__internal__.decorator.make_decorator( function, _call_wrapper ) # Do not individually trace op layers in the SavedModel. self._must_restore_from_config = True super().__init__(**kwargs) # Preserve all argument data structures when saving/loading a config # (e.g., don't unnest lists that contain one element) self._preserve_input_structure_in_config = True # Warning on every invocation will be quite irksome in Eager mode. self._already_warned = False self._call_spec.expects_training_arg = False self._call_spec.expects_mask_arg = False def _call_wrapper(self, *args, **kwargs): created_variables = [] def _variable_creator(next_creator, **creator_kwargs): var = next_creator(**creator_kwargs) created_variables.append(var) return var with tf.GradientTape( watch_accessed_variables=True ) as tape, tf.variable_creator_scope(_variable_creator): # We explicitly drop `name` arguments here, # to guard against the case where an op explicitly has a # `name` passed (which is susceptible to producing # multiple ops w/ the same name when the layer is reused) kwargs.pop("name", None) result = self.function(*args, **kwargs) self._check_variables(created_variables, tape.watched_variables()) return result def _check_variables(self, created_variables, accessed_variables): if not created_variables and not accessed_variables: # In the common case that a Lambda layer does not touch a Variable, # we don't want to incur the runtime cost of assembling any state # used for checking only to immediately discard it. return tracked_weights = set(v.ref() for v in self.weights) untracked_new_vars = [ v for v in created_variables if v.ref() not in tracked_weights ] if untracked_new_vars: variable_str = "\n".join(f" {i}" for i in untracked_new_vars) raise ValueError( "The following Variables were created within a Lambda layer " f"({self.name}) but are not tracked by said layer: " f"{variable_str}\n" "The layer cannot safely ensure proper Variable reuse " "across multiple calls, and consequently this behavior " "is disallowed for safety reasons. Lambda layers are " "not well suited for stateful computation; instead, " "writing a subclassed Layer is the recommend " "way to define layers with Variables." ) untracked_used_vars = [ v for v in accessed_variables if v.ref() not in tracked_weights ] if untracked_used_vars and not self._already_warned: variable_str = "\n".join(f" {i}" for i in untracked_used_vars) self._warn( "The following Variables were used in a Lambda layer's call " f"({self.name}), but are not present in its tracked objects: " f"{variable_str}. This is a strong indication that the Lambda " "layer should be rewritten as a subclassed Layer." ) self._already_warned = True def _warn(self, msg): # This method will be overridden in a unit test to raise an error, # because self.assertWarns is not universally implemented. return tf_logging.warning(msg) def get_config(self): if not self.symbol: raise ValueError( f"This Keras op layer was generated from {self.function}, a " "method that is not publicly exposed in the TensorFlow API. " "This may have happened if the method was explicitly " "decorated to add dispatching support, and it was used " "during Functional model construction. " "To ensure cross-version compatibility of Keras models " "that use op layers, only op layers produced from " "public TensorFlow API symbols can be serialized." ) config = {"function": self.symbol} base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) @classmethod def from_config(cls, config, custom_objects=None): config = config.copy() symbol_name = config["function"] function = get_symbol_from_name(symbol_name) if not function: raise ValueError(f"TF symbol `{symbol_name}` could not be found.") config["function"] = function return cls(**config) def _delegate_property(keras_tensor_cls, property_name): """Register property on a KerasTensor class. Calling this multiple times with the same arguments should be a no-op. This method exposes a property on the KerasTensor class that will use an `InstanceProperty` layer to access the property on the represented intermediate values in the model. Args: keras_tensor_cls: The KerasTensor subclass that should expose the property. property_name: The name of the property to expose and delegate to the represented (Composite)Tensor. """ # We use a lambda because we can't create a Keras layer at import time # due to dynamic layer class versioning. property_access = property( lambda self: InstanceProperty(property_name)(self) ) setattr(keras_tensor_cls, property_name, property_access) def _delegate_method(keras_tensor_cls, method_name): """Register method on a KerasTensor class. Calling this function times with the same arguments should be a no-op. This method exposes an instance method on the KerasTensor class that will use an `InstanceMethod` layer to run the desired method on the represented intermediate values in the model. Args: keras_tensor_cls: The KerasTensor subclass that should expose the property. method_name: The name of the method to expose and delegate to the represented (Composite)Tensor. """ def delegate(self, *args, **kwargs): return InstanceMethod(method_name)(self, args, kwargs) setattr(keras_tensor_cls, method_name, delegate) # We do not support the `uniform_row_length` property because it # returns either `None` or an int tensor, and code that relies on it tends # to check `is None` directly. Delegating it here would always return a # `KerasTensor`, regardless of what can be statically inferred. This would # never equal `None`, breaking code that expects it to be partially-static # in unpredictable ways. for ragged_property in [ "values", "flat_values", "row_splits", "nested_row_splits", ]: _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property) for ragged_method_name in [ "value_rowids", "nested_value_rowids", "nrows", "row_starts", "row_limits", "row_lengths", "nested_row_lengths", "bounding_shape", "with_values", "with_flat_values", "with_row_splits_dtype", "merge_dims", "to_tensor", "to_sparse", ]: _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name) for sparse_property in [ "indices", "values", "dense_shape", ]: _delegate_property(keras_tensor.SparseKerasTensor, sparse_property) for sparse_method in [ "with_values", ]: _delegate_method(keras_tensor.SparseKerasTensor, sparse_method) class TFClassMethodDispatcher(tf.__internal__.dispatch.OpDispatcher): """A class method dispatcher that allows building a functional model with TF class methods.""" def __init__(self, cls, method_name): self.cls = cls self.method_name = method_name def handle(self, args, kwargs): """Handle the specified operation with the specified arguments.""" if any( isinstance(x, keras_tensor.KerasTensor) for x in tf.nest.flatten([args, kwargs]) ): return ClassMethod(self.cls, self.method_name)(args[1:], kwargs) else: return self.NOT_SUPPORTED for ragged_class_method in [ "from_value_rowids", "from_row_splits", "from_row_lengths", "from_row_starts", "from_row_limits", "from_uniform_row_length", "from_nested_value_rowids", "from_nested_row_splits", "from_nested_row_lengths", "from_tensor", "from_sparse", ]: TFClassMethodDispatcher(tf.RaggedTensor, ragged_class_method).register( getattr(tf.RaggedTensor, ragged_class_method) ) class SlicingOpLambda(TFOpLambda): """Wraps TF API symbols in a `Layer` object. It is inserted by the Functional API construction whenever users call a supported TF symbol on KerasTensors. Like Lambda layers, this layer tries to raise warnings when it detects users explicitly use variables in the call. (To let them know that the layer will not capture the variables). This is useful in the case where users do something like: x = keras.Input(...) y = tf.Variable(...) out = x * tf_variable """ @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self, function, **kwargs): super().__init__(function, **kwargs) original_call = self.call # Decorate the function to produce this layer's call method def _call_wrapper(*args, **kwargs): # Turn any slice dicts in the args back into `slice` objects. # This conversion cannot use nest.flatten/map_structure, # because dicts are flattened by nest while slices aren't. # So, map_structure would only see the individual elements in the # dict. # This can't use map_structure_up_to either because the # 'shallowness' of the shallow tree would have to vary depending on # if only one dim or multiple are being sliced. new_args = [] for arg in args: arg = _dict_to_slice(arg) if isinstance(arg, (list, tuple)): new_arg = [] for sub_arg in arg: new_arg.append(_dict_to_slice(sub_arg)) arg = new_arg new_args.append(arg) # Handle the kwargs too. new_kwargs = {} for key, value in kwargs.items(): value = _dict_to_slice(value) if isinstance(value, (list, tuple)): new_value = [] for v in value: new_value.append(_dict_to_slice(v)) value = new_value new_kwargs[key] = value return original_call(*new_args, **new_kwargs) self.call = tf.__internal__.decorator.make_decorator( original_call, _call_wrapper ) def _slice_to_dict(x): if isinstance(x, slice): return {"start": x.start, "stop": x.stop, "step": x.step} return x def _dict_to_slice(x): if isinstance(x, dict): return slice(x["start"], x["stop"], x["step"]) return x class TFSlicingOpDispatcher(tf.__internal__.dispatch.OpDispatcher): """A global dispatcher that allows building a functional model with TF Ops.""" def __init__(self, op): self.op = op def handle(self, args, kwargs): """Handle the specified operation with the specified arguments.""" args = tf.nest.map_structure(_slice_to_dict, args) kwargs = tf.nest.map_structure(_slice_to_dict, kwargs) if any( isinstance(x, keras_tensor.KerasTensor) for x in tf.nest.flatten([args, kwargs]) ): return SlicingOpLambda(self.op)(*args, **kwargs) else: return self.NOT_SUPPORTED for slicing_op in [ tf.__operators__.getitem, tf.compat.v1.boolean_mask, tf.boolean_mask, tf.__operators__.ragged_getitem, ]: TFSlicingOpDispatcher(slicing_op).register(slicing_op)