# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Test base for tf.data service tests.""" import tempfile from tensorflow.core.protobuf import service_config_pb2 from tensorflow.python.data.experimental.ops import data_service_ops from tensorflow.python.data.experimental.service import server_lib from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations from tensorflow.python.framework import dtypes from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest # This will be resolved to a tmp directory by `start_dispatch_server`. TMP_WORK_DIR = "tmp_work_dir_placeholder" # `""` indicates not to use a work directory. NO_WORK_DIR = "" # We use a faster than normal heartbeat interval so that tests run faster. TEST_HEARTBEAT_INTERVAL_MS = 100 TEST_DISPATCHER_TIMEOUT_MS = 1000 TEST_WORKER_TIMEOUT_MS = 200 TEST_JOB_GC_CHECK_INTERNAL_MS = 1000 PROTOCOL = "grpc" def all_cluster_configurations(): with_work_dir = combinations.combine( work_dir=TMP_WORK_DIR, fault_tolerant_mode=[True, False]) without_work_dir = combinations.combine( work_dir=NO_WORK_DIR, fault_tolerant_mode=False) return with_work_dir + without_work_dir def _make_worker(dispatcher_address, data_transfer_protocol, shutdown_quiet_period_ms=0, port=0, worker_tags=None, cross_trainer_cache_size_bytes=None): """Creates a worker server.""" defaults = server_lib.WorkerConfig(dispatcher_address=dispatcher_address) config_proto = service_config_pb2.WorkerConfig( dispatcher_address=dispatcher_address, worker_address=defaults.worker_address, port=port, protocol=PROTOCOL, worker_tags=worker_tags, heartbeat_interval_ms=TEST_HEARTBEAT_INTERVAL_MS, dispatcher_timeout_ms=TEST_DISPATCHER_TIMEOUT_MS, data_transfer_protocol=data_transfer_protocol, data_transfer_address=defaults.worker_address, shutdown_quiet_period_ms=shutdown_quiet_period_ms, cross_trainer_cache_size_bytes=cross_trainer_cache_size_bytes) return server_lib.WorkerServer(config_proto, start=False) # pylint: disable=protected-access class TestWorker: """A tf.data service worker.""" def __init__(self, dispatcher_address, shutdown_quiet_period_ms, data_transfer_protocol=None, port=0, worker_tags=None, cross_trainer_cache_size_bytes=None): self._dispatcher_address = dispatcher_address self._shutdown_quiet_period_ms = shutdown_quiet_period_ms self._server = _make_worker( dispatcher_address, data_transfer_protocol, shutdown_quiet_period_ms, port=port, worker_tags=worker_tags, cross_trainer_cache_size_bytes=cross_trainer_cache_size_bytes) self._running = False self._data_transfer_protocol = data_transfer_protocol def stop(self): self._server._stop() self._running = False def start(self): self._server.start() self._port = int(self._server._address.split(":")[1]) self._running = True def restart(self, use_same_port=True): """Restarts the worker, stopping it first if it is already running.""" if self._running: self.stop() port = 0 if use_same_port: port = self._port self._server = _make_worker(self._dispatcher_address, self._data_transfer_protocol, self._shutdown_quiet_period_ms, port) self._server.start() self._port = int(self._server._address.split(":")[1]) self._running = True def join(self): self._server.join() def num_tasks(self): return self._server._num_tasks() def snapshot_task_progresses(self): return self._server._snapshot_task_progresses() def worker_address(self): return self._server._address class TestCluster: """Test tf.data service cluster.""" def __init__( self, num_workers, dispatcher_port=0, work_dir=TMP_WORK_DIR, fault_tolerant_mode=True, job_gc_check_interval_ms=TEST_JOB_GC_CHECK_INTERNAL_MS, job_gc_timeout_ms=None, worker_timeout_ms=TEST_WORKER_TIMEOUT_MS, worker_shutdown_quiet_period_ms=0, start=True, data_transfer_protocol=None, ): """Creates a tf.data service test cluster. Args: num_workers: The number of workers to initially add to the cluster. dispatcher_port: The port to use for the dispatcher. work_dir: The work directory to use for the dispatcher. If set to `TMP_WORK_DIR`, the cluster will create a new temporary directory to use as the work directory. If set to `NO_WORK_DIR`, no work directory will be used. fault_tolerant_mode: Whether the dispatcher should write its state to a journal so that it can recover from restarts. job_gc_check_interval_ms: How often the dispatcher should scan through to delete old and unused jobs, in milliseconds. job_gc_timeout_ms: How long a job needs to be unused before it becomes a candidate for garbage collection, in milliseconds. worker_timeout_ms: How long to wait for a worker to heartbeat before considering it missing, in milliseconds. worker_shutdown_quiet_period_ms: When shutting down a worker, how long to wait for the gRPC server to process the final requests. start: Whether to immediately start the servers in the cluster. If `False`, the servers can be started later by calling `start_dispatcher()` and `start_workers()`. data_transfer_protocol: (Optional.) The protocol to use for transferring data with the tf.data service. """ if work_dir == TMP_WORK_DIR: work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir()) self._worker_shutdown_quiet_period_ms = worker_shutdown_quiet_period_ms self._data_transfer_protocol = data_transfer_protocol self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=dispatcher_port, work_dir=work_dir, protocol=PROTOCOL, fault_tolerant_mode=fault_tolerant_mode, job_gc_check_interval_ms=job_gc_check_interval_ms, job_gc_timeout_ms=job_gc_timeout_ms, worker_timeout_ms=worker_timeout_ms, ), start=start, ) self.workers = [] for _ in range(num_workers): self.add_worker(start=start) def dispatcher_address(self): return self.dispatcher.target.split("://")[1] def add_worker(self, start=True): worker = TestWorker(self.dispatcher_address(), self._worker_shutdown_quiet_period_ms, self._data_transfer_protocol) if start: worker.start() self.workers.append(worker) def start_dispatcher(self): self.dispatcher.start() def start_workers(self): for worker in self.workers: worker.start() def stop_dispatcher(self): # pylint: disable=protected-access self.dispatcher._stop() def stop_worker(self, index): self.workers[index].stop() def stop_workers(self): for worker in self.workers: worker.stop() # pylint: disable=protected-access def restart_dispatcher(self): """Stops `dispatcher` and creates a new dispatcher with the same port. Restarting is supported only when the dispatcher is configured with `fault_tolerant_mode=True`. """ if not self.dispatcher._config.fault_tolerant_mode: raise ValueError( "Trying to restart the dispatcher without fault-tolerance.") port = int(self.dispatcher_address().split(":")[1]) self.dispatcher._stop() self.dispatcher = server_lib.DispatchServer( server_lib.DispatcherConfig( port=port, work_dir=self.dispatcher._config.work_dir, protocol=PROTOCOL, fault_tolerant_mode=self.dispatcher._config.fault_tolerant_mode)) def num_registered_workers(self): return self.dispatcher._num_workers() def num_tasks_on_workers(self): return sum(worker.num_tasks() for worker in self.workers) def snapshot_streams(self, path): return self.dispatcher._snapshot_streams(path) def __del__(self): # Destroy workers before the dispatcher for clean shutdown. self.workers.clear() del self.dispatcher class TestBase(test_base.DatasetTestBase): """Base class for tf.data service tests.""" def make_distributed_dataset(self, dataset, cluster, processing_mode="parallel_epochs", job_name=None, consumer_index=None, num_consumers=None, max_outstanding_requests=None, data_transfer_protocol=None, compression="AUTO", cross_trainer_cache=None, target_workers="AUTO"): # pylint: disable=protected-access return dataset.apply( data_service_ops._distribute( processing_mode, cluster.dispatcher_address(), job_name=job_name, consumer_index=consumer_index, num_consumers=num_consumers, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=20, data_transfer_protocol=data_transfer_protocol, compression=compression, cross_trainer_cache=cross_trainer_cache, target_workers=target_workers)) def make_distributed_range_dataset(self, num_elements, cluster, processing_mode="parallel_epochs", job_name=None, max_outstanding_requests=None, data_transfer_protocol=None, compression="AUTO", cross_trainer_cache=None, target_workers="AUTO"): dataset = dataset_ops.Dataset.range(num_elements) return self.make_distributed_dataset( dataset, cluster, processing_mode=processing_mode, job_name=job_name, max_outstanding_requests=max_outstanding_requests, data_transfer_protocol=data_transfer_protocol, compression=compression, cross_trainer_cache=cross_trainer_cache, target_workers=target_workers) def make_coordinated_read_dataset( self, cluster, num_consumers, sharding_policy=data_service_ops.ShardingPolicy.OFF): """Creates a dataset that performs coordinated reads. The dataset simulates `num_consumers` consumers by using parallel interleave to read with `num_consumers` threads, one for each consumer. The nth element of the dataset is produced by consumer `n % num_consumers`. The dataset executed on each worker will produce groups of `num_consumers` sequentially increasing numbers. For example, if `num_consumers=3` a worker dataset could produce [0, 1, 2, 9, 10, 11, 21, 22, 23]. This enables `checkCoordinatedReadGroups` below to assess whether the values received in each step came from the same group. Args: cluster: A tf.data service `TestCluster`. num_consumers: The number of consumers to simulate. sharding_policy: The sharding policy to use. Currently only OFF and DYNAMIC are supported. Returns: A dataset that simulates reading with `num_consumers` consumers. """ if sharding_policy not in [ data_service_ops.ShardingPolicy.OFF, data_service_ops.ShardingPolicy.DYNAMIC ]: raise ValueError(f"Unsupported sharding policy: {sharding_policy}") # Start from 0 so that we can detect when a new worker is added with # ShardingPolicy.OFF. ds = dataset_ops.Dataset.from_tensors(math_ops.cast(0, dtypes.int64)) ds = ds.concatenate(dataset_ops.Dataset.random()) # Ensure that all elements in the same group are consecutive. def make_group(x): # Avoid overflowing an int64 in (x+1)*num_consumers below. x = x % (2**32) return dataset_ops.Dataset.range(x*num_consumers, (x+1)*num_consumers) ds = ds.flat_map(make_group) consumers = [] for consumer_index in range(num_consumers): consumers.append( self.make_distributed_dataset( ds, cluster, job_name="test", processing_mode=sharding_policy, consumer_index=consumer_index, num_consumers=num_consumers)) # Use parallel interleave to read from consumers in parallel. ds = dataset_ops.Dataset.from_tensor_slices(consumers) ds = ds.interleave( lambda x: x, cycle_length=num_consumers, num_parallel_calls=num_consumers) return ds def checkCoordinatedReadGroups(self, results, num_consumers): """Validates results from a `make_coordinted_read_dataset` dataset. Each group of `num_consumers` results should be consecutive, indicating that they were produced by the same worker. Args: results: The elements produced by the dataset. num_consumers: The number of consumers. """ groups = [ results[start:start + num_consumers] for start in range(0, len(results), num_consumers) ] incorrect_groups = [] for group in groups: # Check that each group of `num_consumers` results are consecutive. for offset in range(1, len(group)): if group[0] + offset != group[offset]: incorrect_groups.append(group) break self.assertEmpty( incorrect_groups, "Incorrect groups: {}.\nAll groups: {}".format(incorrect_groups, groups)) def read(self, get_next, results, count): for _ in range(count): results.append(self.evaluate(get_next()))