# Copyright 2019 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Correctness tests for tf.keras CNN models using DistributionStrategy.""" import numpy as np import tensorflow.compat.v2 as tf import keras from keras.distribute import keras_correctness_test_base from keras.optimizers.legacy import gradient_descent from keras.testing_infra import test_utils @test_utils.run_all_without_tensor_float_32( "Uses Dense layers, which call matmul. Even if Dense layers run in " "float64, the test sometimes fails with TensorFloat-32 enabled for unknown " "reasons" ) @test_utils.run_v2_only() class DistributionStrategyCnnCorrectnessTest( keras_correctness_test_base.TestDistributionStrategyCorrectnessBase ): def get_model( self, initial_weights=None, distribution=None, input_shapes=None ): del input_shapes with keras_correctness_test_base.MaybeDistributionScope(distribution): image = keras.layers.Input(shape=(28, 28, 3), name="image") c1 = keras.layers.Conv2D( name="conv1", filters=16, kernel_size=(3, 3), strides=(4, 4), kernel_regularizer=keras.regularizers.l2(1e-4), )(image) if self.with_batch_norm == "regular": c1 = keras.layers.BatchNormalization(name="bn1")(c1) elif self.with_batch_norm == "sync": # Test with parallel batch norms to verify all-reduce works OK. bn1 = keras.layers.BatchNormalization( name="bn1", synchronized=True )(c1) bn2 = keras.layers.BatchNormalization( name="bn2", synchronized=True )(c1) c1 = keras.layers.Add()([bn1, bn2]) c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1) logits = keras.layers.Dense(10, activation="softmax", name="pred")( keras.layers.Flatten()(c1) ) model = keras.Model(inputs=[image], outputs=[logits]) if initial_weights: model.set_weights(initial_weights) model.compile( optimizer=gradient_descent.SGD(learning_rate=0.1), loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"], ) return model def _get_data(self, count, shape=(28, 28, 3), num_classes=10): centers = np.random.randn(num_classes, *shape) features = [] labels = [] for _ in range(count): label = np.random.randint(0, num_classes, size=1)[0] offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape)) offset = offset.reshape(shape) labels.append(label) features.append(centers[label] + offset) x = np.asarray(features, dtype=np.float32) y = np.asarray(labels, dtype=np.float32).reshape((count, 1)) return x, y def get_data(self): x_train, y_train = self._get_data( count=keras_correctness_test_base._GLOBAL_BATCH_SIZE * keras_correctness_test_base._EVAL_STEPS ) x_predict = x_train return x_train, y_train, x_predict def get_data_with_partial_last_batch_eval(self): x_train, y_train = self._get_data(count=1280) x_eval, y_eval = self._get_data(count=1000) return x_train, y_train, x_eval, y_eval, x_eval @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations() + keras_correctness_test_base.multi_worker_mirrored_eager() ) def test_cnn_correctness( self, distribution, use_numpy, use_validation_data ): if ( distribution == tf.__internal__.distribute.combinations.central_storage_strategy_with_gpu_and_cpu # noqa: E501 ): self.skipTest("b/183958183") self.run_correctness_test(distribution, use_numpy, use_validation_data) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations() + keras_correctness_test_base.multi_worker_mirrored_eager() ) def test_cnn_with_batch_norm_correctness( self, distribution, use_numpy, use_validation_data ): self.run_correctness_test( distribution, use_numpy, use_validation_data, with_batch_norm="regular", ) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations() + keras_correctness_test_base.multi_worker_mirrored_eager() ) def test_cnn_with_sync_batch_norm_correctness( self, distribution, use_numpy, use_validation_data ): if not tf.executing_eagerly(): self.skipTest( "BatchNorm with `synchronized` is not enabled in graph mode." ) self.run_correctness_test( distribution, use_numpy, use_validation_data, with_batch_norm="sync" ) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations_eager() # noqa: E501 + keras_correctness_test_base.multi_worker_mirrored_eager() + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph() # noqa: E501 ) def test_cnn_correctness_with_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data ): self.run_correctness_test( distribution, use_numpy, use_validation_data, partial_last_batch=True, training_epochs=1, ) @tf.__internal__.distribute.combinations.generate( keras_correctness_test_base.all_strategy_and_input_config_combinations_eager() # noqa: E501 + keras_correctness_test_base.multi_worker_mirrored_eager() + keras_correctness_test_base.test_combinations_with_tpu_strategies_graph() # noqa: E501 ) def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval( self, distribution, use_numpy, use_validation_data ): self.run_correctness_test( distribution, use_numpy, use_validation_data, with_batch_norm="regular", partial_last_batch=True, ) if __name__ == "__main__": tf.__internal__.distribute.multi_process_runner.test_main()