# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Helpers to connect to remote servers.""" import copy from absl import logging from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef from tensorflow.python import pywrap_tfe from tensorflow.python.distribute import device_util from tensorflow.python.distribute.cluster_resolver import cluster_resolver from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.platform import remote_utils from tensorflow.python.training import server_lib from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export _GRPC_PREFIX = "grpc://" _LOCAL_MASTERS = ("", "local") @tf_export("config.experimental_connect_to_host") def connect_to_remote_host(remote_host=None, job_name="worker"): """Connects to a single machine to enable remote execution on it. Will make devices on the remote host available to use. Note that calling this more than once will work, but will invalidate any tensor handles on the old remote devices. Using the default job_name of worker, you can schedule ops to run remotely as follows: ```python # When eager execution is enabled, connect to the remote host. tf.config.experimental_connect_to_host("exampleaddr.com:9876") with ops.device("job:worker/replica:0/task:1/device:CPU:0"): # The following tensors should be resident on the remote device, and the op # will also execute remotely. x1 = array_ops.ones([2, 2]) x2 = array_ops.ones([2, 2]) y = math_ops.matmul(x1, x2) ``` Args: remote_host: a single or a list the remote server addr in host-port format. job_name: The job name under which the new server will be accessible. Raises: ValueError: if remote_host is None. """ if not remote_host: raise ValueError("Must provide at least one remote_host") remote_hosts = nest.flatten(remote_host) cluster_spec = server_lib.ClusterSpec( {job_name: [_strip_prefix(host, _GRPC_PREFIX) for host in remote_hosts]}) connect_to_cluster(cluster_spec) @tf_export("config.experimental_connect_to_cluster") def connect_to_cluster(cluster_spec_or_resolver, job_name="localhost", task_index=0, protocol=None, make_master_device_default=True, cluster_device_filters=None): """Connects to the given cluster. Will make devices on the cluster available to use. Note that calling this more than once will work, but will invalidate any tensor handles on the old remote devices. If the given local job name is not present in the cluster specification, it will be automatically added, using an unused port on the localhost. Device filters can be specified to isolate groups of remote tasks to avoid undesired accesses between workers. Workers accessing resources or launching ops / functions on filtered remote devices will result in errors (unknown devices). For any remote task, if no device filter is present, all cluster devices will be visible; if any device filter is specified, it can only see devices matching at least one filter. Devices on the task itself are always visible. Device filters can be particially specified. For example, for a cluster set up for parameter server training, the following device filters might be specified: ```python cdf = tf.config.experimental.ClusterDeviceFilters() # For any worker, only the devices on PS nodes and itself are visible for i in range(num_workers): cdf.set_device_filters('worker', i, ['/job:ps']) # Similarly for any ps, only the devices on workers and itself are visible for i in range(num_ps): cdf.set_device_filters('ps', i, ['/job:worker']) tf.config.experimental_connect_to_cluster(cluster_def, cluster_device_filters=cdf) ``` Args: cluster_spec_or_resolver: A `ClusterSpec` or `ClusterResolver` describing the cluster. job_name: The name of the local job. task_index: The local task index. protocol: The communication protocol, such as `"grpc"`. If unspecified, will use the default from `python/platform/remote_utils.py`. make_master_device_default: If True and a cluster resolver is passed, will automatically enter the master task device scope, which indicates the master becomes the default device to run ops. It won't do anything if a cluster spec is passed. Will throw an error if the caller is currently already in some device scope. cluster_device_filters: an instance of `tf.train.experimental/ClusterDeviceFilters` that specify device filters to the remote tasks in cluster. """ if not context.executing_eagerly(): raise ValueError( "`tf.config.experimental_connect_to_cluster` can only be called in " "eager mode." ) protocol = protocol or remote_utils.get_default_communication_protocol() if isinstance(cluster_spec_or_resolver, server_lib.ClusterSpec): cluster_spec = cluster_spec_or_resolver elif isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver): if cluster_spec_or_resolver.master() in _LOCAL_MASTERS: # Do nothing if the master is local. return cluster_spec = cluster_spec_or_resolver.cluster_spec() else: raise ValueError( "`cluster_spec_or_resolver` must be a `ClusterSpec` or a " "`ClusterResolver`.") cluster_def = copy.deepcopy(cluster_spec.as_cluster_def()) if cluster_device_filters: if isinstance(cluster_device_filters, server_lib.ClusterDeviceFilters): cluster_device_filters = copy.deepcopy( cluster_device_filters._as_cluster_device_filters()) # pylint: disable=protected-access else: raise ValueError("`cluster_device_filters` must be an instance of " "`tf.train.experimental.ClusterDeviceFilters`.") # Check whether the server def has changed. We need to do the check before the # local job is added to the cluster. is_server_def_changed = False current_server_def = context.get_server_def() if current_server_def and job_name not in cluster_spec.jobs: for i, job in enumerate(current_server_def.cluster.job): if job.name == job_name: del current_server_def.cluster.job[i] if (current_server_def is None or current_server_def.cluster != cluster_def or current_server_def.job_name != job_name or current_server_def.task_index != task_index): is_server_def_changed = True # Automatically add local job, if not part of the cluster spec. if job_name not in cluster_spec.jobs: local_port = pywrap_tfe.TF_PickUnusedPortOrDie() job_def = cluster_def.job.add() job_def.name = job_name # TODO(fishx): Update this to make sure remote worker has valid ip address # to connect with local. job_def.tasks[0] = "localhost:{}".format(local_port) if context.context().coordination_service is None: service_type = remote_utils.coordination_service_type(protocol) service_leader = "" # Maybe enable coordination service for the communication protocol # TODO(b/243839559): Fix UPTC + Coordination service crashing # Check if cluster_spec_or_resolver is an instance of # tpu_cluster_resolver.TPUClusterResolver if (isinstance(cluster_spec_or_resolver, cluster_resolver.ClusterResolver) and hasattr(cluster_spec_or_resolver, "tpu_hardware_feature")): service_leader = cluster_spec_or_resolver.get_coordination_service_leader( ) # Maybe enable coordination service internally. if cluster_spec_or_resolver.environment == "google": is_uptc_sess = ".uptc-worker." in cluster_spec_or_resolver.master() service_type = remote_utils.coordination_service_type( protocol, is_uptc_sess) # Enable coordination service for Cloud TPU. else: service_type = "standalone" if service_type: # If `enable_health_check` is true, coordination service agent would # do connecting (and tasks would send heartbeat if connection is set up) # while creating eager contexts. Enabling health check does not mutate # coordination service. context.context().configure_coordination_service( service_type=service_type, service_leader=service_leader, enable_health_check=False) default_session_config = copy.deepcopy(context.context().config) for name in cluster_spec.jobs: # assuming any of the non-local job is the worker jobs. # should we use cluster_spec_or_resolver.get_job_name() instead when # it is available? # maybe consolicate this with the 'master' logic below if name == job_name: continue default_session_config.experimental.collective_group_leader = ( f"/job:{name}/replica:0/task:0" ) logging.info("default session config: %s", default_session_config) server_def = ServerDef( cluster=cluster_def, job_name=job_name, task_index=task_index, protocol=protocol, default_session_config=default_session_config, cluster_device_filters=cluster_device_filters, ) if is_server_def_changed: context.set_server_def(server_def) else: context.update_server_def(server_def) if make_master_device_default and isinstance( cluster_spec_or_resolver, cluster_resolver.ClusterResolver) and cluster_spec_or_resolver.master(): master = cluster_spec_or_resolver.master() master_job_name = None master_task_id = None for job_name in cluster_spec.jobs: for task_id in cluster_spec.task_indices(job_name): task_address = cluster_spec.task_address(job_name, task_id) if master in task_address or task_address in master: master_job_name = job_name master_task_id = task_id break if not master_job_name: raise ValueError( "`make_master_device_default` is set to True but cannot find " "master %s in the cluster" % master) master_device = "/job:{}/replica:0/task:{}".format(master_job_name, master_task_id) master_device = device_util.canonicalize(master_device) current_device = device_util.current() if current_device: current_device = device_util.canonicalize(current_device) if current_device and current_device != master_device: raise ValueError("`connect_to_cluster` is called inside existing device " "scope %s, which is different from the master device " "scope %s to enter. This is not allowed." % (current_device, master_device)) # TODO(b/138389076): Think of the entering device scope behavior in the # failure recovery case when dealing with preemptions. if not current_device: logging.info("Entering into master device scope: %s", master_device) ops.device(master_device).__enter__() def _strip_prefix(s, prefix): return s[len(prefix):] if s.startswith(prefix) else s