784 lines
29 KiB
Python
784 lines
29 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Utilities related to distribute coordinator.
|
|
|
|
The module is used only for utils to support legacy TF1 code path involving
|
|
distribute coordinator, and is not expected to change in any way. This is
|
|
subject to cleanup once TF1 is no longer supported.
|
|
|
|
TODO(rchao): Remove this module once TF1 is not supported.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import copy
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
# isort: off
|
|
from tensorflow.core.protobuf import cluster_pb2
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
|
|
_worker_context = threading.local()
|
|
_thread_local = threading.local()
|
|
|
|
|
|
def get_current_worker_context():
|
|
"""Returns the current task context."""
|
|
try:
|
|
return _worker_context.current
|
|
except AttributeError:
|
|
return None
|
|
|
|
|
|
class _TaskType:
|
|
PS = "ps"
|
|
WORKER = "worker"
|
|
CHIEF = "chief"
|
|
EVALUATOR = "evaluator"
|
|
CLIENT = "client"
|
|
|
|
|
|
def _get_num_workers(cluster_spec):
|
|
"""Gets number of workers including chief."""
|
|
if not cluster_spec:
|
|
return 0
|
|
return len(cluster_spec.as_dict().get(_TaskType.WORKER, [])) + len(
|
|
cluster_spec.as_dict().get(_TaskType.CHIEF, [])
|
|
)
|
|
|
|
|
|
class _WorkerContext:
|
|
"""The worker context class.
|
|
|
|
This context object provides configuration information for each task. One
|
|
context manager with a worker context object will be created per invocation
|
|
to the `worker_fn` where `get_current_worker_context` can be called to
|
|
access the worker context object.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config=None,
|
|
rpc_layer="grpc",
|
|
worker_barrier=None,
|
|
):
|
|
"""Initialize the worker context object.
|
|
|
|
Args:
|
|
strategy: a `DistributionStrategy` object.
|
|
cluster_spec: a ClusterSpec object. It can be empty or None in the
|
|
local training case.
|
|
task_type: a string indicating the role of the corresponding task,
|
|
such as "worker" or "ps". It can be None if it is local training or
|
|
in-graph replicated training.
|
|
task_id: an integer indicating id of the corresponding task. It can be
|
|
None if it is local training or in-graph replicated training.
|
|
session_config: an optional `tf.compat.v1.ConfigProto` object.
|
|
rpc_layer: optional string specifying the RPC protocol for
|
|
communication with worker masters. If None or empty, hosts in the
|
|
`cluster_spec` will be used directly.
|
|
worker_barrier: optional, the barrier object for worker
|
|
synchronization.
|
|
"""
|
|
self._strategy = strategy
|
|
self._cluster_spec = cluster_spec
|
|
self._task_type = task_type
|
|
self._task_id = task_id
|
|
self._session_config = session_config
|
|
self._worker_barrier = worker_barrier
|
|
self._rpc_layer = rpc_layer
|
|
self._master_target = self._get_master_target()
|
|
self._num_workers = _get_num_workers(cluster_spec)
|
|
self._is_chief_node = self._is_chief()
|
|
|
|
def _debug_message(self):
|
|
if self._cluster_spec:
|
|
return "[cluster_spec: %r, task_type: %r, task_id: %r]" % (
|
|
self._cluster_spec,
|
|
self.task_type,
|
|
self.task_id,
|
|
)
|
|
else:
|
|
return "[local]"
|
|
|
|
def __enter__(self):
|
|
old_context = get_current_worker_context()
|
|
if old_context:
|
|
raise ValueError(
|
|
"You cannot run distribute coordinator in a `worker_fn`.\t"
|
|
+ self._debug_message()
|
|
)
|
|
|
|
_worker_context.current = self
|
|
|
|
def __exit__(
|
|
self, unused_exception_type, unused_exception_value, unused_traceback
|
|
):
|
|
|
|
_worker_context.current = None
|
|
|
|
def _get_master_target(self):
|
|
"""Return the master target for a task."""
|
|
# If cluster_spec is None or empty, we use local master.
|
|
if not self._cluster_spec or self._task_type == _TaskType.EVALUATOR:
|
|
return ""
|
|
|
|
# If task_type is None, then it is in-graph replicated training. In this
|
|
# case we use the chief or first worker's master target.
|
|
if not self._task_type:
|
|
if _TaskType.CHIEF in self._cluster_spec.jobs:
|
|
task_type = _TaskType.CHIEF
|
|
task_id = 0
|
|
else:
|
|
assert _TaskType.WORKER in self._cluster_spec.jobs
|
|
task_type = _TaskType.WORKER
|
|
task_id = 0
|
|
else:
|
|
task_type = self._task_type
|
|
task_id = self._task_id
|
|
|
|
prefix = ""
|
|
if self._rpc_layer:
|
|
prefix = self._rpc_layer + "://"
|
|
return prefix + self._cluster_spec.job_tasks(task_type)[task_id or 0]
|
|
|
|
def _is_chief(self):
|
|
"""Return whether the task is the chief worker."""
|
|
if not self._cluster_spec or self._task_type in [
|
|
_TaskType.CHIEF,
|
|
_TaskType.EVALUATOR,
|
|
None,
|
|
]:
|
|
return True
|
|
|
|
# If not local and chief not in the cluster_spec, use the first worker
|
|
# as chief.
|
|
if (
|
|
_TaskType.CHIEF not in self._cluster_spec.jobs
|
|
and self._task_type == _TaskType.WORKER
|
|
and self._task_id == 0
|
|
):
|
|
return True
|
|
return False
|
|
|
|
def wait_for_other_workers(self):
|
|
"""Waits for other workers to reach the same call to this method.
|
|
|
|
Raises:
|
|
ValueError: if `worker_barrier` is not passed to the __init__ method.
|
|
"""
|
|
if not self._worker_barrier:
|
|
# TODO(yuefengz): we should throw an error in independent worker
|
|
# mode.
|
|
return
|
|
self._worker_barrier.wait()
|
|
|
|
def session_creator(
|
|
self,
|
|
scaffold=None,
|
|
config=None,
|
|
checkpoint_dir=None,
|
|
checkpoint_filename_with_path=None,
|
|
max_wait_secs=7200,
|
|
):
|
|
"""Returns a session creator.
|
|
|
|
The returned session creator will be configured with the correct master
|
|
target and session configs. It will also run either init ops or ready
|
|
ops by querying the `strategy` object when `create_session` is called on
|
|
it.
|
|
|
|
Args:
|
|
scaffold: A `Scaffold` used for gathering or building supportive ops.
|
|
If not specified a default one is created. It's used to finalize the
|
|
graph.
|
|
config: `ConfigProto` proto used to configure the session.
|
|
checkpoint_dir: A string. Optional path to a directory where to
|
|
restore variables.
|
|
checkpoint_filename_with_path: Full file name path to the checkpoint
|
|
file. Only one of `checkpoint_dir` or
|
|
`checkpoint_filename_with_path` can be specified.
|
|
max_wait_secs: Maximum time to wait for the session to become
|
|
available.
|
|
|
|
Returns:
|
|
a descendant of SessionCreator.
|
|
"""
|
|
if config:
|
|
session_config = copy.deepcopy(config)
|
|
session_config.MergeFrom(self._session_config)
|
|
else:
|
|
session_config = self._session_config
|
|
|
|
if (
|
|
not self._strategy
|
|
or self._strategy.extended.experimental_should_init
|
|
):
|
|
logging.info(
|
|
"Creating chief session creator with config: %r", config
|
|
)
|
|
return tf.compat.v1.train.ChiefSessionCreator(
|
|
scaffold,
|
|
master=self.master_target,
|
|
config=session_config,
|
|
checkpoint_dir=checkpoint_dir,
|
|
checkpoint_filename_with_path=checkpoint_filename_with_path,
|
|
)
|
|
else:
|
|
logging.info(
|
|
"Creating worker session creator with config: %r", config
|
|
)
|
|
return tf.compat.v1.train.WorkerSessionCreator(
|
|
scaffold,
|
|
master=self.master_target,
|
|
config=session_config,
|
|
max_wait_secs=max_wait_secs,
|
|
)
|
|
|
|
@property
|
|
def session_config(self):
|
|
return copy.deepcopy(self._session_config)
|
|
|
|
@property
|
|
def has_barrier(self):
|
|
"""Whether the barrier is set or not."""
|
|
return self._worker_barrier is not None
|
|
|
|
@property
|
|
def distributed_mode(self):
|
|
"""Whether it is distributed training or not."""
|
|
return (
|
|
bool(self._cluster_spec) and self._task_type != _TaskType.EVALUATOR
|
|
)
|
|
|
|
@property
|
|
def cluster_spec(self):
|
|
"""Returns a copy of the cluster_spec object."""
|
|
return copy.deepcopy(self._cluster_spec)
|
|
|
|
@property
|
|
def task_type(self):
|
|
"""Returns the role of the corresponding task."""
|
|
return self._task_type
|
|
|
|
@property
|
|
def task_id(self):
|
|
"""Returns the id or index of the corresponding task."""
|
|
return self._task_id
|
|
|
|
@property
|
|
def master_target(self):
|
|
"""Returns the session master for the corresponding task to connect
|
|
to."""
|
|
return self._master_target
|
|
|
|
@property
|
|
def is_chief(self):
|
|
"""Returns whether the task is a chief node."""
|
|
return self._is_chief_node
|
|
|
|
@property
|
|
def num_workers(self):
|
|
"""Returns number of workers in the cluster, including chief."""
|
|
return self._num_workers
|
|
|
|
@property
|
|
def experimental_should_init(self):
|
|
"""Whether to run init ops."""
|
|
return self._strategy.extended.experimental_should_init
|
|
|
|
@property
|
|
def should_checkpoint(self):
|
|
"""Whether to save checkpoint."""
|
|
return self._strategy.extended.should_checkpoint
|
|
|
|
@property
|
|
def should_save_summary(self):
|
|
"""Whether to save summaries."""
|
|
return self._strategy.extended.should_save_summary
|
|
|
|
|
|
def _run_single_worker(
|
|
worker_fn,
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config,
|
|
rpc_layer="",
|
|
worker_barrier=None,
|
|
coord=None,
|
|
):
|
|
"""Runs a single worker by calling `worker_fn` under context."""
|
|
session_config = copy.deepcopy(session_config)
|
|
strategy = copy.deepcopy(strategy)
|
|
# If there is an EVALUATOR task, we run single-machine eval on that task.
|
|
if task_type == _TaskType.EVALUATOR:
|
|
# It is possible to not have a strategy object for EVALUATOR task.
|
|
if strategy:
|
|
strategy.configure(session_config)
|
|
else:
|
|
assert strategy
|
|
strategy.configure(session_config, cluster_spec, task_type, task_id)
|
|
|
|
context = _WorkerContext(
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config=session_config,
|
|
rpc_layer=rpc_layer,
|
|
worker_barrier=worker_barrier,
|
|
)
|
|
with context:
|
|
if coord:
|
|
with coord.stop_on_exception():
|
|
return worker_fn(strategy)
|
|
else:
|
|
return worker_fn(strategy)
|
|
|
|
|
|
def _split_cluster_for_evaluator(cluster_spec, task_type):
|
|
"""Split the cluster for evaluator since it needn't talk to other tasks."""
|
|
# Splitting the cluster is important to prevent the evaluator from talking
|
|
# to other tasks in the cluster. Since we allow evaluator not to use
|
|
# distribution strategies and as a result ops in the evaluator task may have
|
|
# unspecified devices. Those ops may end up on other tasks if we don't split
|
|
# the cluster.
|
|
# Note: if you bypass distribute coordinator and bring the cluster yourself,
|
|
# you can equivalently set device filters to split clusters. This is already
|
|
# done by distribution strategy's `update_config_proto` method.
|
|
new_cluster_spec = normalize_cluster_spec(cluster_spec).as_dict()
|
|
if task_type == _TaskType.EVALUATOR:
|
|
assert _TaskType.EVALUATOR in new_cluster_spec
|
|
new_cluster_spec = {
|
|
_TaskType.EVALUATOR: new_cluster_spec[_TaskType.EVALUATOR]
|
|
}
|
|
else:
|
|
new_cluster_spec.pop(_TaskType.EVALUATOR, None)
|
|
return normalize_cluster_spec(new_cluster_spec)
|
|
|
|
|
|
def _run_std_server(
|
|
cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None,
|
|
session_config=None,
|
|
rpc_layer=None,
|
|
environment=None,
|
|
):
|
|
"""Runs a standard server."""
|
|
# Check if the Server is already running. If so, assert that no
|
|
# configuration options have changed, and return the existing Server. This
|
|
# allows us to call `run_distribute_coordinator` multiple times.
|
|
if getattr(_thread_local, "server", None) is not None:
|
|
assert _thread_local.cluster_spec == cluster_spec
|
|
assert _thread_local.task_type == task_type
|
|
assert _thread_local.task_id == task_id
|
|
assert _thread_local.session_config_str == repr(session_config)
|
|
assert _thread_local.rpc_layer == rpc_layer
|
|
assert _thread_local.environment == environment
|
|
return _thread_local.server
|
|
else:
|
|
# This method is not thread-safe.
|
|
_thread_local.server_started = True
|
|
_thread_local.cluster_spec = cluster_spec
|
|
_thread_local.task_type = task_type
|
|
_thread_local.task_id = task_id
|
|
_thread_local.session_config_str = repr(session_config)
|
|
_thread_local.rpc_layer = rpc_layer
|
|
_thread_local.environment = environment
|
|
|
|
assert cluster_spec
|
|
target = cluster_spec.task_address(task_type, task_id)
|
|
if rpc_layer:
|
|
target = rpc_layer + "://" + target
|
|
|
|
class _FakeServer:
|
|
"""A fake server that runs a master session."""
|
|
|
|
def start(self):
|
|
# A tensorflow server starts when a remote session is created.
|
|
logging.info(
|
|
"Creating a remote session to start a TensorFlow server, "
|
|
"target = %r, session_config=%r",
|
|
target,
|
|
session_config,
|
|
)
|
|
tf.compat.v1.Session(target=target, config=session_config)
|
|
|
|
def join(self):
|
|
while True:
|
|
time.sleep(5)
|
|
|
|
if environment == "google":
|
|
server = _FakeServer()
|
|
else:
|
|
if session_config:
|
|
logging.info(
|
|
"Starting standard TensorFlow server, target = %r, "
|
|
"session_config = %r",
|
|
target,
|
|
session_config,
|
|
)
|
|
else:
|
|
logging.info(
|
|
"Starting standard TensorFlow server, target = %r", target
|
|
)
|
|
cluster_spec = _split_cluster_for_evaluator(cluster_spec, task_type)
|
|
server = tf.distribute.Server(
|
|
cluster_spec,
|
|
job_name=task_type,
|
|
task_index=task_id,
|
|
config=session_config,
|
|
protocol=rpc_layer,
|
|
)
|
|
|
|
server.start()
|
|
_thread_local.server = server
|
|
return server
|
|
|
|
|
|
def _configure_session_config_for_std_servers(
|
|
strategy, eval_strategy, session_config, cluster_spec, task_type, task_id
|
|
):
|
|
|
|
"""Call strategy's `configure` to mutate the session_config.
|
|
|
|
The session_config is currently needed as default config for a TensorFlow
|
|
server. In the future, we should be able to remove this method and only pass
|
|
the session config to a client session.
|
|
"""
|
|
if task_type == _TaskType.EVALUATOR:
|
|
if eval_strategy:
|
|
eval_strategy.configure(session_config=session_config)
|
|
else:
|
|
# The strategy may be shared in standalone client mode.
|
|
strategy = copy.deepcopy(strategy)
|
|
strategy.configure(
|
|
session_config=session_config,
|
|
cluster_spec=cluster_spec,
|
|
task_type=task_type,
|
|
task_id=task_id,
|
|
)
|
|
# Remove the device filters specific to the strategy, so that the
|
|
# TensorFlow server brought up with one strategy can be used by other
|
|
# strategies. The device filters can be set in the client side as well.
|
|
del session_config.device_filters[:]
|
|
|
|
|
|
# TODO(yuefengz): propagate cluster_spec in the STANDALONE_CLIENT mode.
|
|
# TODO(yuefengz): we may need a smart way to figure out whether the current task
|
|
# is the special task when we support cluster_spec propagation.
|
|
def run_distribute_coordinator(
|
|
worker_fn,
|
|
strategy,
|
|
eval_fn=None,
|
|
eval_strategy=None,
|
|
cluster_spec=None,
|
|
task_type=None,
|
|
task_id=None,
|
|
session_config=None,
|
|
rpc_layer="grpc",
|
|
):
|
|
"""Runs the coordinator for distributed TensorFlow.
|
|
|
|
This function runs a split coordinator for distributed TensorFlow in its
|
|
default mode, i.e the STANDALONE_CLIENT mode. Given a `cluster_spec`
|
|
specifying server addresses and their roles in a cluster, this coordinator
|
|
will figure out how to set them up, give the underlying function the right
|
|
targets for master sessions via a scope object and coordinate their
|
|
training. The cluster consisting of standard servers needs to be brought up
|
|
either with the standard server binary or with a binary running distribute
|
|
coordinator with `task_type` set to non-client type which will then turn
|
|
into standard servers.
|
|
|
|
In addition to be the distribute coordinator, this is also the source of
|
|
configurations for each job in the distributed training. As there are
|
|
multiple ways to configure a distributed TensorFlow cluster, its context
|
|
object provides these configurations so that users or higher-level APIs
|
|
don't have to figure out the configuration for each job by themselves.
|
|
|
|
In the between-graph replicated training, this coordinator will create
|
|
multiple threads and each calls the `worker_fn` which is supposed to create
|
|
its own graph and connect to one worker master given by its context object.
|
|
In the in-graph replicated training, it has only one thread calling this
|
|
`worker_fn`.
|
|
|
|
Another mode is the INDEPENDENT_WORKER mode where each server runs a
|
|
distribute coordinator which will start a standard server and optionally
|
|
runs `worker_fn` depending whether it is between-graph training or in-graph
|
|
replicated training.
|
|
|
|
The `strategy` object is expected to be a DistributionStrategy object which
|
|
has implemented methods needed by distributed coordinator such as
|
|
`configure(session_config, cluster_spec, task_type, task_id)` which
|
|
configures the strategy object for a specific task and
|
|
`experimental_should_init` property which instructs the distribute
|
|
coordinator whether to run init ops for a task. The distribute coordinator
|
|
will make a copy of the `strategy` object, call its `configure` method and
|
|
pass it to `worker_fn` as an argument.
|
|
|
|
The `worker_fn` defines the training logic and is called under its own
|
|
worker context which can be accessed to via `get_current_worker_context`. A
|
|
worker context provides access to configurations for each task, e.g. the
|
|
task_type, task_id, master target and so on. Since `worker_fn` will be
|
|
called in a thread and possibly multiple times, caller should be careful
|
|
when it accesses global data. For example, it is unsafe to define flags in a
|
|
`worker_fn` or to define different environment variables for different
|
|
`worker_fn`s.
|
|
|
|
The `worker_fn` for the between-graph replication is defined as if there is
|
|
only one worker corresponding to the `worker_fn` and possibly ps jobs. For
|
|
example, when training with parameter servers, it assigns variables to
|
|
parameter servers and all other operations to that worker. In the in-graph
|
|
replication case, the `worker_fn` has to define operations for all worker
|
|
jobs. Using a distribution strategy can simplify the `worker_fn` by not
|
|
having to worry about the replication and device assignment of variables and
|
|
operations.
|
|
|
|
This method is intended to be invoked by high-level APIs so that users don't
|
|
have to explicitly call it to run this coordinator. For those who don't use
|
|
high-level APIs, to change a program to use this coordinator, wrap
|
|
everything in a the program after global data definitions such as
|
|
commandline flag definition into the `worker_fn` and get task-specific
|
|
configurations from the worker context.
|
|
|
|
The `cluster_spec` can be either passed by the argument or parsed from the
|
|
"TF_CONFIG" environment variable. Example of a TF_CONFIG:
|
|
```
|
|
cluster = {'chief': ['host0:2222'],
|
|
'ps': ['host1:2222', 'host2:2222'],
|
|
'worker': ['host3:2222', 'host4:2222', 'host5:2222']}
|
|
os.environ['TF_CONFIG'] = json.dumps({'cluster': cluster})
|
|
```
|
|
|
|
If `cluster_spec` is not given in any format, it becomes local training and
|
|
this coordinator will connect to a local session.
|
|
|
|
For evaluation, if "evaluator" exists in the cluster_spec, a separate thread
|
|
will be created to call `eval_fn` with its `task_type` set to "evaluator".
|
|
If `eval_fn` is not defined, fall back to `worker_fn`. This implies that
|
|
evaluation will be done on a single machine if there is an "evaluator" task.
|
|
If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
|
|
`worker_fn` for how to do evaluation.
|
|
|
|
Args:
|
|
worker_fn: the function to be called. The function should accept a
|
|
`strategy` object and will be given access to a context object via a
|
|
context manager scope.
|
|
strategy: a DistributionStrategy object specifying whether it should run
|
|
between-graph replicated training or not, whether to run init ops, etc.
|
|
This object will also be configured given `session_config`,
|
|
`cluster_spec`, `task_type` and `task_id`.
|
|
eval_fn: optional function for "evaluator" task. If `eval_fn` is not
|
|
passed in but a "evaluator" task is found in the `cluster_spec`, the
|
|
`worker_fn` will be used for this task.
|
|
eval_strategy: optional DistributionStrategy object for "evaluator" task.
|
|
cluster_spec: a dict, ClusterDef or ClusterSpec specifying servers and
|
|
roles in a cluster. If not set or empty, fall back to local training.
|
|
task_type: the current task type, optional if this is a client.
|
|
task_id: the current task id, optional if this is a client.
|
|
session_config: an optional `tf.compat.v1.ConfigProto` object which will
|
|
be passed to `strategy`'s `configure` method and used to create a
|
|
session.
|
|
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
|
|
|
|
Raises:
|
|
ValueError: if `cluster_spec` is supplied but not a dict or a ClusterDef
|
|
or a ClusterSpec.
|
|
|
|
Returns:
|
|
In the client job, return the value returned by `worker_fn` if
|
|
it is in-graph replication or INDEPENDENT_WORKER mode; return None
|
|
otherwise.
|
|
"""
|
|
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
|
|
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
|
|
environment = tf_config.get("environment", None)
|
|
|
|
if not cluster_spec:
|
|
cluster_spec = tf_config.get("cluster", {})
|
|
task_env = tf_config.get("task", {})
|
|
if task_env:
|
|
task_type = task_env.get("type", task_type)
|
|
task_id = int(task_env.get("index", task_id))
|
|
|
|
if cluster_spec:
|
|
# TODO(yuefengz): validate cluster_spec.
|
|
cluster_spec = normalize_cluster_spec(cluster_spec)
|
|
elif hasattr(strategy.extended, "_cluster_resolver"):
|
|
cluster_resolver = strategy.extended._cluster_resolver
|
|
task_type = cluster_resolver.task_type
|
|
task_id = cluster_resolver.task_id
|
|
rpc_layer = cluster_resolver.rpc_layer or rpc_layer
|
|
environment = cluster_resolver.environment
|
|
cluster_spec = cluster_resolver.cluster_spec()
|
|
|
|
# Setting the session config is necessary for some strategies such as
|
|
# CollectiveAllReduceStrategy.
|
|
session_config = session_config or tf.compat.v1.ConfigProto(
|
|
allow_soft_placement=True
|
|
)
|
|
|
|
if cluster_spec:
|
|
logging.info(
|
|
"Running Distribute Coordinator with cluster_spec = %r, "
|
|
"task_type = %r, task_id = %r, environment = %r, rpc_layer = %r",
|
|
cluster_spec.as_dict(),
|
|
task_type,
|
|
task_id,
|
|
environment,
|
|
rpc_layer,
|
|
)
|
|
|
|
if not cluster_spec:
|
|
# `mode` is ignored in the local case.
|
|
logging.info("Running local Distribute Coordinator.")
|
|
_run_single_worker(
|
|
worker_fn, strategy, None, None, None, session_config, rpc_layer
|
|
)
|
|
if eval_fn:
|
|
_run_single_worker(
|
|
eval_fn,
|
|
eval_strategy,
|
|
None,
|
|
None,
|
|
None,
|
|
session_config,
|
|
rpc_layer,
|
|
)
|
|
else:
|
|
logging.warning(
|
|
"Skipped evaluation since `eval_fn` is not passed in."
|
|
)
|
|
else:
|
|
if not eval_fn:
|
|
logging.warning(
|
|
"`eval_fn` is not passed in. The `worker_fn` will be "
|
|
'used if an "evaluator" task exists in the cluster.'
|
|
)
|
|
eval_fn = eval_fn or worker_fn
|
|
if not eval_strategy:
|
|
logging.warning(
|
|
"`eval_strategy` is not passed in. No distribution "
|
|
"strategy will be used for evaluation."
|
|
)
|
|
|
|
# Every one starts a standard server, get session config from
|
|
# `configure` method.
|
|
_configure_session_config_for_std_servers(
|
|
strategy,
|
|
eval_strategy,
|
|
session_config,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
)
|
|
|
|
if task_type != _TaskType.EVALUATOR and not getattr(
|
|
strategy.extended, "_std_server_started", False
|
|
):
|
|
# Right now, with eager mode, context is configured with a std
|
|
# server at the very beginning while with graph mode the std server
|
|
# is started when distribute coordinator is called. We should
|
|
# consolidate these two paths.
|
|
server = _run_std_server(
|
|
cluster_spec=cluster_spec,
|
|
task_type=task_type,
|
|
task_id=task_id,
|
|
session_config=session_config,
|
|
rpc_layer=rpc_layer,
|
|
environment=environment,
|
|
)
|
|
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
|
|
if strategy.extended.experimental_between_graph:
|
|
# All jobs run `worker_fn` if between-graph.
|
|
return _run_single_worker(
|
|
worker_fn,
|
|
strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config,
|
|
rpc_layer,
|
|
)
|
|
else:
|
|
# Only one node runs `worker_fn` if in-graph.
|
|
context = _WorkerContext(
|
|
strategy, cluster_spec, task_type, task_id
|
|
)
|
|
if context.is_chief:
|
|
return _run_single_worker(
|
|
worker_fn,
|
|
strategy,
|
|
cluster_spec,
|
|
None,
|
|
None,
|
|
session_config,
|
|
rpc_layer,
|
|
)
|
|
else:
|
|
server.join()
|
|
elif task_type == _TaskType.EVALUATOR:
|
|
return _run_single_worker(
|
|
eval_fn,
|
|
eval_strategy,
|
|
cluster_spec,
|
|
task_type,
|
|
task_id,
|
|
session_config,
|
|
rpc_layer,
|
|
)
|
|
else:
|
|
if task_type != _TaskType.PS:
|
|
raise ValueError(f"Unexpected task_type: {task_type!r}")
|
|
server.join()
|
|
|
|
|
|
def normalize_cluster_spec(cluster_spec):
|
|
"""Makes `cluster_spec` into a `ClusterSpec` object.
|
|
|
|
Args:
|
|
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
|
|
cluster configurations.
|
|
|
|
Returns:
|
|
a `ClusterSpec` object.
|
|
|
|
Raises:
|
|
ValueError: if `cluster_spec` is not a dict or a `ClusterSpec` or a
|
|
`ClusterDef`.
|
|
"""
|
|
if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
|
|
return tf.train.ClusterSpec(cluster_spec)
|
|
elif not isinstance(cluster_spec, tf.train.ClusterSpec):
|
|
raise ValueError(
|
|
"`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
|
|
"`tf.train.ClusterDef` object"
|
|
)
|
|
return cluster_spec
|