473 lines
17 KiB
Python
473 lines
17 KiB
Python
|
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""A Python interface for creating dataset servers."""
|
||
|
|
||
|
import collections
|
||
|
|
||
|
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
|
||
|
from tensorflow.core.protobuf import service_config_pb2
|
||
|
from tensorflow.python import pywrap_tensorflow
|
||
|
from tensorflow.python.data.experimental.service import _pywrap_server_lib
|
||
|
from tensorflow.python.data.experimental.service import _pywrap_utils
|
||
|
from tensorflow.python.util.tf_export import tf_export
|
||
|
|
||
|
|
||
|
def _get_time_or_placeholder(value):
|
||
|
"""Modifies time-based config values to account for special behaviors."""
|
||
|
|
||
|
# Servers interpret time values of 0 to mean "choose a reasonable
|
||
|
# default". However, the Python API uses `None` for this, and allows 0 as a
|
||
|
# normal value. To account for this, if a user explicitly configures the
|
||
|
# interval/timeout to 0, we interpret it to mean "a very small number", and
|
||
|
# replace it with 1.
|
||
|
if value == 0:
|
||
|
return 1
|
||
|
# `None` indicates that the user wants to leave the behavior to the runtime.
|
||
|
if value is None:
|
||
|
return 0
|
||
|
return value
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.service.DispatcherConfig")
|
||
|
class DispatcherConfig(
|
||
|
collections.namedtuple(
|
||
|
"DispatcherConfig",
|
||
|
[
|
||
|
"port",
|
||
|
"protocol",
|
||
|
"work_dir",
|
||
|
"fault_tolerant_mode",
|
||
|
"worker_addresses",
|
||
|
"job_gc_check_interval_ms",
|
||
|
"job_gc_timeout_ms",
|
||
|
"worker_timeout_ms",
|
||
|
],
|
||
|
)
|
||
|
):
|
||
|
"""Configuration class for tf.data service dispatchers.
|
||
|
|
||
|
Fields:
|
||
|
port: Specifies the port to bind to. A value of 0 indicates that the server
|
||
|
may bind to any available port.
|
||
|
protocol: The protocol to use for communicating with the tf.data service,
|
||
|
e.g. "grpc".
|
||
|
work_dir: A directory to store dispatcher state in. This
|
||
|
argument is required for the dispatcher to be able to recover from
|
||
|
restarts.
|
||
|
fault_tolerant_mode: Whether the dispatcher should write its state to a
|
||
|
journal so that it can recover from restarts. Dispatcher state, including
|
||
|
registered datasets and created jobs, is synchronously written to the
|
||
|
journal before responding to RPCs. If `True`, `work_dir` must also be
|
||
|
specified.
|
||
|
worker_addresses: If the job uses auto-sharding, it needs to specify a fixed
|
||
|
list of worker addresses that will register with the dispatcher. The
|
||
|
worker addresses should be in the format `"host"` or `"host:port"`, where
|
||
|
`"port"` is an integer, named port, or `%port%` to match any port.
|
||
|
job_gc_check_interval_ms: How often the dispatcher should scan through to
|
||
|
delete old and unused jobs, in milliseconds. If not set, the runtime will
|
||
|
select a reasonable default. A higher value will reduce load on the
|
||
|
dispatcher, while a lower value will reduce the time it takes for the
|
||
|
dispatcher to garbage collect expired jobs.
|
||
|
job_gc_timeout_ms: How long a job needs to be unused before it becomes a
|
||
|
candidate for garbage collection, in milliseconds. A value of -1 indicates
|
||
|
that jobs should never be garbage collected. If not set, the runtime will
|
||
|
select a reasonable default. A higher value will cause jobs to stay around
|
||
|
longer with no consumers. This is useful if there is a large gap in
|
||
|
time between when consumers read from the job. A lower value will reduce
|
||
|
the time it takes to reclaim the resources from expired jobs.
|
||
|
worker_timeout_ms: How long to wait for a worker to heartbeat before
|
||
|
considering it missing. If not set, the runtime will select a reasonable
|
||
|
default.
|
||
|
"""
|
||
|
|
||
|
def __new__(
|
||
|
cls,
|
||
|
port=0,
|
||
|
protocol=None,
|
||
|
work_dir=None,
|
||
|
fault_tolerant_mode=False,
|
||
|
worker_addresses=None,
|
||
|
job_gc_check_interval_ms=None,
|
||
|
job_gc_timeout_ms=None,
|
||
|
worker_timeout_ms=None,
|
||
|
):
|
||
|
if protocol is None:
|
||
|
protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
|
||
|
job_gc_check_interval_ms = _get_time_or_placeholder(
|
||
|
job_gc_check_interval_ms)
|
||
|
job_gc_timeout_ms = _get_time_or_placeholder(job_gc_timeout_ms)
|
||
|
return super().__new__(
|
||
|
cls,
|
||
|
port,
|
||
|
protocol,
|
||
|
work_dir,
|
||
|
fault_tolerant_mode,
|
||
|
worker_addresses,
|
||
|
job_gc_check_interval_ms,
|
||
|
job_gc_timeout_ms,
|
||
|
worker_timeout_ms,
|
||
|
)
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.service.DispatchServer", v1=[])
|
||
|
class DispatchServer:
|
||
|
"""An in-process tf.data service dispatch server.
|
||
|
|
||
|
A `tf.data.experimental.service.DispatchServer` coordinates a cluster of
|
||
|
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
||
|
register themselves with the dispatcher.
|
||
|
|
||
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||
|
... tf.data.experimental.service.WorkerConfig(
|
||
|
... dispatcher_address=dispatcher_address))
|
||
|
>>> dataset = tf.data.Dataset.range(10)
|
||
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||
|
>>> print(list(dataset.as_numpy_iterator()))
|
||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||
|
|
||
|
When starting a dedicated tf.data dispatch process, use join() to block
|
||
|
after starting up the server, until the server terminates.
|
||
|
|
||
|
```
|
||
|
dispatcher = tf.data.experimental.service.DispatchServer(
|
||
|
tf.data.experimental.service.DispatcherConfig(port=5050))
|
||
|
dispatcher.join()
|
||
|
```
|
||
|
|
||
|
Call stop() to gracefully terminate the dispatcher. The server automatically
|
||
|
stops when all reference to it have been deleted.
|
||
|
|
||
|
To start a `DispatchServer` in fault-tolerant mode, set `work_dir` and
|
||
|
`fault_tolerant_mode` like below:
|
||
|
|
||
|
```
|
||
|
dispatcher = tf.data.experimental.service.DispatchServer(
|
||
|
tf.data.experimental.service.DispatcherConfig(
|
||
|
port=5050,
|
||
|
work_dir="gs://my-bucket/dispatcher/work_dir",
|
||
|
fault_tolerant_mode=True))
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config=None, start=True):
|
||
|
"""Creates a new dispatch server.
|
||
|
|
||
|
Args:
|
||
|
config: (Optional.) A `tf.data.experimental.service.DispatcherConfig`
|
||
|
configration. If `None`, the dispatcher will use default
|
||
|
configuration values.
|
||
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||
|
creating it. Defaults to True.
|
||
|
"""
|
||
|
config = config or DispatcherConfig()
|
||
|
if config.fault_tolerant_mode and not config.work_dir:
|
||
|
raise ValueError(
|
||
|
"Cannot enable fault tolerant mode without configuring a work dir. "
|
||
|
"Make sure to set `work_dir` in the `config` object passed to "
|
||
|
"`DispatcherServer`.")
|
||
|
self._config = config
|
||
|
if isinstance(config, service_config_pb2.DispatcherConfig):
|
||
|
config_proto = config
|
||
|
else:
|
||
|
config_proto = service_config_pb2.DispatcherConfig(
|
||
|
port=config.port,
|
||
|
protocol=config.protocol,
|
||
|
work_dir=config.work_dir,
|
||
|
fault_tolerant_mode=config.fault_tolerant_mode,
|
||
|
worker_addresses=config.worker_addresses,
|
||
|
job_gc_check_interval_ms=config.job_gc_check_interval_ms,
|
||
|
job_gc_timeout_ms=config.job_gc_timeout_ms,
|
||
|
worker_timeout_ms=config.worker_timeout_ms,
|
||
|
)
|
||
|
self._server = _pywrap_server_lib.TF_DATA_NewDispatchServer(
|
||
|
config_proto.SerializeToString())
|
||
|
if start:
|
||
|
self._server.start()
|
||
|
|
||
|
def start(self):
|
||
|
"""Starts this server.
|
||
|
|
||
|
>>> dispatcher = tf.data.experimental.service.DispatchServer(start=False)
|
||
|
>>> dispatcher.start()
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
starting the server.
|
||
|
"""
|
||
|
self._server.start()
|
||
|
|
||
|
def join(self):
|
||
|
"""Blocks until the server has shut down.
|
||
|
|
||
|
This is useful when starting a dedicated dispatch process.
|
||
|
|
||
|
```
|
||
|
dispatcher = tf.data.experimental.service.DispatchServer(
|
||
|
tf.data.experimental.service.DispatcherConfig(port=5050))
|
||
|
dispatcher.join()
|
||
|
```
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
joining the server.
|
||
|
"""
|
||
|
self._server.join()
|
||
|
|
||
|
def stop(self):
|
||
|
"""Stops the server.
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
stopping the server.
|
||
|
"""
|
||
|
self._stop()
|
||
|
|
||
|
@property
|
||
|
def target(self):
|
||
|
"""Returns a target that can be used to connect to the server.
|
||
|
|
||
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||
|
>>> dataset = tf.data.Dataset.range(10)
|
||
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||
|
|
||
|
The returned string will be in the form protocol://address, e.g.
|
||
|
"grpc://localhost:5050".
|
||
|
"""
|
||
|
return "{0}://localhost:{1}".format(self._config.protocol,
|
||
|
self._server.bound_port())
|
||
|
|
||
|
def _stop(self):
|
||
|
"""Stops the server.
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
stopping the server.
|
||
|
"""
|
||
|
self._server.stop()
|
||
|
|
||
|
def __del__(self):
|
||
|
self._stop()
|
||
|
|
||
|
@property
|
||
|
def _address(self):
|
||
|
"""Returns the address of the server.
|
||
|
|
||
|
The returned string will be in the form address:port, e.g. "localhost:1000".
|
||
|
"""
|
||
|
return "localhost:{0}".format(self._server.bound_port())
|
||
|
|
||
|
def _num_workers(self):
|
||
|
"""Returns the number of workers registered with the dispatcher."""
|
||
|
return self._server.num_workers()
|
||
|
|
||
|
def _snapshot_streams(self, path):
|
||
|
"""Returns information about all the streams for a snapshot."""
|
||
|
return self._server.snapshot_streams(path)
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.service.WorkerConfig")
|
||
|
class WorkerConfig(
|
||
|
collections.namedtuple("WorkerConfig", [
|
||
|
"dispatcher_address", "worker_address", "port", "protocol",
|
||
|
"heartbeat_interval_ms", "dispatcher_timeout_ms",
|
||
|
"data_transfer_protocol", "data_transfer_address"
|
||
|
])):
|
||
|
"""Configuration class for tf.data service dispatchers.
|
||
|
|
||
|
Fields:
|
||
|
dispatcher_address: Specifies the address of the dispatcher.
|
||
|
worker_address: Specifies the address of the worker server. This address is
|
||
|
passed to the dispatcher so that the dispatcher can tell clients how to
|
||
|
connect to this worker.
|
||
|
port: Specifies the port to bind to. A value of 0 indicates that the worker
|
||
|
can bind to any available port.
|
||
|
protocol: A string indicating the protocol to be used by the worker to
|
||
|
connect to the dispatcher. E.g. "grpc".
|
||
|
heartbeat_interval_ms: How often the worker should heartbeat to the
|
||
|
dispatcher, in milliseconds. If not set, the runtime will select a
|
||
|
reasonable default. A higher value will reduce the load on the dispatcher,
|
||
|
while a lower value will reduce the time it takes to reclaim resources
|
||
|
from finished jobs.
|
||
|
dispatcher_timeout_ms: How long, in milliseconds, to retry requests to the
|
||
|
dispatcher before giving up and reporting an error. Defaults to 1 hour.
|
||
|
data_transfer_protocol: A string indicating the protocol to be used by the
|
||
|
worker to transfer data to the client. E.g. "grpc".
|
||
|
data_transfer_address: A string indicating the data transfer address of the
|
||
|
worker server.
|
||
|
"""
|
||
|
|
||
|
def __new__(cls,
|
||
|
dispatcher_address,
|
||
|
worker_address=None,
|
||
|
port=0,
|
||
|
protocol=None,
|
||
|
heartbeat_interval_ms=None,
|
||
|
dispatcher_timeout_ms=None,
|
||
|
data_transfer_protocol=None,
|
||
|
data_transfer_address=None):
|
||
|
if worker_address is None:
|
||
|
worker_address = "localhost:%port%"
|
||
|
if protocol is None:
|
||
|
protocol = _pywrap_utils.TF_DATA_DefaultProtocol()
|
||
|
if data_transfer_protocol is None:
|
||
|
data_transfer_protocol = (
|
||
|
_pywrap_utils.TF_DATA_DefaultDataTransferProtocol())
|
||
|
if data_transfer_address is None:
|
||
|
data_transfer_address = "localhost:%port%"
|
||
|
heartbeat_interval_ms = _get_time_or_placeholder(heartbeat_interval_ms)
|
||
|
dispatcher_timeout_ms = _get_time_or_placeholder(dispatcher_timeout_ms)
|
||
|
|
||
|
return super(WorkerConfig,
|
||
|
cls).__new__(cls, dispatcher_address, worker_address, port,
|
||
|
protocol, heartbeat_interval_ms,
|
||
|
dispatcher_timeout_ms, data_transfer_protocol,
|
||
|
data_transfer_address)
|
||
|
|
||
|
|
||
|
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
||
|
class WorkerServer:
|
||
|
"""An in-process tf.data service worker server.
|
||
|
|
||
|
A `tf.data.experimental.service.WorkerServer` performs `tf.data.Dataset`
|
||
|
processing for user-defined datasets, and provides the resulting elements over
|
||
|
RPC. A worker is associated with a single
|
||
|
`tf.data.experimental.service.DispatchServer`.
|
||
|
|
||
|
>>> dispatcher = tf.data.experimental.service.DispatchServer()
|
||
|
>>> dispatcher_address = dispatcher.target.split("://")[1]
|
||
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||
|
... tf.data.experimental.service.WorkerConfig(
|
||
|
... dispatcher_address=dispatcher_address))
|
||
|
>>> dataset = tf.data.Dataset.range(10)
|
||
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||
|
... processing_mode="parallel_epochs", service=dispatcher.target))
|
||
|
>>> print(list(dataset.as_numpy_iterator()))
|
||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||
|
|
||
|
When starting a dedicated tf.data worker process, use join() to block
|
||
|
after starting up the worker, until the worker terminates.
|
||
|
|
||
|
```
|
||
|
worker = tf.data.experimental.service.WorkerServer(
|
||
|
port=5051, dispatcher_address="localhost:5050")
|
||
|
worker.join()
|
||
|
```
|
||
|
|
||
|
Call stop() to gracefully terminate the worker. The worker automatically stops
|
||
|
when all reference to it have been deleted.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, config, start=True):
|
||
|
"""Creates a new worker server.
|
||
|
|
||
|
Args:
|
||
|
config: A `tf.data.experimental.service.WorkerConfig` configration.
|
||
|
start: (Optional.) Boolean, indicating whether to start the server after
|
||
|
creating it. Defaults to True.
|
||
|
"""
|
||
|
if config.dispatcher_address is None:
|
||
|
raise ValueError(
|
||
|
"Must specify a `dispatcher_address` in the `config` passed "
|
||
|
"to `WorkerServer`.")
|
||
|
if isinstance(config, service_config_pb2.WorkerConfig):
|
||
|
config_proto = config
|
||
|
else:
|
||
|
config_proto = service_config_pb2.WorkerConfig(
|
||
|
dispatcher_address=config.dispatcher_address,
|
||
|
worker_address=config.worker_address,
|
||
|
port=config.port,
|
||
|
protocol=config.protocol,
|
||
|
heartbeat_interval_ms=config.heartbeat_interval_ms,
|
||
|
dispatcher_timeout_ms=config.dispatcher_timeout_ms,
|
||
|
data_transfer_protocol=config.data_transfer_protocol,
|
||
|
data_transfer_address=config.data_transfer_address)
|
||
|
self._server = _pywrap_server_lib.TF_DATA_NewWorkerServer(
|
||
|
config_proto.SerializeToString())
|
||
|
if start:
|
||
|
self._server.start()
|
||
|
|
||
|
def start(self):
|
||
|
"""Starts this server.
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
starting the server.
|
||
|
"""
|
||
|
self._server.start()
|
||
|
|
||
|
def join(self):
|
||
|
"""Blocks until the server has shut down.
|
||
|
|
||
|
This is useful when starting a dedicated worker process.
|
||
|
|
||
|
```
|
||
|
worker_server = tf.data.experimental.service.WorkerServer(
|
||
|
port=5051, dispatcher_address="localhost:5050")
|
||
|
worker_server.join()
|
||
|
```
|
||
|
|
||
|
This method currently blocks forever.
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
joining the server.
|
||
|
"""
|
||
|
self._server.join()
|
||
|
|
||
|
def stop(self):
|
||
|
"""Stops the server.
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
stopping the server.
|
||
|
"""
|
||
|
self._stop()
|
||
|
|
||
|
def _stop(self):
|
||
|
"""Stops the server.
|
||
|
|
||
|
Raises:
|
||
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||
|
stopping the server.
|
||
|
"""
|
||
|
self._server.stop()
|
||
|
|
||
|
def __del__(self):
|
||
|
self._stop()
|
||
|
|
||
|
@property
|
||
|
def _address(self):
|
||
|
"""Returns the address of the server.
|
||
|
|
||
|
The returned string will be in the form address:port, e.g. "localhost:1000".
|
||
|
"""
|
||
|
return "localhost:{0}".format(self._server.bound_port())
|
||
|
|
||
|
def _num_tasks(self):
|
||
|
"""Returns the number of tasks currently being executed on the worker."""
|
||
|
return self._server.num_tasks()
|
||
|
|
||
|
def _snapshot_task_progresses(self):
|
||
|
"""Returns the progresses of the snapshot tasks currently being executed.
|
||
|
|
||
|
Returns:
|
||
|
An `Iterable[common_pb2.SnapshotTaskProgress]`.
|
||
|
"""
|
||
|
return self._server.snapshot_task_progresses()
|