# Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A data provider that talks to a gRPC server.""" import contextlib import grpc from tensorboard.util import tensor_util from tensorboard.util import timing from tensorboard import errors from tensorboard.data import provider from tensorboard.data.proto import data_provider_pb2 from tensorboard.data.proto import data_provider_pb2_grpc def make_stub(channel): """Wraps a gRPC channel with a service stub.""" return data_provider_pb2_grpc.TensorBoardDataProviderStub(channel) class GrpcDataProvider(provider.DataProvider): """Data provider that talks over gRPC.""" def __init__(self, addr, stub): """Initializes a GrpcDataProvider. Args: addr: String address of the remote peer. Used cosmetically for data location. stub: `data_provider_pb2_grpc.TensorBoardDataProviderStub` value. See `make_stub` to construct one from a channel. """ self._addr = addr self._stub = stub def __str__(self): return "GrpcDataProvider(addr=%r)" % self._addr def experiment_metadata(self, ctx, *, experiment_id): req = data_provider_pb2.GetExperimentRequest() req.experiment_id = experiment_id with _translate_grpc_error(): res = self._stub.GetExperiment(req) res = provider.ExperimentMetadata( data_location=res.data_location, experiment_name=res.name, experiment_description=res.description, creation_time=_timestamp_proto_to_float(res.creation_time), ) return res def list_plugins(self, ctx, *, experiment_id): req = data_provider_pb2.ListPluginsRequest() req.experiment_id = experiment_id with _translate_grpc_error(): res = self._stub.ListPlugins(req) return [p.name for p in res.plugins] def list_runs(self, ctx, *, experiment_id): req = data_provider_pb2.ListRunsRequest() req.experiment_id = experiment_id with _translate_grpc_error(): res = self._stub.ListRuns(req) return [ provider.Run( run_id=run.name, run_name=run.name, start_time=run.start_time, ) for run in res.runs ] @timing.log_latency def list_scalars( self, ctx, *, experiment_id, plugin_name, run_tag_filter=None ): with timing.log_latency("build request"): req = data_provider_pb2.ListScalarsRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) with timing.log_latency("_stub.ListScalars"): with _translate_grpc_error(): res = self._stub.ListScalars(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.ScalarTimeSeries( max_step=time_series.max_step, max_wall_time=time_series.max_wall_time, plugin_content=time_series.summary_metadata.plugin_data.content, description=time_series.summary_metadata.summary_description, display_name=time_series.summary_metadata.display_name, ) return result @timing.log_latency def read_scalars( self, ctx, *, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): with timing.log_latency("build request"): req = data_provider_pb2.ReadScalarsRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) req.downsample.num_points = downsample with timing.log_latency("_stub.ReadScalars"): with _translate_grpc_error(): res = self._stub.ReadScalars(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: series = [] tags[tag_entry.tag_name] = series d = tag_entry.data for (step, wt, value) in zip(d.step, d.wall_time, d.value): point = provider.ScalarDatum( step=step, wall_time=wt, value=value, ) series.append(point) return result @timing.log_latency def list_tensors( self, ctx, *, experiment_id, plugin_name, run_tag_filter=None ): with timing.log_latency("build request"): req = data_provider_pb2.ListTensorsRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) with timing.log_latency("_stub.ListTensors"): with _translate_grpc_error(): res = self._stub.ListTensors(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.TensorTimeSeries( max_step=time_series.max_step, max_wall_time=time_series.max_wall_time, plugin_content=time_series.summary_metadata.plugin_data.content, description=time_series.summary_metadata.summary_description, display_name=time_series.summary_metadata.display_name, ) return result @timing.log_latency def read_tensors( self, ctx, *, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): with timing.log_latency("build request"): req = data_provider_pb2.ReadTensorsRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) req.downsample.num_points = downsample with timing.log_latency("_stub.ReadTensors"): with _translate_grpc_error(): res = self._stub.ReadTensors(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: series = [] tags[tag_entry.tag_name] = series d = tag_entry.data for (step, wt, value) in zip(d.step, d.wall_time, d.value): point = provider.TensorDatum( step=step, wall_time=wt, numpy=tensor_util.make_ndarray(value), ) series.append(point) return result @timing.log_latency def list_blob_sequences( self, ctx, experiment_id, plugin_name, run_tag_filter=None ): with timing.log_latency("build request"): req = data_provider_pb2.ListBlobSequencesRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) with timing.log_latency("_stub.ListBlobSequences"): with _translate_grpc_error(): res = self._stub.ListBlobSequences(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: time_series = tag_entry.metadata tags[tag_entry.tag_name] = provider.BlobSequenceTimeSeries( max_step=time_series.max_step, max_wall_time=time_series.max_wall_time, max_length=time_series.max_length, plugin_content=time_series.summary_metadata.plugin_data.content, description=time_series.summary_metadata.summary_description, display_name=time_series.summary_metadata.display_name, ) return result @timing.log_latency def read_blob_sequences( self, ctx, experiment_id, plugin_name, downsample=None, run_tag_filter=None, ): with timing.log_latency("build request"): req = data_provider_pb2.ReadBlobSequencesRequest() req.experiment_id = experiment_id req.plugin_filter.plugin_name = plugin_name _populate_rtf(run_tag_filter, req.run_tag_filter) req.downsample.num_points = downsample with timing.log_latency("_stub.ReadBlobSequences"): with _translate_grpc_error(): res = self._stub.ReadBlobSequences(req) with timing.log_latency("build result"): result = {} for run_entry in res.runs: tags = {} result[run_entry.run_name] = tags for tag_entry in run_entry.tags: series = [] tags[tag_entry.tag_name] = series d = tag_entry.data for (step, wt, blob_sequence) in zip( d.step, d.wall_time, d.values ): values = [] for ref in blob_sequence.blob_refs: values.append( provider.BlobReference( blob_key=ref.blob_key, url=ref.url or None ) ) point = provider.BlobSequenceDatum( step=step, wall_time=wt, values=tuple(values) ) series.append(point) return result @timing.log_latency def read_blob(self, ctx, blob_key): with timing.log_latency("build request"): req = data_provider_pb2.ReadBlobRequest() req.blob_key = blob_key with timing.log_latency("list(_stub.ReadBlob)"): with _translate_grpc_error(): responses = list(self._stub.ReadBlob(req)) with timing.log_latency("build result"): return b"".join(res.data for res in responses) @contextlib.contextmanager def _translate_grpc_error(): try: yield except grpc.RpcError as e: if e.code() == grpc.StatusCode.INVALID_ARGUMENT: raise errors.InvalidArgumentError(e.details()) if e.code() == grpc.StatusCode.NOT_FOUND: raise errors.NotFoundError(e.details()) if e.code() == grpc.StatusCode.PERMISSION_DENIED: raise errors.PermissionDeniedError(e.details()) raise def _populate_rtf(run_tag_filter, rtf_proto): """Copies `run_tag_filter` into `rtf_proto`.""" if run_tag_filter is None: return if run_tag_filter.runs is not None: rtf_proto.runs.names[:] = sorted(run_tag_filter.runs) if run_tag_filter.tags is not None: rtf_proto.tags.names[:] = sorted(run_tag_filter.tags) def _timestamp_proto_to_float(ts): """Converts `timestamp_pb2.Timestamp` to float seconds since epoch.""" return ts.ToNanoseconds() / 1e9