# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utilities for Keras classes with v1 and v2 versions.""" import tensorflow.compat.v2 as tf from keras.utils.generic_utils import LazyLoader # TODO(b/134426265): Switch back to single-quotes once the issue # with copybara is fixed. training = LazyLoader("training", globals(), "keras.engine.training") training_v1 = LazyLoader("training_v1", globals(), "keras.engine.training_v1") base_layer = LazyLoader("base_layer", globals(), "keras.engine.base_layer") base_layer_v1 = LazyLoader( "base_layer_v1", globals(), "keras.engine.base_layer_v1" ) callbacks = LazyLoader("callbacks", globals(), "keras.callbacks") callbacks_v1 = LazyLoader("callbacks_v1", globals(), "keras.callbacks_v1") class ModelVersionSelector: """Chooses between Keras v1 and v2 Model class.""" def __new__(cls, *args, **kwargs): use_v2 = should_use_v2() cls = swap_class(cls, training.Model, training_v1.Model, use_v2) return super(ModelVersionSelector, cls).__new__(cls) class LayerVersionSelector: """Chooses between Keras v1 and v2 Layer class.""" def __new__(cls, *args, **kwargs): use_v2 = should_use_v2() cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) return super(LayerVersionSelector, cls).__new__(cls) class TensorBoardVersionSelector: """Chooses between Keras v1 and v2 TensorBoard callback class.""" def __new__(cls, *args, **kwargs): use_v2 = should_use_v2() start_cls = cls cls = swap_class( start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, use_v2 ) if ( start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard ): # Since the v2 class is not a subclass of the v1 class, __init__ has # to be called manually. return cls(*args, **kwargs) return super(TensorBoardVersionSelector, cls).__new__(cls) def should_use_v2(): """Determine if v1 or v2 version should be used.""" if tf.executing_eagerly(): return True elif tf.compat.v1.executing_eagerly_outside_functions(): # Check for a v1 `wrap_function` FuncGraph. # Code inside a `wrap_function` is treated like v1 code. graph = tf.compat.v1.get_default_graph() if getattr(graph, "name", False) and graph.name.startswith( "wrapped_function" ): return False return True else: return False def swap_class(cls, v2_cls, v1_cls, use_v2): """Swaps in v2_cls or v1_cls depending on graph mode.""" if cls == object: return cls if cls in (v2_cls, v1_cls): return v2_cls if use_v2 else v1_cls # Recursively search superclasses to swap in the right Keras class. new_bases = [] for base in cls.__bases__: if ( use_v2 and issubclass(base, v1_cls) # `v1_cls` often extends `v2_cls`, so it may still call `swap_class` # even if it doesn't need to. That being said, it may be the safest # not to over optimize this logic for the sake of correctness, # especially if we swap v1 & v2 classes that don't extend each # other, or when the inheritance order is different. or (not use_v2 and issubclass(base, v2_cls)) ): new_base = swap_class(base, v2_cls, v1_cls, use_v2) else: new_base = base new_bases.append(new_base) cls.__bases__ = tuple(new_bases) return cls def disallow_legacy_graph(cls_name, method_name): if not tf.compat.v1.executing_eagerly_outside_functions(): error_msg = ( f"Calling `{cls_name}.{method_name}` in graph mode is not " f"supported when the `{cls_name}` instance was constructed with " f"eager mode enabled. Please construct your `{cls_name}` instance " f"in graph mode or call `{cls_name}.{method_name}` with " "eager mode enabled." ) raise ValueError(error_msg) def is_v1_layer_or_model(obj): return isinstance(obj, (base_layer_v1.Layer, training_v1.Model))