# Copyright 2021 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """tf.data service test-cluster with local and remote workers.""" import tempfile from tensorflow.core.protobuf import data_service_pb2 from tensorflow.core.protobuf import service_config_pb2 from tensorflow.python.data.experimental.kernel_tests.service import test_base as data_service_test_base from tensorflow.python.data.experimental.service import server_lib from tensorflow.python.distribute import multi_process_lib from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest _WORKER_SHUTDOWN_QUIET_PERIOD_MS = 100 # pylint: disable=protected-access class _RemoteWorkerProcess(multi_process_lib.Process): """Runs a worker server in a new process to simulate a remote worker.""" def __init__(self, dispatcher_address, port, worker_tags, pipe_writer): super(_RemoteWorkerProcess, self).__init__() self._dispatcher_address = dispatcher_address self._port = port self._worker_tags = worker_tags self._pipe_writer = pipe_writer def run(self): self.start_worker() def start_worker(self): self._worker = data_service_test_base.TestWorker( self._dispatcher_address, _WORKER_SHUTDOWN_QUIET_PERIOD_MS, port=self._port, worker_tags=self._worker_tags) self._worker.start() self._pipe_writer.send(self._worker.worker_address()) self._worker.join() class MultiProcessCluster: """tf.data service cluster with local and remote workers. Represents a cluster with a dispatcher, `num_local_workers` local workers, and `num_remote_workers` remote workers. Remote workers run in separate processes. This is useful to test reading from local in-process workers. For example: ``` cluster = multi_process_cluster.MultiProcessCluster( num_local_workers=1, num_remote_workers=3) num_elements = 10 dataset = self.make_distributed_range_dataset( num_elements, cluster, target_workers="LOCAL") self.assertDatasetProduces(dataset, list(range(num_elements))) ``` """ def __init__(self, num_local_workers, num_remote_workers, worker_tags=None, worker_addresses=None, deployment_mode=data_service_pb2.DEPLOYMENT_MODE_COLOCATED): self._work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._deployment_mode = deployment_mode self._start_dispatcher(worker_addresses) self._start_local_workers(num_local_workers, worker_tags) self._start_remote_workers(num_remote_workers, worker_tags) def _start_dispatcher(self, worker_addresses, port=0): if port == 0: port = test_util.pick_unused_port() self._dispatcher = server_lib.DispatchServer( service_config_pb2.DispatcherConfig( port=port, protocol="grpc", work_dir=self._work_dir, fault_tolerant_mode=True, worker_addresses=worker_addresses, deployment_mode=self._deployment_mode), start=True) def _start_local_workers(self, num_workers, worker_tags=None): self._local_workers = [] for _ in range(num_workers): self.start_local_worker(worker_tags) def _start_remote_workers(self, num_workers, worker_tags=None): # List of (worker address, remote worker process) tuples. self._remote_workers = [] for _ in range(num_workers): self.start_remote_worker(worker_tags) def start_local_worker(self, worker_tags=None): worker = data_service_test_base.TestWorker( self.dispatcher_address(), _WORKER_SHUTDOWN_QUIET_PERIOD_MS, port=test_util.pick_unused_port(), worker_tags=worker_tags) worker.start() self._local_workers.append(worker) def start_remote_worker(self, worker_tags=None): """Runs a tf.data service worker in a remote process.""" pipe_reader, pipe_writer = multi_process_lib.multiprocessing.Pipe( duplex=False) worker_process = _RemoteWorkerProcess( self.dispatcher_address(), port=test_util.pick_unused_port(), worker_tags=worker_tags, pipe_writer=pipe_writer) worker_process.start() worker_address = pipe_reader.recv() self._remote_workers.append((worker_address, worker_process)) def restart_dispatcher(self): port = int(self.dispatcher_address().split(":")[1]) self._dispatcher._stop() self._start_dispatcher( worker_addresses=(self.local_worker_addresses() + self.remote_worker_addresses()), port=port) def restart_local_workers(self): for worker in self._local_workers: worker.restart() def dispatcher_address(self): return self._dispatcher._address def local_worker_addresses(self): return [worker.worker_address() for worker in self._local_workers] def remote_worker_addresses(self): return [worker_address for (worker_address, _) in self._remote_workers] def _stop(self): for worker in self._local_workers: worker.stop() for (_, worker_process) in self._remote_workers: worker_process.kill() self._dispatcher._stop() def __del__(self): self._stop() def test_main(): """Main function to be called within `__main__` of a test file.""" multi_process_lib.test_main()