Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/engine/training_distributed_v1.py
2023-06-19 00:49:18 +02:00

924 lines
31 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Part of the Keras training engine related to distributed training."""
import numpy as np
import tensorflow.compat.v2 as tf
from keras import backend
from keras import callbacks as cbks
from keras.distribute import distribute_coordinator_utils as dc
from keras.distribute import distributed_training_utils_v1 as dist_utils
from keras.engine import partial_batch_padding_handler as padding_util
from keras.engine import training_arrays_v1
from keras.engine import training_utils_v1
from keras.utils.generic_utils import Progbar
from keras.utils.mode_keys import ModeKeys
# isort: off
from tensorflow.python.distribute import input_lib
from tensorflow.python.platform import tf_logging as logging
def _per_replica_execution_function(model, mode):
exec_func = model._make_execution_function(mode)
return (
exec_func.inputs,
exec_func.outputs,
exec_func.updates_op,
exec_func.session_kwargs,
)
def _build_model(strategy, model, mode, inputs, targets=None):
if model._compile_distribution:
dist_utils.clone_model_on_replicas(
model, strategy, mode, inputs=inputs, targets=targets
)
else:
dist_utils._build_distributed_network(
model, strategy, mode, inputs, targets
)
def _make_train_step_fn(model, mode, strategy, output_labels):
"""Create step fn.
Args:
model: a Keras Model instance.
mode: One of ModeKeys.TRAIN/ModeKeys.TEST/ModeKeys.PREDICT.
strategy: a `tf.distribute.Strategy` instance.
output_labels: the output labels for the step function.
Returns:
A step function to run by `tf.distribute.Strategy`.
"""
def _step_fn(ctx, inputs):
"""A step fn that returns update ops."""
if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
inputs, targets = inputs
else:
targets = None
# When input feature is a dictionary of tensors, dictionary is
# flattended to an array and passed as a model input. This results in
# input mismatch when model input layer names are not sorted in
# alphabetical order as `nest.flatten()`sorts dictionary elements by
# keys. As so, transform input tensors into an array and order it along
# `model._feed_input_names`.
if isinstance(inputs, dict):
inputs = [
inputs[input_name] for input_name in model._feed_input_names
]
_build_model(strategy, model, mode, inputs, targets)
(
grouped_inputs,
grouped_outputs,
grouped_updates,
grouped_session_args,
) = strategy.extended.call_for_each_replica(
_per_replica_execution_function,
args=(dist_utils.get_distributed_model(model, mode), mode),
)
(
all_inputs,
all_outputs,
all_updates,
all_session_args,
) = dist_utils.unwrap_values(
strategy,
grouped_inputs,
grouped_outputs,
grouped_updates,
grouped_session_args,
)
combined_fn = backend.function(
all_inputs,
all_outputs,
updates=all_updates,
name="distributed_" + str(mode) + "_function",
**all_session_args
)
for label, output in zip(output_labels, combined_fn.outputs):
if label == "loss":
reduce_op = tf.distribute.ReduceOp.SUM
else:
# We reduce all other metrics using mean for now. This is
# temporary workaround until new metrics are in place.
reduce_op = tf.distribute.ReduceOp.MEAN
ctx.set_last_step_output(label, output, reduce_op)
# TODO(priyag, sourabhbajaj): Ignoring these things from the
# combined_fn: feed_dict, session kwargs, run options, run_metadata for
# now. These should be handled appropriately
return combined_fn.updates_op
return _step_fn
def experimental_tpu_fit_loop(
model,
dataset,
epochs=100,
verbose=1,
callbacks=None,
initial_epoch=0,
steps_per_epoch=None,
val_dataset=None,
validation_steps=None,
validation_freq=1,
):
"""Fit loop for training with TPU tf.distribute.Strategy.
Args:
model: Keras Model instance.
dataset: Dataset that returns inputs and targets
epochs: Number of times to iterate over the data
verbose: Integer, Verbosity mode, 0, 1 or 2
callbacks: List of callbacks to be called during training
initial_epoch: Epoch at which to start training
(useful for resuming a previous training run)
steps_per_epoch: Total number of steps (batches of samples)
before declaring one epoch finished and starting the
next epoch. Ignored with the default value of `None`.
val_dataset: Dataset for validation data.
validation_steps: Number of steps to run validation for
(only if doing validation from data tensors).
Ignored with the default value of `None`.
validation_freq: Only relevant if validation data is provided. Integer
or `collections.abc.Container` instance (e.g. list, tuple, etc.). If
an integer, specifies how many training epochs to run before a new
validation run is performed, e.g. `validation_freq=2` runs
validation every 2 epochs. If a Container, specifies the epochs on
which to run validation, e.g. `validation_freq=[1, 2, 10]` runs
validation at the end of the 1st, 2nd, and 10th epochs.
Returns:
Returns `None`.
Raises:
ValueError: in case of invalid arguments.
"""
mode = ModeKeys.TRAIN
current_strategy = model._distribution_strategy
iteration_value = min(
steps_per_epoch, current_strategy.extended.steps_per_run
)
steps_per_run = backend.variable(
value=iteration_value, dtype="int32", name="steps_per_run"
)
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
iterator = dist_utils.get_iterator(dataset, current_strategy)
scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=1
)
scope.__enter__()
out_labels = model.metrics_names or []
step_fn = _make_train_step_fn(
model, ModeKeys.TRAIN, current_strategy, out_labels
)
# Add initial dummy values for loss and other metric tensors.
initial_loop_values = {}
initial_loop_values["loss"] = tf.constant(1e7)
for m in model._get_training_eval_metrics():
tensor = m.result()
initial_loop_values[m.name] = tf.zeros(tensor.shape, tensor.dtype)
ctx = current_strategy.extended.experimental_run_steps_on_iterator(
step_fn,
iterator,
iterations=steps_per_run,
initial_loop_values=initial_loop_values,
)
train_op = ctx.run_op
output_tensors = ctx.last_step_outputs
do_validation = bool(validation_steps)
if model._compile_distribution:
dist_utils._copy_weights_to_distributed_model(model, mode)
callbacks = cbks.configure_callbacks(
callbacks,
model,
do_validation=do_validation,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
verbose=verbose,
count_mode="steps",
mode=mode,
)
# Calculate the steps each time on the device.
steps_to_run = [current_strategy.extended.steps_per_run] * (
steps_per_epoch // current_strategy.extended.steps_per_run
)
if steps_per_epoch % current_strategy.extended.steps_per_run:
steps_to_run.append(
steps_per_epoch % current_strategy.extended.steps_per_run
)
target_steps = len(steps_to_run)
callbacks._call_begin_hook(mode)
initial_epoch = model._maybe_load_initial_epoch_from_ckpt(
initial_epoch, mode
)
for epoch in range(initial_epoch, epochs):
dist_utils._reset_metrics(model)
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
step_index = 0
prev_step_count = None
current_step = 0
while current_step < target_steps:
step_count = steps_to_run[current_step]
batch_logs = {
"batch": step_index,
"size": 1,
"num_steps": step_count,
}
callbacks._call_batch_hook(mode, "begin", step_index, batch_logs)
if prev_step_count is None or step_count != prev_step_count:
backend.get_session().run(steps_per_run.assign(step_count))
prev_step_count = step_count
try:
_, outputs = backend.batch_get_value([train_op, output_tensors])
except tf.errors.OutOfRangeError:
logging.warning(
"Your dataset iterator ran out of data; "
"interrupting training. Make sure that your dataset "
"can generate at least `steps_per_epoch * epochs` "
"batches (in this case, %d batches)."
% steps_per_epoch
* epochs
)
break
batch_logs.update(outputs)
callbacks._call_batch_hook(mode, "end", step_index, batch_logs)
step_index = step_index + step_count
current_step += 1
if callbacks.model.stop_training:
break
if do_validation and training_utils_v1.should_run_validation(
validation_freq, epoch
):
logging.info("Running validation at fit epoch: %s", epoch)
if model._compile_distribution:
# Since we create a new clone from the original model we need to
# copy the weights back to the original model before we can run
# validation.
dist_utils._copy_weights_to_original_model(
model, ModeKeys.TRAIN
)
val_outs = experimental_tpu_test_loop(
model,
val_dataset,
steps=validation_steps,
verbose=verbose,
callbacks=callbacks,
)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
for label, val_out in zip(out_labels, val_outs):
epoch_logs["val_" + label] = val_out
callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
model._successful_loop_finish = True
callbacks._call_end_hook(mode)
if model._compile_distribution:
# Copy the weights back from the replicated model to the original model.
dist_utils._copy_weights_to_original_model(model, ModeKeys.TRAIN)
scope.__exit__(None, None, None)
return model.history
def experimental_tpu_test_loop(
model, dataset, verbose=0, steps=None, callbacks=None
):
"""Test loop for evaluating with TPU tf.distribute.Strategy.
Args:
model: Keras Model instance.
dataset: Dataset for input data.
verbose: Integer, Verbosity mode 0 or 1.
steps: Total number of steps (batches of samples)
before declaring predictions finished.
Ignored with the default value of `None`.
callbacks: List of callbacks to be called during training
Returns:
Scalar loss (if the model has a single output and no metrics)
or list of scalars (if the model has multiple outputs
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the outputs.
"""
mode = ModeKeys.TEST
current_strategy = model._distribution_strategy
iterator = dist_utils.get_iterator(dataset, current_strategy)
scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=0
)
scope.__enter__()
out_labels = model.metrics_names
def _test_step_fn(inputs):
"""A fn that returns output of single test step."""
if isinstance(inputs, (tuple, list)) and len(inputs) == 2:
inputs, targets = inputs
else:
targets = None
(
tf.distribute.get_replica_context().merge_call(
_build_model, args=(model, mode, inputs, targets)
)
)
(_, outputs, updates, _) = _per_replica_execution_function(
dist_utils.get_distributed_model(model, mode), mode
)
with tf.control_dependencies([updates]):
return [tf.identity(out) for out in outputs]
test_input_data = iterator.get_next()
per_replica_outputs = current_strategy.run(
_test_step_fn, args=(test_input_data,)
)
output_tensors = {}
for label, output in zip(out_labels, per_replica_outputs):
if label == "loss":
reduce_op = tf.distribute.ReduceOp.SUM
else:
# We reduce all other metrics using mean for now. This is temporary
# workaround until new metrics are in place.
reduce_op = tf.distribute.ReduceOp.MEAN
output_tensors[label] = current_strategy.reduce(
reduce_op, output, axis=None
)
test_op = tf.group(list(output_tensors.values()))
if verbose >= 1:
progbar = Progbar(target=steps)
if model._compile_distribution:
dist_utils._copy_weights_to_distributed_model(model, mode)
dist_utils._reset_metrics(model)
callbacks = cbks.configure_callbacks(
callbacks,
model,
do_validation=False,
epochs=1,
steps_per_epoch=steps,
verbose=verbose,
count_mode="steps",
mode=ModeKeys.TEST,
)
callbacks._call_begin_hook(mode)
outs = [0.0] * len(model.metrics_names)
if steps is not None:
target_steps = steps
else:
raise ValueError(
"Number of steps could not be inferred from the data, "
"please pass the steps argument."
)
current_step = 0
while current_step < target_steps:
batch_logs = {"batch": current_step, "size": 1}
callbacks._call_batch_hook(mode, "begin", current_step, batch_logs)
try:
_, batch_outs = backend.batch_get_value([test_op, output_tensors])
except tf.errors.OutOfRangeError:
warning_msg = (
"Make sure that your dataset can generate at least "
"`steps` batches (in this case, {} batches).".format(steps)
)
logging.warning(
"Your dataset iterator ran out of data; "
"interrupting evaluation. " + warning_msg
)
target_steps = current_step
break
for i, label in enumerate(model.metrics_names):
if i == 0:
# Loss is stateless metrics.
outs[i] += batch_outs[label]
else:
# For all stateful metrics, the aggregation is handled by
# mirrored vars.
outs[i] = batch_outs[label]
batch_logs = callbacks.make_logs(model, batch_logs, outs, mode)
callbacks._call_batch_hook(mode, "end", current_step, batch_logs)
if verbose == 1:
progbar.update(current_step + 1)
current_step += 1
if verbose >= 1:
# Progress bar finishes at the end.
progbar.update(target_steps)
callbacks._call_end_hook(mode)
scope.__exit__(None, None, None)
if len(outs) > 0:
outs[0] /= target_steps
if len(outs) == 1:
return outs[0]
return outs
def experimental_tpu_predict_loop(
model, dataset, verbose=0, steps=None, callbacks=None
):
"""Predict loop for predicting with TPU tf.distribute.Strategy.
Args:
model: Keras Model instance.
dataset: Dataset for input data.
verbose: Integer, Verbosity mode 0 or 1.
steps: Total number of steps (batches of samples)
before declaring `_predict_loop` finished.
Ignored with the default value of `None`.
callbacks: List of callbacks to be called during training
Returns:
Array of predictions (if the model has a single output)
or list of arrays of predictions
(if the model has multiple outputs).
"""
mode = ModeKeys.PREDICT
dataset_fully_shaped = dist_utils.is_dataset_shape_fully_defined(dataset)
padding_handler = None
if not dataset_fully_shaped:
# TODO(hongjunchoi): Investigate whether operations from
# PartialBatchPaddingHandler are unnecessarily pruned out
# during graph optimization.
padding_handler = padding_util.PartialBatchPaddingHandler(
model._feed_output_shapes
)
batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(
dataset
)
padding_handler.padded_batch_size = batch_size
padding_handler.padding_mask = dataset.reduce(
padding_handler.padding_mask, padding_handler.update_mask
)
dataset = dataset.map(padding_handler.pad_batch)
dataset = dataset.unbatch()
# Upon this point, it is guaranteed that the dataset does not
# have partial batches. Thus, we set `drop_remainder=True` to
# get static shape information about the elements in the dataset.
dataset = dataset.batch(batch_size, drop_remainder=True)
if prefetch_buffer is not None:
dataset = dataset.prefetch(prefetch_buffer)
current_strategy = model._distribution_strategy
iterator = dist_utils.get_iterator(dataset, current_strategy)
scope = dist_utils.distributed_scope(
strategy=current_strategy, learning_phase=0
)
scope.__enter__()
def _predict_step_fn(inputs):
"""A fn that returns output of single prediction step."""
(
tf.distribute.get_replica_context().merge_call(
_build_model, args=(model, mode, inputs)
)
)
(_, outputs, updates, _) = _per_replica_execution_function(
dist_utils.get_distributed_model(model, mode), mode
)
with tf.control_dependencies([updates]):
return [tf.identity(out) for out in outputs]
# TODO(hongjunchoi): When numpy array is passed as an input to `predict()`
# use numpy arrays directly to avoid cumulating unnecessary input pipeline
# ops.
predict_input_data = iterator.get_next()
per_replica_outputs = current_strategy.run(
_predict_step_fn, args=(predict_input_data,)
)
output_tensors = dist_utils.flatten_per_replica_values(
current_strategy, per_replica_outputs
)
if verbose >= 1:
progbar = Progbar(target=steps)
if model._compile_distribution:
dist_utils._copy_weights_to_distributed_model(model, mode)
dist_utils._reset_metrics(model)
callbacks = cbks.configure_callbacks(
callbacks,
model,
do_validation=False,
epochs=1,
steps_per_epoch=steps,
verbose=verbose,
count_mode="steps",
mode=mode,
)
callbacks._call_begin_hook(mode)
# Since we do not know how many samples we will see, we cannot pre-allocate
# the returned Numpy arrays. Instead, we store one array per batch seen
# and concatenate them upon returning.
num_model_outputs = len(model.output_names)
unconcatenated_outs = [[] for _ in range(num_model_outputs)]
if steps is not None:
target_steps = steps
else:
raise ValueError(
"Number of steps could not be inferred from the data, "
"please pass the steps argument."
)
current_step = 0
while current_step < target_steps:
batch_logs = {"batch": current_step, "size": 1}
callbacks._call_batch_hook(mode, "begin", current_step, batch_logs)
try:
predict_ops = tf.group(output_tensors)
_, batch_outs = backend.batch_get_value(
[predict_ops, output_tensors]
)
except tf.errors.OutOfRangeError:
warning_msg = (
"Make sure that your dataset can generate at least "
"`steps` batches (in this case, {} batches).".format(steps)
)
logging.warning(
"Your dataset iterator ran out of data; "
"interrupting evaluation. " + warning_msg
)
break
# TODO(priyag): maybe need to unwrap the outputs first for
# MirroredStrategy.
for i in range(num_model_outputs):
output_start_index = i * current_strategy.num_replicas_in_sync
output_end_index = (
output_start_index + current_strategy.num_replicas_in_sync
)
single_model_output = batch_outs[
output_start_index:output_end_index
]
unconcatenated_outs[i].extend(single_model_output)
batch_logs = callbacks.make_logs(model, batch_logs, batch_outs, mode)
callbacks._call_batch_hook(mode, "end", current_step, batch_logs)
if verbose == 1:
progbar.update(current_step + 1)
current_step += 1
if verbose >= 1:
# Progress bar finishes at the end.
progbar.update(current_step)
callbacks._call_end_hook(mode)
scope.__exit__(None, None, None)
if len(unconcatenated_outs) == 1:
prediction_result = np.concatenate(unconcatenated_outs[0], axis=0)
else:
prediction_result = [
np.concatenate(out, axis=0) for out in unconcatenated_outs
]
if padding_handler:
prediction_result = padding_handler.apply_mask(prediction_result)
return prediction_result
class DistributionSingleWorkerTrainingLoop(training_utils_v1.TrainingLoop):
"""Training loop for distribution strategy with single worker."""
def fit(
self,
model,
x=None,
y=None,
batch_size=None,
epochs=1,
verbose=1,
callbacks=None,
validation_split=0.0,
validation_data=None,
shuffle=True,
class_weight=None,
sample_weight=None,
initial_epoch=0,
steps_per_epoch=None,
validation_steps=None,
validation_freq=1,
**kwargs
):
"""Fit loop for Distribution Strategies."""
dist_utils.validate_callbacks(
input_callbacks=callbacks, optimizer=model.optimizer
)
dist_utils.validate_inputs(x, y)
batch_size, steps_per_epoch = dist_utils.process_batch_and_step_size(
model._distribution_strategy,
x,
batch_size,
steps_per_epoch,
ModeKeys.TRAIN,
validation_split=validation_split,
)
batch_size = model._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x
)
dataset = model._distribution_standardize_user_data(
x,
y,
sample_weight=sample_weight,
class_weight=class_weight,
batch_size=batch_size,
validation_split=validation_split,
shuffle=shuffle,
epochs=epochs,
)
if not dist_utils.is_distributing_by_cloning(model):
with model._distribution_strategy.scope():
(dataset, _, _) = model._standardize_user_data(
dataset,
sample_weight=sample_weight,
class_weight=class_weight,
batch_size=batch_size,
validation_split=validation_split,
shuffle=shuffle,
)
val_dataset = None
if validation_data:
(
val_x,
val_y,
val_sample_weights,
) = training_utils_v1.unpack_validation_data(validation_data)
dist_utils.validate_inputs(val_x, val_y)
_, validation_steps = dist_utils.process_batch_and_step_size(
model._distribution_strategy,
val_x,
batch_size,
validation_steps,
ModeKeys.TEST,
)
val_dataset = model._distribution_standardize_user_data(
val_x,
val_y,
sample_weight=val_sample_weights,
class_weight=None,
batch_size=batch_size,
validation_split=validation_split,
shuffle=shuffle,
allow_partial_batch=True,
)
elif validation_split:
raise ValueError(
"validation_split argument is not supported with "
"distribution strategies."
)
if backend.is_tpu_strategy(model._distribution_strategy):
steps_per_epoch = training_utils_v1.infer_steps_for_dataset(
model,
dataset,
steps_per_epoch,
epochs,
steps_name="steps_per_epoch",
)
if steps_per_epoch is None:
raise ValueError(
"Number of steps could not be inferred from the data, "
"please pass the steps_per_epoch argument."
)
if not tf.executing_eagerly():
# Run TPU training in a custom loop in graph mode.
return experimental_tpu_fit_loop(
model,
dataset,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
val_dataset=val_dataset,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
)
return training_arrays_v1.fit_loop(
model,
dataset,
batch_size=batch_size,
epochs=epochs,
verbose=verbose,
callbacks=callbacks,
val_inputs=val_dataset,
shuffle=shuffle,
initial_epoch=initial_epoch,
steps_per_epoch=steps_per_epoch,
validation_steps=validation_steps,
validation_freq=validation_freq,
steps_name="steps_per_epoch",
)
def evaluate(
self,
model,
x=None,
y=None,
batch_size=None,
verbose=1,
sample_weight=None,
steps=None,
callbacks=None,
**kwargs
):
"""Evaluate loop for Distribution Strategies."""
dist_utils.validate_inputs(x, y)
batch_size, steps = dist_utils.process_batch_and_step_size(
model._distribution_strategy, x, batch_size, steps, ModeKeys.TEST
)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
dataset = model._distribution_standardize_user_data(
x,
y,
sample_weight=sample_weight,
batch_size=batch_size,
allow_partial_batch=True,
)
if backend.is_tpu_strategy(model._distribution_strategy):
steps = training_utils_v1.infer_steps_for_dataset(
model, dataset, steps, steps_name="steps"
)
if steps is None:
raise ValueError(
"Number of steps could not be inferred from the data, "
"please pass the steps argument."
)
if not tf.executing_eagerly():
# Run TPU evaluation in a custom loop in graph mode.
return experimental_tpu_test_loop(
model,
dataset,
verbose=verbose,
steps=steps,
callbacks=callbacks,
)
return training_arrays_v1.test_loop(
model,
inputs=dataset,
batch_size=batch_size,
verbose=verbose,
steps=steps,
callbacks=callbacks,
)
def predict(
self,
model,
x,
batch_size=None,
verbose=0,
steps=None,
callbacks=None,
**kwargs
):
"""Predict loop for Distribution Strategies."""
dist_utils.validate_inputs(x=x, y=None)
batch_size, steps = dist_utils.process_batch_and_step_size(
model._distribution_strategy, x, batch_size, steps, ModeKeys.PREDICT
)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
dataset = model._distribution_standardize_user_data(
x, batch_size=batch_size, allow_partial_batch=True
)
if backend.is_tpu_strategy(model._distribution_strategy):
steps = training_utils_v1.infer_steps_for_dataset(
model, dataset, steps, steps_name="steps"
)
if steps is None:
raise ValueError(
"Number of steps could not be inferred from the data, "
"please pass the steps argument."
)
if not tf.executing_eagerly():
return experimental_tpu_predict_loop(
model,
dataset,
verbose=verbose,
steps=steps,
callbacks=callbacks,
)
return training_arrays_v1.predict_loop(
model,
dataset,
batch_size=batch_size,
verbose=verbose,
steps=steps,
callbacks=callbacks,
)
def _train_with_multi_worker(method):
"""Decorator handles multi worker training with distribution strategy."""
def wrapper(model, **kwargs):
def _worker_fn(_):
callbacks = kwargs.pop("callbacks", None)
filtered_callbacks = dist_utils.filter_distributed_callbacks(
callbacks, model
)
kwargs["callbacks"] = filtered_callbacks
return method(model, **kwargs)
return dc.run_distribute_coordinator(
_worker_fn, model._distribution_strategy
)
return wrapper
class DistributionMultiWorkerTrainingLoop(training_utils_v1.TrainingLoop):
"""Training loop for distribution strategy with multiple worker."""
def __init__(self, single_worker_loop):
self._single_worker_loop = single_worker_loop
def fit(self, *args, **kwargs):
return _train_with_multi_worker(self._single_worker_loop.fit)(
*args, **kwargs
)
def evaluate(self, *args, **kwargs):
return _train_with_multi_worker(self._single_worker_loop.evaluate)(
*args, **kwargs
)
def predict(self, *args, **kwargs):
# Currently predict is still using the single worker implementation.
return self._single_worker_loop.predict(*args, **kwargs)