Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/data/experimental/kernel_tests/service/multi_process_cluster.py
2023-06-19 00:49:18 +02:00

166 lines
5.8 KiB
Python

# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""tf.data service test-cluster with local and remote workers."""
import tempfile
from tensorflow.core.protobuf import data_service_pb2
from tensorflow.core.protobuf import service_config_pb2
from tensorflow.python.data.experimental.kernel_tests.service import test_base as data_service_test_base
from tensorflow.python.data.experimental.service import server_lib
from tensorflow.python.distribute import multi_process_lib
from tensorflow.python.framework import test_util
from tensorflow.python.platform import googletest
_WORKER_SHUTDOWN_QUIET_PERIOD_MS = 100
# pylint: disable=protected-access
class _RemoteWorkerProcess(multi_process_lib.Process):
"""Runs a worker server in a new process to simulate a remote worker."""
def __init__(self, dispatcher_address, port, worker_tags, pipe_writer):
super(_RemoteWorkerProcess, self).__init__()
self._dispatcher_address = dispatcher_address
self._port = port
self._worker_tags = worker_tags
self._pipe_writer = pipe_writer
def run(self):
self.start_worker()
def start_worker(self):
self._worker = data_service_test_base.TestWorker(
self._dispatcher_address,
_WORKER_SHUTDOWN_QUIET_PERIOD_MS,
port=self._port,
worker_tags=self._worker_tags)
self._worker.start()
self._pipe_writer.send(self._worker.worker_address())
self._worker.join()
class MultiProcessCluster:
"""tf.data service cluster with local and remote workers.
Represents a cluster with a dispatcher, `num_local_workers` local workers, and
`num_remote_workers` remote workers. Remote workers run in separate processes.
This is useful to test reading from local in-process workers. For example:
```
cluster = multi_process_cluster.MultiProcessCluster(
num_local_workers=1, num_remote_workers=3)
num_elements = 10
dataset = self.make_distributed_range_dataset(
num_elements, cluster, target_workers="LOCAL")
self.assertDatasetProduces(dataset, list(range(num_elements)))
```
"""
def __init__(self,
num_local_workers,
num_remote_workers,
worker_tags=None,
worker_addresses=None,
deployment_mode=data_service_pb2.DEPLOYMENT_MODE_COLOCATED):
self._work_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
self._deployment_mode = deployment_mode
self._start_dispatcher(worker_addresses)
self._start_local_workers(num_local_workers, worker_tags)
self._start_remote_workers(num_remote_workers, worker_tags)
def _start_dispatcher(self, worker_addresses, port=0):
if port == 0:
port = test_util.pick_unused_port()
self._dispatcher = server_lib.DispatchServer(
service_config_pb2.DispatcherConfig(
port=port,
protocol="grpc",
work_dir=self._work_dir,
fault_tolerant_mode=True,
worker_addresses=worker_addresses,
deployment_mode=self._deployment_mode),
start=True)
def _start_local_workers(self, num_workers, worker_tags=None):
self._local_workers = []
for _ in range(num_workers):
self.start_local_worker(worker_tags)
def _start_remote_workers(self, num_workers, worker_tags=None):
# List of (worker address, remote worker process) tuples.
self._remote_workers = []
for _ in range(num_workers):
self.start_remote_worker(worker_tags)
def start_local_worker(self, worker_tags=None):
worker = data_service_test_base.TestWorker(
self.dispatcher_address(),
_WORKER_SHUTDOWN_QUIET_PERIOD_MS,
port=test_util.pick_unused_port(),
worker_tags=worker_tags)
worker.start()
self._local_workers.append(worker)
def start_remote_worker(self, worker_tags=None):
"""Runs a tf.data service worker in a remote process."""
pipe_reader, pipe_writer = multi_process_lib.multiprocessing.Pipe(
duplex=False)
worker_process = _RemoteWorkerProcess(
self.dispatcher_address(),
port=test_util.pick_unused_port(),
worker_tags=worker_tags,
pipe_writer=pipe_writer)
worker_process.start()
worker_address = pipe_reader.recv()
self._remote_workers.append((worker_address, worker_process))
def restart_dispatcher(self):
port = int(self.dispatcher_address().split(":")[1])
self._dispatcher._stop()
self._start_dispatcher(
worker_addresses=(self.local_worker_addresses() +
self.remote_worker_addresses()),
port=port)
def restart_local_workers(self):
for worker in self._local_workers:
worker.restart()
def dispatcher_address(self):
return self._dispatcher._address
def local_worker_addresses(self):
return [worker.worker_address() for worker in self._local_workers]
def remote_worker_addresses(self):
return [worker_address for (worker_address, _) in self._remote_workers]
def _stop(self):
for worker in self._local_workers:
worker.stop()
for (_, worker_process) in self._remote_workers:
worker_process.kill()
self._dispatcher._stop()
def __del__(self):
self._stop()
def test_main():
"""Main function to be called within `__main__` of a test file."""
multi_process_lib.test_main()