Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/distribute/keras_image_model_correctness_test.py
2023-06-19 00:49:18 +02:00

183 lines
7.1 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Correctness tests for tf.keras CNN models using DistributionStrategy."""
import numpy as np
import tensorflow.compat.v2 as tf
import keras
from keras.distribute import keras_correctness_test_base
from keras.optimizers.legacy import gradient_descent
from keras.testing_infra import test_utils
@test_utils.run_all_without_tensor_float_32(
"Uses Dense layers, which call matmul. Even if Dense layers run in "
"float64, the test sometimes fails with TensorFloat-32 enabled for unknown "
"reasons"
)
@test_utils.run_v2_only()
class DistributionStrategyCnnCorrectnessTest(
keras_correctness_test_base.TestDistributionStrategyCorrectnessBase
):
def get_model(
self, initial_weights=None, distribution=None, input_shapes=None
):
del input_shapes
with keras_correctness_test_base.MaybeDistributionScope(distribution):
image = keras.layers.Input(shape=(28, 28, 3), name="image")
c1 = keras.layers.Conv2D(
name="conv1",
filters=16,
kernel_size=(3, 3),
strides=(4, 4),
kernel_regularizer=keras.regularizers.l2(1e-4),
)(image)
if self.with_batch_norm == "regular":
c1 = keras.layers.BatchNormalization(name="bn1")(c1)
elif self.with_batch_norm == "sync":
# Test with parallel batch norms to verify all-reduce works OK.
bn1 = keras.layers.BatchNormalization(
name="bn1", synchronized=True
)(c1)
bn2 = keras.layers.BatchNormalization(
name="bn2", synchronized=True
)(c1)
c1 = keras.layers.Add()([bn1, bn2])
c1 = keras.layers.MaxPooling2D(pool_size=(2, 2))(c1)
logits = keras.layers.Dense(10, activation="softmax", name="pred")(
keras.layers.Flatten()(c1)
)
model = keras.Model(inputs=[image], outputs=[logits])
if initial_weights:
model.set_weights(initial_weights)
model.compile(
optimizer=gradient_descent.SGD(learning_rate=0.1),
loss="sparse_categorical_crossentropy",
metrics=["sparse_categorical_accuracy"],
)
return model
def _get_data(self, count, shape=(28, 28, 3), num_classes=10):
centers = np.random.randn(num_classes, *shape)
features = []
labels = []
for _ in range(count):
label = np.random.randint(0, num_classes, size=1)[0]
offset = np.random.normal(loc=0, scale=0.1, size=np.prod(shape))
offset = offset.reshape(shape)
labels.append(label)
features.append(centers[label] + offset)
x = np.asarray(features, dtype=np.float32)
y = np.asarray(labels, dtype=np.float32).reshape((count, 1))
return x, y
def get_data(self):
x_train, y_train = self._get_data(
count=keras_correctness_test_base._GLOBAL_BATCH_SIZE
* keras_correctness_test_base._EVAL_STEPS
)
x_predict = x_train
return x_train, y_train, x_predict
def get_data_with_partial_last_batch_eval(self):
x_train, y_train = self._get_data(count=1280)
x_eval, y_eval = self._get_data(count=1000)
return x_train, y_train, x_eval, y_eval, x_eval
@tf.__internal__.distribute.combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()
+ keras_correctness_test_base.multi_worker_mirrored_eager()
)
def test_cnn_correctness(
self, distribution, use_numpy, use_validation_data
):
if (
distribution
== tf.__internal__.distribute.combinations.central_storage_strategy_with_gpu_and_cpu # noqa: E501
):
self.skipTest("b/183958183")
self.run_correctness_test(distribution, use_numpy, use_validation_data)
@tf.__internal__.distribute.combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()
+ keras_correctness_test_base.multi_worker_mirrored_eager()
)
def test_cnn_with_batch_norm_correctness(
self, distribution, use_numpy, use_validation_data
):
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
with_batch_norm="regular",
)
@tf.__internal__.distribute.combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations()
+ keras_correctness_test_base.multi_worker_mirrored_eager()
)
def test_cnn_with_sync_batch_norm_correctness(
self, distribution, use_numpy, use_validation_data
):
if not tf.executing_eagerly():
self.skipTest(
"BatchNorm with `synchronized` is not enabled in graph mode."
)
self.run_correctness_test(
distribution, use_numpy, use_validation_data, with_batch_norm="sync"
)
@tf.__internal__.distribute.combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations_eager() # noqa: E501
+ keras_correctness_test_base.multi_worker_mirrored_eager()
+ keras_correctness_test_base.test_combinations_with_tpu_strategies_graph() # noqa: E501
)
def test_cnn_correctness_with_partial_last_batch_eval(
self, distribution, use_numpy, use_validation_data
):
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
partial_last_batch=True,
training_epochs=1,
)
@tf.__internal__.distribute.combinations.generate(
keras_correctness_test_base.all_strategy_and_input_config_combinations_eager() # noqa: E501
+ keras_correctness_test_base.multi_worker_mirrored_eager()
+ keras_correctness_test_base.test_combinations_with_tpu_strategies_graph() # noqa: E501
)
def test_cnn_with_batch_norm_correctness_and_partial_last_batch_eval(
self, distribution, use_numpy, use_validation_data
):
self.run_correctness_test(
distribution,
use_numpy,
use_validation_data,
with_batch_norm="regular",
partial_last_batch=True,
)
if __name__ == "__main__":
tf.__internal__.distribute.multi_process_runner.test_main()