143 lines
4.6 KiB
Python
143 lines
4.6 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities related to distributed training."""
|
|
|
|
import contextlib
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
from absl import flags
|
|
|
|
from keras import backend
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
# TODO(b/118776054): Currently we support global batch size for TPUStrategy and
|
|
# core MirroredStrategy only. Remove this check when contrib MirroredStrategy is
|
|
# no longer needed.
|
|
def global_batch_size_supported(distribution_strategy):
|
|
return distribution_strategy.extended._global_batch_size
|
|
|
|
|
|
def call_replica_local_fn(fn, *args, **kwargs):
|
|
"""Call a function that uses replica-local variables.
|
|
|
|
This function correctly handles calling `fn` in a cross-replica
|
|
context.
|
|
|
|
Args:
|
|
fn: The function to call.
|
|
*args: Positional arguments to the `fn`.
|
|
**kwargs: Keyword argument to `fn`.
|
|
|
|
Returns:
|
|
The result of calling `fn`.
|
|
"""
|
|
# TODO(b/132666209): Remove this function when we support assign_*
|
|
# for replica-local variables.
|
|
strategy = None
|
|
if "strategy" in kwargs:
|
|
strategy = kwargs.pop("strategy")
|
|
else:
|
|
if tf.distribute.has_strategy():
|
|
strategy = tf.distribute.get_strategy()
|
|
|
|
# TODO(b/120571621): TPUStrategy does not implement replica-local variables.
|
|
is_tpu = backend.is_tpu_strategy(strategy)
|
|
if (not is_tpu) and strategy and tf.distribute.in_cross_replica_context():
|
|
with strategy.scope():
|
|
return strategy.extended.call_for_each_replica(fn, args, kwargs)
|
|
return fn(*args, **kwargs)
|
|
|
|
|
|
def is_distributed_variable(v):
|
|
"""Returns whether `v` is a distributed variable."""
|
|
return isinstance(v, tf.distribute.DistributedValues) and isinstance(
|
|
v, tf.Variable
|
|
)
|
|
|
|
|
|
def get_strategy():
|
|
"""Creates a `tf.distribute.Strategy` object from flags.
|
|
|
|
Example usage:
|
|
|
|
```python
|
|
strategy = utils.get_strategy()
|
|
with strategy.scope():
|
|
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
|
|
|
|
model.compile(...)
|
|
train_ds, test_ds = ...
|
|
model.fit(train_ds, validation_data=test_ds, epochs=10)
|
|
```
|
|
|
|
Returns:
|
|
`tf.distribute.Strategy` instance.
|
|
"""
|
|
cls = FLAGS.keras_distribute_strategy_class
|
|
accepted_strats = {
|
|
"tpu",
|
|
"multi_worker_mirrored",
|
|
"mirrored",
|
|
"parameter_server",
|
|
"one_device",
|
|
}
|
|
if cls == "tpu":
|
|
tpu_addr = FLAGS.keras_distribute_strategy_tpu_addr
|
|
if not tpu_addr:
|
|
raise ValueError(
|
|
"When using a TPU strategy, you must set the flag "
|
|
"`keras_distribute_strategy_tpu_addr` (TPU address)."
|
|
)
|
|
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
|
|
tpu=tpu_addr
|
|
)
|
|
tf.config.experimental_connect_to_cluster(cluster_resolver)
|
|
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
|
|
strategy = tf.distribute.experimental.TPUStrategy(cluster_resolver)
|
|
elif cls == "multi_worker_mirrored":
|
|
strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy()
|
|
elif cls == "mirrored":
|
|
strategy = tf.distribute.MirroredStrategy()
|
|
elif cls == "parameter_server":
|
|
cluster_resolver = (
|
|
tf.distribute.cluster_resolver.TFConfigClusterResolver()
|
|
)
|
|
strategy = tf.distribute.experimental.ParameterServerStrategy(
|
|
cluster_resolver
|
|
)
|
|
elif cls == "one_device":
|
|
strategy = tf.distribute.OneDeviceStrategy("/gpu:0")
|
|
else:
|
|
raise ValueError(
|
|
"Unknown distribution strategy flag. Received: "
|
|
f"keras_distribute_strategy_class={cls}. "
|
|
f"It should be one of {accepted_strats}"
|
|
)
|
|
return strategy
|
|
|
|
|
|
def maybe_preemption_handler_scope(model):
|
|
|
|
if getattr(model, "_preemption_handler", None):
|
|
preemption_checkpoint_scope = (
|
|
model._preemption_handler._watch_error_scope()
|
|
)
|
|
else:
|
|
preemption_checkpoint_scope = contextlib.nullcontext()
|
|
|
|
return preemption_checkpoint_scope
|