698 lines
26 KiB
Python
698 lines
26 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Tests for tf.keras models with callbacks, checkpointing with dist
|
||
|
strategy."""
|
||
|
|
||
|
import collections
|
||
|
import tempfile
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
from absl.testing import parameterized
|
||
|
|
||
|
import keras
|
||
|
from keras import losses
|
||
|
from keras.distribute import distribute_strategy_test as keras_test_lib
|
||
|
from keras.distribute import distributed_training_utils_v1
|
||
|
from keras.distribute import optimizer_combinations
|
||
|
|
||
|
|
||
|
class Counter(keras.callbacks.Callback):
|
||
|
"""Counts the number of times each callback method was run.
|
||
|
|
||
|
Attributes:
|
||
|
method_counts: dict. Contains the counts of time each callback method was
|
||
|
run.
|
||
|
"""
|
||
|
|
||
|
def __init__(self):
|
||
|
self.method_counts = collections.defaultdict(int)
|
||
|
methods_to_count = [
|
||
|
"on_batch_begin",
|
||
|
"on_batch_end",
|
||
|
"on_epoch_begin",
|
||
|
"on_epoch_end",
|
||
|
"on_predict_batch_begin",
|
||
|
"on_predict_batch_end",
|
||
|
"on_predict_begin",
|
||
|
"on_predict_end",
|
||
|
"on_test_batch_begin",
|
||
|
"on_test_batch_end",
|
||
|
"on_test_begin",
|
||
|
"on_test_end",
|
||
|
"on_train_batch_begin",
|
||
|
"on_train_batch_end",
|
||
|
"on_train_begin",
|
||
|
"on_train_end",
|
||
|
]
|
||
|
for method_name in methods_to_count:
|
||
|
setattr(
|
||
|
self,
|
||
|
method_name,
|
||
|
self.wrap_with_counts(method_name, getattr(self, method_name)),
|
||
|
)
|
||
|
|
||
|
def wrap_with_counts(self, method_name, method):
|
||
|
def _call_and_count(*args, **kwargs):
|
||
|
self.method_counts[method_name] += 1
|
||
|
return method(*args, **kwargs)
|
||
|
|
||
|
return _call_and_count
|
||
|
|
||
|
|
||
|
class TestDistributionStrategyWithCallbacks(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations()
|
||
|
)
|
||
|
)
|
||
|
def test_callbacks_in_fit(self, distribution):
|
||
|
with distribution.scope():
|
||
|
model = keras_test_lib.get_model()
|
||
|
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
|
||
|
|
||
|
dataset = keras_test_lib.get_dataset(distribution)
|
||
|
counter = Counter()
|
||
|
|
||
|
epochs = 2
|
||
|
steps_per_epoch = 5
|
||
|
validation_steps = 3
|
||
|
|
||
|
model.fit(
|
||
|
dataset,
|
||
|
epochs=epochs,
|
||
|
steps_per_epoch=steps_per_epoch,
|
||
|
verbose=0,
|
||
|
validation_data=dataset,
|
||
|
validation_steps=validation_steps,
|
||
|
callbacks=[counter],
|
||
|
)
|
||
|
|
||
|
if (
|
||
|
isinstance(
|
||
|
distribution, tf.compat.v1.distribute.experimental.TPUStrategy
|
||
|
)
|
||
|
and not tf.executing_eagerly()
|
||
|
):
|
||
|
# TPU Strategy can have multi step training, from
|
||
|
# extended.steps_per_run if steps_per_run = 1, then
|
||
|
# num_batch_call_per_epoch = steps_per_epoch
|
||
|
steps_per_run = distribution.extended.steps_per_run
|
||
|
num_batch_call_per_epoch = steps_per_epoch // steps_per_run
|
||
|
if steps_per_epoch % steps_per_run:
|
||
|
num_batch_call_per_epoch += 1
|
||
|
else:
|
||
|
num_batch_call_per_epoch = steps_per_epoch
|
||
|
|
||
|
self.assertDictEqual(
|
||
|
counter.method_counts,
|
||
|
{
|
||
|
"on_batch_begin": epochs * num_batch_call_per_epoch,
|
||
|
"on_batch_end": epochs * num_batch_call_per_epoch,
|
||
|
"on_epoch_begin": epochs,
|
||
|
"on_epoch_end": epochs,
|
||
|
"on_test_batch_begin": epochs * validation_steps,
|
||
|
"on_test_batch_end": epochs * validation_steps,
|
||
|
"on_test_begin": epochs,
|
||
|
"on_test_end": epochs,
|
||
|
"on_train_batch_begin": epochs * num_batch_call_per_epoch,
|
||
|
"on_train_batch_end": epochs * num_batch_call_per_epoch,
|
||
|
"on_train_begin": 1,
|
||
|
"on_train_end": 1,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations()
|
||
|
)
|
||
|
)
|
||
|
def test_callbacks_in_eval(self, distribution):
|
||
|
with distribution.scope():
|
||
|
model = keras_test_lib.get_model()
|
||
|
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
|
||
|
|
||
|
dataset = keras_test_lib.get_dataset(distribution)
|
||
|
counter = Counter()
|
||
|
|
||
|
model.evaluate(dataset, steps=5, callbacks=[counter])
|
||
|
|
||
|
self.assertDictEqual(
|
||
|
counter.method_counts,
|
||
|
{
|
||
|
"on_test_batch_begin": 5,
|
||
|
"on_test_batch_end": 5,
|
||
|
"on_test_begin": 1,
|
||
|
"on_test_end": 1,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations()
|
||
|
)
|
||
|
)
|
||
|
def test_callbacks_in_predict(self, distribution):
|
||
|
with distribution.scope():
|
||
|
model = keras_test_lib.get_model()
|
||
|
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
|
||
|
|
||
|
dataset = keras_test_lib.get_dataset(distribution)
|
||
|
counter = Counter()
|
||
|
|
||
|
model.predict(
|
||
|
keras_test_lib.get_predict_dataset(dataset),
|
||
|
steps=5,
|
||
|
callbacks=[counter],
|
||
|
)
|
||
|
|
||
|
self.assertDictEqual(
|
||
|
counter.method_counts,
|
||
|
{
|
||
|
"on_predict_batch_begin": 5,
|
||
|
"on_predict_batch_end": 5,
|
||
|
"on_predict_begin": 1,
|
||
|
"on_predict_end": 1,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
|
||
|
class TestDistributionStrategyErrorCases(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
],
|
||
|
mode=["graph"],
|
||
|
)
|
||
|
)
|
||
|
def test_validating_dataset_input_tensors_with_shape_mismatch(
|
||
|
self, distribution
|
||
|
):
|
||
|
with self.cached_session():
|
||
|
|
||
|
@tf.function
|
||
|
def run():
|
||
|
ctx = tf.distribute.get_replica_context()
|
||
|
if ctx.replica_id_in_sync_group.device.endswith("GPU:0"):
|
||
|
return tf.constant([[1, 2]])
|
||
|
else:
|
||
|
return tf.constant([[1, 2], [1, 2]])
|
||
|
|
||
|
x = distribution.run(run)
|
||
|
|
||
|
# Removed device and input tensor shape details from the error
|
||
|
# message since the order of the device and the corresponding input
|
||
|
# tensor shape is not deterministic over different runs.
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"Input tensor shapes do not match for "
|
||
|
"distributed tensor inputs "
|
||
|
"PerReplica:.+",
|
||
|
):
|
||
|
with distribution.scope():
|
||
|
distributed_training_utils_v1.validate_distributed_dataset_inputs( # noqa: E501
|
||
|
distribution, x, None
|
||
|
)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
)
|
||
|
)
|
||
|
def test_validating_dataset_input_tensors_with_dtype_mismatch(
|
||
|
self, distribution
|
||
|
):
|
||
|
with self.cached_session():
|
||
|
|
||
|
@tf.function
|
||
|
def run():
|
||
|
ctx = tf.distribute.get_replica_context()
|
||
|
if ctx.replica_id_in_sync_group.device.endswith("GPU:0"):
|
||
|
return tf.constant([[1, 2]], dtype=tf.int32)
|
||
|
else:
|
||
|
return tf.constant([[1, 2]], dtype=tf.float64)
|
||
|
|
||
|
x = distribution.run(run)
|
||
|
|
||
|
# Removed device and input tensor dtype details from the error
|
||
|
# message since the order of the device and the corresponding input
|
||
|
# tensor dtype is not deterministic over different runs.
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"Input tensor dtypes do not match for "
|
||
|
"distributed tensor inputs "
|
||
|
"PerReplica:.+",
|
||
|
):
|
||
|
with distribution.scope():
|
||
|
distributed_training_utils_v1.validate_distributed_dataset_inputs( # noqa: E501
|
||
|
distribution, x, None
|
||
|
)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
)
|
||
|
)
|
||
|
def test_unsupported_features(self, distribution, mode):
|
||
|
with self.cached_session():
|
||
|
with distribution.scope():
|
||
|
model = keras_test_lib.get_model()
|
||
|
optimizer = tf.compat.v1.train.GradientDescentOptimizer(0.001)
|
||
|
loss = "mse"
|
||
|
metrics = ["mae"]
|
||
|
model.compile(optimizer, loss, metrics=metrics)
|
||
|
|
||
|
dataset = keras_test_lib.get_dataset(distribution)
|
||
|
# Test with validation split
|
||
|
with self.assertRaises(ValueError):
|
||
|
model.fit(
|
||
|
dataset,
|
||
|
epochs=1,
|
||
|
steps_per_epoch=2,
|
||
|
verbose=0,
|
||
|
validation_split=0.5,
|
||
|
validation_steps=2,
|
||
|
)
|
||
|
|
||
|
# Test with sample weight.
|
||
|
sample_weight = np.random.random((10,))
|
||
|
with self.assertRaises(ValueError):
|
||
|
model.fit(
|
||
|
dataset,
|
||
|
epochs=1,
|
||
|
steps_per_epoch=2,
|
||
|
verbose=0,
|
||
|
sample_weight=sample_weight,
|
||
|
)
|
||
|
|
||
|
# Test with not specifying the `steps` argument for dataset with
|
||
|
# infinite cardinality.
|
||
|
dataset = dataset.repeat()
|
||
|
with self.assertRaises(ValueError):
|
||
|
model.fit(dataset, epochs=1, verbose=0)
|
||
|
with self.assertRaises(ValueError):
|
||
|
model.evaluate(dataset, verbose=0)
|
||
|
|
||
|
with self.assertRaises(ValueError):
|
||
|
model.predict(dataset, verbose=0)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
tf.__internal__.distribute.combinations.one_device_strategy,
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
)
|
||
|
)
|
||
|
def test_distribution_strategy_on_subclassed_model(self, distribution):
|
||
|
with distribution.scope():
|
||
|
|
||
|
class _SimpleMLP(keras.Model):
|
||
|
def __init__(self, num_labels):
|
||
|
super().__init__()
|
||
|
self.dense = keras.layers.Dense(num_labels)
|
||
|
|
||
|
def call(self, inputs):
|
||
|
return self.dense(inputs)
|
||
|
|
||
|
model = _SimpleMLP(3)
|
||
|
|
||
|
if not tf.executing_eagerly():
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"We currently do not support distribution strategy with a "
|
||
|
"`Sequential` model that is created without `input_shape`/"
|
||
|
"`input_dim` set in its first layer or a subclassed model.",
|
||
|
):
|
||
|
model.compile("sgd")
|
||
|
else:
|
||
|
model.compile("sgd")
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
tf.__internal__.distribute.combinations.one_device_strategy,
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
)
|
||
|
)
|
||
|
def test_distribution_strategy_on_deferred_sequential_model(
|
||
|
self, distribution
|
||
|
):
|
||
|
with distribution.scope():
|
||
|
model = keras.models.Sequential()
|
||
|
model.add(keras.layers.Dense(16, activation="relu"))
|
||
|
model.add(keras.layers.Dense(3, activation="softmax"))
|
||
|
|
||
|
if tf.executing_eagerly():
|
||
|
model.compile("sgd")
|
||
|
else:
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"We currently do not support distribution strategy with a "
|
||
|
"`Sequential` model that is created without "
|
||
|
"`input_shape`/`input_dim` set in its first layer or "
|
||
|
"a subclassed model.",
|
||
|
):
|
||
|
model.compile("sgd")
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
keras_test_lib.all_strategy_combinations_minus_default()
|
||
|
)
|
||
|
def test_standalone_loss_without_loss_reduction(self, distribution):
|
||
|
with distribution.scope():
|
||
|
loss_object = losses.MeanSquaredError()
|
||
|
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
"Please use `tf.keras.losses.Reduction.SUM` or "
|
||
|
"`tf.keras.losses.Reduction.NONE`",
|
||
|
):
|
||
|
y = np.asarray([1, 0])
|
||
|
loss_object(y, y)
|
||
|
|
||
|
|
||
|
class TestDistributionStrategyWithLossMasking(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
|
||
|
# TODO(priyag): Enable all strategies for this test. Currently it does not
|
||
|
# work for TPU due to some invalid datatype.
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
optimizer=optimizer_combinations.gradient_descent_optimizer_keras_v2_fn, # noqa: E501
|
||
|
)
|
||
|
)
|
||
|
def test_masking(self, distribution, optimizer):
|
||
|
with self.cached_session():
|
||
|
np.random.seed(1337)
|
||
|
x = np.array([[[1], [1]], [[0], [0]]])
|
||
|
with distribution.scope():
|
||
|
model = keras.models.Sequential()
|
||
|
model.add(
|
||
|
keras.layers.Masking(mask_value=0, input_shape=(2, 1))
|
||
|
)
|
||
|
model.add(
|
||
|
keras.layers.TimeDistributed(
|
||
|
keras.layers.Dense(1, kernel_initializer="one")
|
||
|
)
|
||
|
)
|
||
|
model.compile(loss="mse", optimizer=optimizer())
|
||
|
y = np.array([[[1], [1]], [[1], [1]]])
|
||
|
dataset = tf.data.Dataset.from_tensor_slices((x, y))
|
||
|
dataset = dataset.repeat(100)
|
||
|
dataset = dataset.batch(10)
|
||
|
hist = model.fit(x=dataset, epochs=1, steps_per_epoch=2)
|
||
|
self.assertEqual(hist.history["loss"][0], 0)
|
||
|
|
||
|
|
||
|
class TestDistributionStrategyWithNormalizationLayer(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations(),
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
fused=[True, False],
|
||
|
optimizer=optimizer_combinations.gradient_descent_optimizer_keras_v2_fn, # noqa: E501
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
def test_batchnorm_correctness(self, distribution, fused, optimizer):
|
||
|
with self.cached_session():
|
||
|
with distribution.scope():
|
||
|
model = keras.models.Sequential()
|
||
|
norm = keras.layers.BatchNormalization(
|
||
|
input_shape=(
|
||
|
10,
|
||
|
20,
|
||
|
30,
|
||
|
),
|
||
|
momentum=0.8,
|
||
|
fused=fused,
|
||
|
)
|
||
|
model.add(norm)
|
||
|
model.compile(loss="mse", optimizer=optimizer())
|
||
|
|
||
|
# centered on 5.0, variance 10.0
|
||
|
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30))
|
||
|
x = x.astype("float32")
|
||
|
dataset = tf.data.Dataset.from_tensor_slices((x, x))
|
||
|
dataset = dataset.repeat(100)
|
||
|
dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution)
|
||
|
|
||
|
predict_dataset = tf.data.Dataset.from_tensor_slices(x)
|
||
|
predict_dataset = predict_dataset.repeat(100)
|
||
|
predict_dataset = keras_test_lib.batch_wrapper(
|
||
|
predict_dataset, 32, distribution
|
||
|
)
|
||
|
|
||
|
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
|
||
|
out = model.predict(predict_dataset, steps=2)
|
||
|
out -= keras.backend.eval(norm.beta)
|
||
|
out /= keras.backend.eval(norm.gamma)
|
||
|
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
|
||
|
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||
|
|
||
|
# TODO(b/146181571): Enable this for all distribution strategies once
|
||
|
# DistributedVariable.assign() returns a variable for MirroredStrategy.
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.tpu_strategy_combinations(),
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
optimizer=optimizer_combinations.gradient_descent_optimizer_keras_v2_fn # noqa: E501
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
def test_batchnorm_correctness_with_renorm(self, distribution, optimizer):
|
||
|
with self.cached_session():
|
||
|
with distribution.scope():
|
||
|
model = keras.models.Sequential()
|
||
|
norm = keras.layers.BatchNormalization(
|
||
|
input_shape=(
|
||
|
10,
|
||
|
20,
|
||
|
30,
|
||
|
),
|
||
|
momentum=0.8,
|
||
|
fused=False,
|
||
|
renorm=True,
|
||
|
)
|
||
|
model.add(norm)
|
||
|
model.compile(loss="mse", optimizer=optimizer())
|
||
|
|
||
|
# centered on 5.0, variance 10.0
|
||
|
x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 10, 20, 30))
|
||
|
x = x.astype("float32")
|
||
|
dataset = tf.data.Dataset.from_tensor_slices((x, x))
|
||
|
dataset = dataset.repeat(100)
|
||
|
dataset = keras_test_lib.batch_wrapper(dataset, 32, distribution)
|
||
|
|
||
|
predict_dataset = tf.data.Dataset.from_tensor_slices(x)
|
||
|
predict_dataset = predict_dataset.repeat(100)
|
||
|
predict_dataset = keras_test_lib.batch_wrapper(
|
||
|
predict_dataset, 32, distribution
|
||
|
)
|
||
|
|
||
|
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
|
||
|
out = model.predict(predict_dataset, steps=2)
|
||
|
out -= keras.backend.eval(norm.beta)
|
||
|
out /= keras.backend.eval(norm.gamma)
|
||
|
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
|
||
|
np.testing.assert_allclose(out.std(), 1.0, atol=1e-1)
|
||
|
|
||
|
|
||
|
class TestDistributionStrategySaveLoadWeights(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations_minus_default(),
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
optimizer=optimizer_combinations.rmsprop_optimizer_keras_v2_fn
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
def test_save_load_h5(self, distribution, optimizer):
|
||
|
with self.cached_session():
|
||
|
dataset = keras_test_lib.get_dataset(distribution)
|
||
|
with distribution.scope():
|
||
|
model = keras_test_lib.get_model()
|
||
|
model.compile(optimizer(), "mse")
|
||
|
model.fit(dataset, epochs=1, steps_per_epoch=1)
|
||
|
|
||
|
weights_file = tempfile.mktemp(".h5")
|
||
|
model.save_weights(weights_file)
|
||
|
|
||
|
model_2 = keras_test_lib.get_model()
|
||
|
model_2.compile(optimizer(), "mse")
|
||
|
model_2.load_weights(weights_file)
|
||
|
model_2.predict(
|
||
|
keras_test_lib.get_predict_dataset(distribution), steps=2
|
||
|
)
|
||
|
model_2.fit(dataset, epochs=1, steps_per_epoch=1)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations_minus_default(),
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
optimizer=optimizer_combinations.rmsprop_optimizer_keras_v2_fn
|
||
|
),
|
||
|
)
|
||
|
)
|
||
|
def test_save_load_trackable(self, distribution, optimizer):
|
||
|
# TODO(b/123533246): Enable the test for TPU once bug is fixed
|
||
|
if (
|
||
|
isinstance(
|
||
|
distribution,
|
||
|
(
|
||
|
tf.distribute.experimental.TPUStrategy,
|
||
|
tf.compat.v1.distribute.experimental.TPUStrategy,
|
||
|
),
|
||
|
)
|
||
|
and distribution.extended.steps_per_run > 1
|
||
|
):
|
||
|
self.skipTest(
|
||
|
"MultiStep TPU Strategy deadlocks with optimizer restore."
|
||
|
)
|
||
|
with self.cached_session():
|
||
|
dataset = keras_test_lib.get_dataset(distribution)
|
||
|
with distribution.scope():
|
||
|
model = keras_test_lib.get_model()
|
||
|
model.compile(optimizer(), "mse")
|
||
|
model.fit(dataset, epochs=1, steps_per_epoch=1)
|
||
|
|
||
|
weights_file = tempfile.mktemp()
|
||
|
model.save_weights(weights_file)
|
||
|
|
||
|
model_2 = keras_test_lib.get_model()
|
||
|
model_2.compile(optimizer(), "mse")
|
||
|
model_2.load_weights(weights_file)
|
||
|
model_2.predict(
|
||
|
keras_test_lib.get_predict_dataset(distribution), steps=2
|
||
|
)
|
||
|
model_2.fit(dataset, epochs=1, steps_per_epoch=1)
|
||
|
|
||
|
|
||
|
class TestDistributionStrategyValidation(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.times(
|
||
|
keras_test_lib.all_strategy_combinations_minus_default()
|
||
|
)
|
||
|
)
|
||
|
def test_layer_outside_scope(self, distribution):
|
||
|
with self.cached_session():
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError, "was not created in the distribution strategy"
|
||
|
):
|
||
|
x = keras.layers.Input(shape=(3,), name="input")
|
||
|
y = keras.layers.Dense(4, name="dense")(x)
|
||
|
with distribution.scope():
|
||
|
model = keras.Model(x, y)
|
||
|
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
|
||
|
0.001
|
||
|
)
|
||
|
loss = "mse"
|
||
|
metrics = ["mae", keras.metrics.CategoricalAccuracy()]
|
||
|
model.compile(optimizer, loss, metrics=metrics)
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
keras_test_lib.all_strategy_combinations_minus_default()
|
||
|
)
|
||
|
def test_model_outside_scope(self, distribution):
|
||
|
with self.cached_session():
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError, "was not created in the distribution strategy"
|
||
|
):
|
||
|
x = keras.layers.Input(shape=(3,), name="input")
|
||
|
y = keras.layers.Dense(4, name="dense")(x)
|
||
|
model = keras.Model(x, y)
|
||
|
with distribution.scope():
|
||
|
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
|
||
|
0.001
|
||
|
)
|
||
|
loss = "mse"
|
||
|
metrics = ["mae", keras.metrics.CategoricalAccuracy()]
|
||
|
model.compile(optimizer, loss, metrics=metrics)
|
||
|
|
||
|
|
||
|
class TestDistributionStrategyWithStaticShapes(
|
||
|
tf.test.TestCase, parameterized.TestCase
|
||
|
):
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
)
|
||
|
)
|
||
|
def test_input_batch_size_not_divisible_by_num_replicas(self, distribution):
|
||
|
with distribution.scope():
|
||
|
with self.assertRaisesRegex(
|
||
|
ValueError,
|
||
|
r"The `batch_size` argument \(5\) must be divisible by "
|
||
|
r"the number of replicas \(2\)",
|
||
|
):
|
||
|
keras.layers.Input(shape=(3,), batch_size=5, name="input")
|
||
|
|
||
|
@tf.__internal__.distribute.combinations.generate(
|
||
|
tf.__internal__.test.combinations.combine(
|
||
|
distribution=[
|
||
|
tf.__internal__.distribute.combinations.mirrored_strategy_with_gpu_and_cpu, # noqa: E501
|
||
|
],
|
||
|
mode=["graph", "eager"],
|
||
|
)
|
||
|
)
|
||
|
def test_static_input_batch_size(self, distribution):
|
||
|
inputs = np.zeros((10, 3), dtype=np.float32)
|
||
|
targets = np.zeros((10, 4), dtype=np.float32)
|
||
|
dataset = tf.data.Dataset.from_tensor_slices((inputs, targets))
|
||
|
dataset = dataset.repeat(100)
|
||
|
dataset = dataset.batch(10, drop_remainder=True)
|
||
|
|
||
|
with distribution.scope():
|
||
|
x = keras.layers.Input(shape=(3,), batch_size=10, name="input")
|
||
|
y = keras.layers.Dense(4, name="dense")(x)
|
||
|
model = keras.Model(x, y)
|
||
|
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
|
||
|
|
||
|
model.fit(dataset, epochs=1, steps_per_epoch=5)
|
||
|
model.evaluate(dataset, steps=5)
|
||
|
model.predict(dataset)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
tf.__internal__.distribute.multi_process_runner.test_main()
|