239 lines
9.1 KiB
Python
239 lines
9.1 KiB
Python
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""Initial server communication to determine session parameters."""
|
||
|
|
||
|
|
||
|
from google.protobuf import message
|
||
|
import requests
|
||
|
|
||
|
from absl import logging
|
||
|
from tensorboard import version
|
||
|
from tensorboard.plugins.scalar import metadata as scalars_metadata
|
||
|
from tensorboard.uploader.proto import server_info_pb2
|
||
|
|
||
|
|
||
|
# Request timeout for communicating with remote server.
|
||
|
_REQUEST_TIMEOUT_SECONDS = 10
|
||
|
|
||
|
# Minimum interval between initiating write WriteScalar RPCs, if not specified
|
||
|
# by server_info, in milliseconds
|
||
|
_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL = 5000
|
||
|
|
||
|
# Minimum interval between initiating write WriteTensor RPCs, if not specified
|
||
|
# by server_info, in milliseconds.
|
||
|
_DEFAULT_MIN_TENSOR_REQUEST_INTERVAL = 1000
|
||
|
|
||
|
# Minimum interval between initiating blob write RPC streams, if not specified
|
||
|
# by server_info, in milliseconds.
|
||
|
# This may differ from the above RPC rate limits, because blob streams
|
||
|
# are not batched, so sending a sequence of N blobs requires N streams, which
|
||
|
# could reasonably be sent more frequently.
|
||
|
_DEFAULT_MIN_BLOB_REQUEST_INTERVAL = 1000
|
||
|
|
||
|
# Maximum WriteScalar request size, if not specified by server_info, in bytes.
|
||
|
# The server-side limit is 4 MiB [1]; we should pad a bit to mitigate any errors
|
||
|
# in our bookkeeping. Currently, we pad a lot because WriteScalar is relatively
|
||
|
# slow and we would otherwise risk Deadline Exceeded errors.
|
||
|
#
|
||
|
# [1]: https://github.com/grpc/grpc/blob/e70d8582b4b0eedc45e3d25a57b58a08b94a9f4a/include/grpc/impl/codegen/grpc_types.h#L447 # pylint: disable=line-too-long
|
||
|
_DEFAULT_MAX_SCALAR_REQUEST_SIZE = 128 * (2**10) # 128KiB
|
||
|
|
||
|
# Maximum WriteTensor request size, if not specified by server_info, in bytes.
|
||
|
# The server-side limit is 4 MiB [1]; we should pad a bit to mitigate any errors
|
||
|
# in our bookkeeping. Currently, we pad a lot.
|
||
|
#
|
||
|
# [1]: https://github.com/grpc/grpc/blob/e70d8582b4b0eedc45e3d25a57b58a08b94a9f4a/include/grpc/impl/codegen/grpc_types.h#L447 # pylint: disable=line-too-long
|
||
|
_DEFAULT_MAX_TENSOR_REQUEST_SIZE = 512 * (2**10) # 512KiB
|
||
|
|
||
|
# Maximum WriteBlob request size, if not specified by server_info, in bytes.
|
||
|
# The server-side limit is 4 MiB [1]; we pad with a 256 KiB chunk to mitigate
|
||
|
# any errors in our bookkeeping.
|
||
|
#
|
||
|
# [1]: https://github.com/grpc/grpc/blob/e70d8582b4b0eedc45e3d25a57b58a08b94a9f4a/include/grpc/impl/codegen/grpc_types.h#L447 # pylint: disable=line-too-long
|
||
|
_DEFAULT_MAX_BLOB_REQUEST_SIZE = 4 * (2**20) - 256 * (2**10) # 4MiB-256KiB
|
||
|
|
||
|
# Maximum blob size, if not specified by server_info, in bytes.
|
||
|
_DEFAULT_MAX_BLOB_SIZE = 10 * (2**20) # 10MiB
|
||
|
|
||
|
# Maximum tensor point size, if not specified by server_info, in bytes.
|
||
|
_DEFAULT_MAX_TENSOR_POINT_SIZE = 16 * (2**10) # 16KiB
|
||
|
|
||
|
|
||
|
def _server_info_request(upload_plugins):
|
||
|
"""Generates a ServerInfoRequest
|
||
|
|
||
|
Args:
|
||
|
upload_plugins: List of plugin names requested by the user and to be
|
||
|
verified by the server.
|
||
|
|
||
|
Returns:
|
||
|
A `server_info_pb2.ServerInfoRequest` message.
|
||
|
"""
|
||
|
request = server_info_pb2.ServerInfoRequest()
|
||
|
request.version = version.VERSION
|
||
|
request.plugin_specification.upload_plugins[:] = upload_plugins
|
||
|
return request
|
||
|
|
||
|
|
||
|
def fetch_server_info(origin, upload_plugins):
|
||
|
"""Fetches server info from a remote server.
|
||
|
|
||
|
Args:
|
||
|
origin: The server with which to communicate. Should be a string
|
||
|
like "https://tensorboard.dev", including protocol, host, and (if
|
||
|
needed) port.
|
||
|
upload_plugins: List of plugins names requested by the user and to be
|
||
|
verified by the server.
|
||
|
|
||
|
Returns:
|
||
|
A `server_info_pb2.ServerInfoResponse` message.
|
||
|
|
||
|
Raises:
|
||
|
CommunicationError: Upon failure to connect to or successfully
|
||
|
communicate with the remote server.
|
||
|
"""
|
||
|
endpoint = "%s/api/uploader" % origin
|
||
|
server_info_request = _server_info_request(upload_plugins)
|
||
|
post_body = server_info_request.SerializeToString()
|
||
|
logging.info("Requested server info: <%r>", server_info_request)
|
||
|
try:
|
||
|
response = requests.post(
|
||
|
endpoint,
|
||
|
data=post_body,
|
||
|
timeout=_REQUEST_TIMEOUT_SECONDS,
|
||
|
headers={"User-Agent": "tensorboard/%s" % version.VERSION},
|
||
|
)
|
||
|
except requests.RequestException as e:
|
||
|
raise CommunicationError("Failed to connect to backend: %s" % e)
|
||
|
if not response.ok:
|
||
|
raise CommunicationError(
|
||
|
"Non-OK status from backend (%d %s): %r"
|
||
|
% (response.status_code, response.reason, response.content)
|
||
|
)
|
||
|
try:
|
||
|
return server_info_pb2.ServerInfoResponse.FromString(response.content)
|
||
|
except message.DecodeError as e:
|
||
|
raise CommunicationError(
|
||
|
"Corrupt response from backend (%s): %r" % (e, response.content)
|
||
|
)
|
||
|
|
||
|
|
||
|
def create_server_info(frontend_origin, api_endpoint, upload_plugins):
|
||
|
"""Manually creates server info given a frontend and backend.
|
||
|
|
||
|
Args:
|
||
|
frontend_origin: The origin of the TensorBoard.dev frontend, like
|
||
|
"https://tensorboard.dev" or "http://localhost:8000".
|
||
|
api_endpoint: As to `server_info_pb2.ApiServer.endpoint`.
|
||
|
upload_plugins: List of plugin names requested by the user and to be
|
||
|
verified by the server.
|
||
|
|
||
|
Returns:
|
||
|
A `server_info_pb2.ServerInfoResponse` message.
|
||
|
"""
|
||
|
result = server_info_pb2.ServerInfoResponse()
|
||
|
result.compatibility.verdict = server_info_pb2.VERDICT_OK
|
||
|
result.api_server.endpoint = api_endpoint
|
||
|
url_format = result.url_format
|
||
|
placeholder = "{{EID}}"
|
||
|
while placeholder in frontend_origin:
|
||
|
placeholder = "{%s}" % placeholder
|
||
|
url_format.template = "%s/experiment/%s/" % (frontend_origin, placeholder)
|
||
|
url_format.id_placeholder = placeholder
|
||
|
result.plugin_control.allowed_plugins[:] = upload_plugins
|
||
|
return result
|
||
|
|
||
|
|
||
|
def experiment_url(server_info, experiment_id):
|
||
|
"""Formats a URL that will resolve to the provided experiment.
|
||
|
|
||
|
Args:
|
||
|
server_info: A `server_info_pb2.ServerInfoResponse` message.
|
||
|
experiment_id: A string; the ID of the experiment to link to.
|
||
|
|
||
|
Returns:
|
||
|
A URL resolving to the given experiment, as a string.
|
||
|
"""
|
||
|
url_format = server_info.url_format
|
||
|
return url_format.template.replace(url_format.id_placeholder, experiment_id)
|
||
|
|
||
|
|
||
|
def allowed_plugins(server_info):
|
||
|
"""Determines which plugins may upload data.
|
||
|
|
||
|
This pulls from the `plugin_control` on the `server_info` when that
|
||
|
submessage is set, else falls back to a default.
|
||
|
|
||
|
Args:
|
||
|
server_info: A `server_info_pb2.ServerInfoResponse` message.
|
||
|
|
||
|
Returns:
|
||
|
A `frozenset` of plugin names.
|
||
|
"""
|
||
|
if server_info.HasField("plugin_control"):
|
||
|
return frozenset(server_info.plugin_control.allowed_plugins)
|
||
|
else:
|
||
|
# Old server: gracefully degrade to scalars only, which have
|
||
|
# been supported since launch. TODO(@wchargin): Promote this
|
||
|
# branch to an error once we're confident that we won't roll
|
||
|
# back to old server versions.
|
||
|
return frozenset((scalars_metadata.PLUGIN_NAME,))
|
||
|
|
||
|
|
||
|
def upload_limits(server_info):
|
||
|
"""Returns UploadLimits, from server_info if possible, otherwise from defaults.
|
||
|
|
||
|
Args:
|
||
|
server_info: A `server_info_pb2.ServerInfoResponse` message.
|
||
|
|
||
|
Returns:
|
||
|
An instance of UploadLimits.
|
||
|
"""
|
||
|
if server_info.HasField("upload_limits"):
|
||
|
upload_limits = server_info.upload_limits
|
||
|
else:
|
||
|
upload_limits = server_info_pb2.UploadLimits()
|
||
|
|
||
|
if not upload_limits.max_scalar_request_size:
|
||
|
upload_limits.max_scalar_request_size = _DEFAULT_MAX_SCALAR_REQUEST_SIZE
|
||
|
if not upload_limits.max_tensor_request_size:
|
||
|
upload_limits.max_tensor_request_size = _DEFAULT_MAX_TENSOR_REQUEST_SIZE
|
||
|
if not upload_limits.max_blob_request_size:
|
||
|
upload_limits.max_blob_request_size = _DEFAULT_MAX_BLOB_REQUEST_SIZE
|
||
|
if not upload_limits.min_scalar_request_interval:
|
||
|
upload_limits.min_scalar_request_interval = (
|
||
|
_DEFAULT_MIN_SCALAR_REQUEST_INTERVAL
|
||
|
)
|
||
|
if not upload_limits.min_tensor_request_interval:
|
||
|
upload_limits.min_tensor_request_interval = (
|
||
|
_DEFAULT_MIN_TENSOR_REQUEST_INTERVAL
|
||
|
)
|
||
|
if not upload_limits.min_blob_request_interval:
|
||
|
upload_limits.min_blob_request_interval = (
|
||
|
_DEFAULT_MIN_BLOB_REQUEST_INTERVAL
|
||
|
)
|
||
|
if not upload_limits.max_blob_size:
|
||
|
upload_limits.max_blob_size = _DEFAULT_MAX_BLOB_SIZE
|
||
|
|
||
|
if not upload_limits.max_tensor_point_size:
|
||
|
upload_limits.max_tensor_point_size = _DEFAULT_MAX_TENSOR_POINT_SIZE
|
||
|
return upload_limits
|
||
|
|
||
|
|
||
|
class CommunicationError(RuntimeError):
|
||
|
"""Raised upon failure to communicate with the server."""
|
||
|
|
||
|
pass
|