# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A Python interface for creating TensorFlow servers.""" from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import device_filters_pb2 from tensorflow.core.protobuf import tensorflow_server_pb2 from tensorflow.python.client import pywrap_tf_session as c_api from tensorflow.python.framework import errors from tensorflow.python.util import compat from tensorflow.python.util import deprecation from tensorflow.python.util.tf_export import tf_export def _make_server_def(server_or_cluster_def, job_name, task_index, protocol, config): """Creates a `tf.train.ServerDef` protocol buffer. Args: server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` protocol buffer, or a `tf.train.ClusterSpec` object, describing the server to be defined and/or the cluster of which it is a member. job_name: (Optional.) Specifies the name of the job of which the server is a member. Defaults to the value in `server_or_cluster_def`, if specified. task_index: (Optional.) Specifies the task index of the server in its job. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to 0 if the server's job has only one task. protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default configuration options for all sessions that run on this server. Returns: A `tf.train.ServerDef`. Raises: TypeError: If the arguments do not have the appropriate type. ValueError: If an argument is not specified and cannot be inferred. """ server_def = tensorflow_server_pb2.ServerDef() if isinstance(server_or_cluster_def, tensorflow_server_pb2.ServerDef): server_def.MergeFrom(server_or_cluster_def) if job_name is not None: server_def.job_name = job_name if task_index is not None: server_def.task_index = task_index if protocol is not None: server_def.protocol = protocol if config is not None: server_def.default_session_config.MergeFrom(config) else: try: cluster_spec = ClusterSpec(server_or_cluster_def) except TypeError: raise TypeError("Could not convert `server_or_cluster_def` to a " "`tf.train.ServerDef` or `tf.train.ClusterSpec`.") if job_name is None: if len(cluster_spec.jobs) == 1: job_name = cluster_spec.jobs[0] else: raise ValueError("Must specify an explicit `job_name`.") if task_index is None: task_indices = cluster_spec.task_indices(job_name) if len(task_indices) == 1: task_index = task_indices[0] else: raise ValueError("Must specify an explicit `task_index`.") if protocol is None: protocol = "grpc" server_def = tensorflow_server_pb2.ServerDef( cluster=cluster_spec.as_cluster_def(), job_name=job_name, task_index=task_index, protocol=protocol) if config is not None: server_def.default_session_config.MergeFrom(config) return server_def @tf_export("distribute.Server", v1=["distribute.Server", "train.Server"]) @deprecation.deprecated_endpoints("train.Server") class Server: """An in-process TensorFlow server, for use in distributed training. A `tf.distribute.Server` instance encapsulates a set of devices and a `tf.compat.v1.Session` target that can participate in distributed training. A server belongs to a cluster (specified by a `tf.train.ClusterSpec`), and corresponds to a particular task in a named job. The server can communicate with any other server in the same cluster. """ def __init__(self, server_or_cluster_def, job_name=None, task_index=None, protocol=None, config=None, start=True): """Creates a new server with the given definition. The `job_name`, `task_index`, and `protocol` arguments are optional, and override any information provided in `server_or_cluster_def`. Args: server_or_cluster_def: A `tf.train.ServerDef` or `tf.train.ClusterDef` protocol buffer, or a `tf.train.ClusterSpec` object, describing the server to be created and/or the cluster of which it is a member. job_name: (Optional.) Specifies the name of the job of which the server is a member. Defaults to the value in `server_or_cluster_def`, if specified. task_index: (Optional.) Specifies the task index of the server in its job. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to 0 if the server's job has only one task. protocol: (Optional.) Specifies the protocol to be used by the server. Acceptable values include `"grpc", "grpc+verbs"`. Defaults to the value in `server_or_cluster_def`, if specified. Otherwise defaults to `"grpc"`. config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default configuration options for all sessions that run on this server. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to `True`. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while creating the TensorFlow server. """ self._server_def = _make_server_def(server_or_cluster_def, job_name, task_index, protocol, config) self._server = c_api.TF_NewServer(self._server_def.SerializeToString()) if start: self.start() def __del__(self): # At shutdown, `errors` may have been garbage collected. if errors is not None: exception = errors.UnimplementedError else: exception = Exception try: c_api.TF_ServerStop(self._server) # Clean shutdown of servers is not yet implemented, so # we leak instead of calling c_api.TF_DeleteServer here. # See: # https://github.com/tensorflow/tensorflow/blob/0495317a6e9dd4cac577b9d5cf9525e62b571018/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h#L73 except AttributeError: # At shutdown, `c_api` may have been garbage collected. pass except exception: pass self._server = None def start(self): """Starts this server. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while starting the TensorFlow server. """ c_api.TF_ServerStart(self._server) def join(self): """Blocks until the server has shut down. This method currently blocks forever. Raises: tf.errors.OpError: Or one of its subclasses if an error occurs while joining the TensorFlow server. """ c_api.TF_ServerJoin(self._server) @property def server_def(self): """Returns the `tf.train.ServerDef` for this server. Returns: A `tf.train.ServerDef` protocol buffer that describes the configuration of this server. """ return self._server_def @property def target(self): """Returns the target for a `tf.compat.v1.Session` to connect to this server. To create a `tf.compat.v1.Session` that connects to this server, use the following snippet: ```python server = tf.distribute.Server(...) with tf.compat.v1.Session(server.target): # ... ``` Returns: A string containing a session target for this server. """ return c_api.TF_ServerTarget(self._server) @staticmethod def create_local_server(config=None, start=True): """Creates a new single-process cluster running on the local host. This method is a convenience wrapper for creating a `tf.distribute.Server` with a `tf.train.ServerDef` that specifies a single-process cluster containing a single task in a job called `"local"`. Args: config: (Options.) A `tf.compat.v1.ConfigProto` that specifies default configuration options for all sessions that run on this server. start: (Optional.) Boolean, indicating whether to start the server after creating it. Defaults to `True`. Returns: A local `tf.distribute.Server`. """ # Specifying port 0 means that the OS will choose a free port for the # server. return Server({"localhost": ["localhost:0"]}, protocol="grpc", config=config, start=start) @tf_export("train.ClusterSpec") class ClusterSpec: """Represents a cluster as a set of "tasks", organized into "jobs". A `tf.train.ClusterSpec` represents the set of processes that participate in a distributed TensorFlow computation. Every `tf.distribute.Server` is constructed in a particular cluster. To create a cluster with two jobs and five tasks, you specify the mapping from job names to lists of network addresses (typically hostname-port pairs). ```python cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222", "worker1.example.com:2222", "worker2.example.com:2222"], "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]}) ``` Each job may also be specified as a sparse mapping from task indices to network addresses. This enables a server to be configured without needing to know the identity of (for example) all other worker tasks: ```python cluster = tf.train.ClusterSpec({"worker": {1: "worker1.example.com:2222"}, "ps": ["ps0.example.com:2222", "ps1.example.com:2222"]}) ``` """ def __init__(self, cluster): """Creates a `ClusterSpec`. Args: cluster: A dictionary mapping one or more job names to (i) a list of network addresses, or (ii) a dictionary mapping integer task indices to network addresses; or a `tf.train.ClusterDef` protocol buffer. Raises: TypeError: If `cluster` is not a dictionary mapping strings to lists of strings, and not a `tf.train.ClusterDef` protobuf. """ if isinstance(cluster, dict): self._cluster_spec = {} for job_name, tasks in cluster.items(): if isinstance(tasks, (list, tuple)): job_tasks = {i: task for i, task in enumerate(tasks)} elif isinstance(tasks, dict): job_tasks = {int(i): task for i, task in tasks.items()} else: raise TypeError("The tasks for job %r must be a list or a dictionary " "from integers to strings." % job_name) self._cluster_spec[job_name] = job_tasks self._make_cluster_def() elif isinstance(cluster, cluster_pb2.ClusterDef): self._cluster_def = cluster self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items() } elif isinstance(cluster, ClusterSpec): self._cluster_def = cluster_pb2.ClusterDef() self._cluster_def.MergeFrom(cluster.as_cluster_def()) self._cluster_spec = {} for job_def in self._cluster_def.job: self._cluster_spec[job_def.name] = { i: t for i, t in job_def.tasks.items() } else: raise TypeError("`cluster` must be a dictionary mapping one or more " "job names to lists of network addresses, or a " "`ClusterDef` protocol buffer") def __bool__(self): return bool(self._cluster_spec) # Python 2.x __nonzero__ = __bool__ def __eq__(self, other): return self._cluster_spec == other def __ne__(self, other): return self._cluster_spec != other def __repr__(self): key_values = self.as_dict() string_items = [ repr(k) + ": " + repr(key_values[k]) for k in sorted(key_values) ] return "ClusterSpec({" + ", ".join(string_items) + "})" def as_dict(self): """Returns a dictionary from job names to their tasks. For each job, if the task index space is dense, the corresponding value will be a list of network addresses; otherwise it will be a dictionary mapping (sparse) task indices to the corresponding addresses. Returns: A dictionary mapping job names to lists or dictionaries describing the tasks in those jobs. """ ret = {} for job in self.jobs: task_indices = self.task_indices(job) if len(task_indices) == 0: ret[job] = {} continue if max(task_indices) + 1 == len(task_indices): # Return a list because the task indices are dense. This # matches the behavior of `as_dict()` before support for # sparse jobs was added. ret[job] = self.job_tasks(job) else: ret[job] = {i: self.task_address(job, i) for i in task_indices} return ret def as_cluster_def(self): """Returns a `tf.train.ClusterDef` protocol buffer based on this cluster.""" return self._cluster_def @property def jobs(self): """Returns a list of job names in this cluster. Returns: A list of strings, corresponding to the names of jobs in this cluster. """ return list(self._cluster_spec.keys()) def num_tasks(self, job_name): """Returns the number of tasks defined in the given job. Args: job_name: The string name of a job in this cluster. Returns: The number of tasks defined in the given job. Raises: ValueError: If `job_name` does not name a job in this cluster. """ try: job = self._cluster_spec[job_name] except KeyError: raise ValueError("No such job in cluster: %r" % job_name) return len(job) def task_indices(self, job_name): """Returns a list of valid task indices in the given job. Args: job_name: The string name of a job in this cluster. Returns: A list of valid task indices in the given job. Raises: ValueError: If `job_name` does not name a job in this cluster, or no task with index `task_index` is defined in that job. """ try: job = self._cluster_spec[job_name] except KeyError: raise ValueError("No such job in cluster: %r" % job_name) return list(sorted(job.keys())) def task_address(self, job_name, task_index): """Returns the address of the given task in the given job. Args: job_name: The string name of a job in this cluster. task_index: A non-negative integer. Returns: The address of the given task in the given job. Raises: ValueError: If `job_name` does not name a job in this cluster, or no task with index `task_index` is defined in that job. """ try: job = self._cluster_spec[job_name] except KeyError: raise ValueError("No such job in cluster: %r" % job_name) try: return job[task_index] except KeyError: raise ValueError("No task with index %r in job %r" % (task_index, job_name)) def job_tasks(self, job_name): """Returns a mapping from task ID to address in the given job. NOTE: For backwards compatibility, this method returns a list. If the given job was defined with a sparse set of task indices, the length of this list may not reflect the number of tasks defined in this job. Use the `tf.train.ClusterSpec.num_tasks` method to find the number of tasks defined in a particular job. Args: job_name: The string name of a job in this cluster. Returns: A list of task addresses, where the index in the list corresponds to the task index of each task. The list may contain `None` if the job was defined with a sparse set of task indices. Raises: ValueError: If `job_name` does not name a job in this cluster. """ try: job = self._cluster_spec[job_name] except KeyError: raise ValueError("No such job in cluster: %r" % job_name) ret = [None for _ in range(max(job.keys()) + 1)] for i, task in job.items(): ret[i] = task return ret def _make_cluster_def(self): """Creates a `tf.train.ClusterDef` based on the given `cluster_spec`. Raises: TypeError: If `cluster_spec` is not a dictionary mapping strings to lists of strings. """ self._cluster_def = cluster_pb2.ClusterDef() # NOTE(mrry): Sort by job_name to produce deterministic protobufs. for job_name, tasks in sorted(self._cluster_spec.items()): try: job_name = compat.as_bytes(job_name) except TypeError: raise TypeError("Job name %r must be bytes or unicode" % job_name) job_def = self._cluster_def.job.add() job_def.name = job_name for i, task_address in sorted(tasks.items()): try: task_address = compat.as_bytes(task_address) except TypeError: raise TypeError("Task address %r must be bytes or unicode" % task_address) job_def.tasks[i] = task_address @tf_export("config.experimental.ClusterDeviceFilters") class ClusterDeviceFilters: """Represent a collection of device filters for the remote workers in cluster. NOTE: this is an experimental API and subject to changes. Set device filters for selective jobs and tasks. For each remote worker, the device filters are a list of strings. When any filters are present, the remote worker will ignore all devices which do not match any of its filters. Each filter can be partially specified, e.g. "/job:ps", "/job:worker/replica:3", etc. Note that a device is always visible to the worker it is located on. For example, to set the device filters for a parameter server cluster: ```python cdf = tf.config.experimental.ClusterDeviceFilters() for i in range(num_workers): cdf.set_device_filters('worker', i, ['/job:ps']) for i in range(num_ps): cdf.set_device_filters('ps', i, ['/job:worker']) tf.config.experimental_connect_to_cluster(cluster_def, cluster_device_filters=cdf) ``` The device filters can be partically specified. For remote tasks that do not have device filters specified, all devices will be visible to them. """ def __init__(self): # `_device_filters` is a dict mapping job names to job device filters. # Job device filters further maps task IDs to task device filters. # Task device filters are a list of strings, each one is a device filter. self._device_filters = {} # Serialized protobuf for cluster device filters. self._cluster_device_filters = None def set_device_filters(self, job_name, task_index, device_filters): """Set the device filters for given job name and task id.""" assert all(isinstance(df, str) for df in device_filters) self._device_filters.setdefault(job_name, {}) self._device_filters[job_name][task_index] = [df for df in device_filters] # Due to updates in data, invalidate the serialized proto cache. self._cluster_device_filters = None def _as_cluster_device_filters(self): """Returns a serialized protobuf of cluster device filters.""" if self._cluster_device_filters: return self._cluster_device_filters self._make_cluster_device_filters() return self._cluster_device_filters def _make_cluster_device_filters(self): """Creates `ClusterDeviceFilters` proto based on the `_device_filters`. Raises: TypeError: If `_device_filters` is not a dictionary mapping strings to a map of task indices and device filters. """ self._cluster_device_filters = device_filters_pb2.ClusterDeviceFilters() # Sort by job_name to produce deterministic protobufs. for job_name, tasks in sorted(self._device_filters.items()): try: job_name = compat.as_bytes(job_name) except TypeError: raise TypeError("Job name %r must be bytes or unicode" % job_name) jdf = self._cluster_device_filters.jobs.add() jdf.name = job_name for i, task_device_filters in sorted(tasks.items()): for tdf in task_device_filters: try: tdf = compat.as_bytes(tdf) except TypeError: raise TypeError("Device filter %r must be bytes or unicode" % tdf) jdf.tasks[i].device_filters.append(tdf)