207 lines
8.3 KiB
Python
207 lines
8.3 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Summary creation methods for the HParams plugin.
|
|
|
|
Typical usage for exporting summaries in a hyperparameters-tuning experiment:
|
|
1. Create the experiment (once) by calling experiment_pb() and exporting
|
|
the resulting summary into a top-level (empty) run.
|
|
2. In each training session in the experiment, call session_start_pb() before
|
|
the session starts, exporting the resulting summary into a uniquely named
|
|
run for the session, say <session_name>.
|
|
3. Train the model in the session, exporting each metric as a scalar summary
|
|
in runs of the form <session_name>/<sub_dir>, where <sub_dir> can be empty a
|
|
(in which case the run is just the <session_name>) and depends on the
|
|
metric. The name of such a metric is a (group, tag) pair given by
|
|
(<sub_dir>, tag) where tag is the tag of the scalar summary.
|
|
When calling experiment_pb in step 1, you'll need to pass all the metric
|
|
names used in the experiemnt.
|
|
4. When the session completes, call session_end_pb() and export the resulting
|
|
summary into the same session run <session_name>.
|
|
"""
|
|
|
|
|
|
import time
|
|
|
|
|
|
import tensorflow as tf
|
|
|
|
from tensorboard.plugins.hparams import api_pb2
|
|
from tensorboard.plugins.hparams import metadata
|
|
from tensorboard.plugins.hparams import plugin_data_pb2
|
|
|
|
|
|
def experiment_pb(
|
|
hparam_infos, metric_infos, user="", description="", time_created_secs=None
|
|
):
|
|
"""Creates a summary that defines a hyperparameter-tuning experiment.
|
|
|
|
Args:
|
|
hparam_infos: Array of api_pb2.HParamInfo messages. Describes the
|
|
hyperparameters used in the experiment.
|
|
metric_infos: Array of api_pb2.MetricInfo messages. Describes the metrics
|
|
used in the experiment. See the documentation at the top of this file
|
|
for how to populate this.
|
|
user: String. An id for the user running the experiment
|
|
description: String. A description for the experiment. May contain markdown.
|
|
time_created_secs: float. The time the experiment is created in seconds
|
|
since the UNIX epoch. If None uses the current time.
|
|
|
|
Returns:
|
|
A summary protobuffer containing the experiment definition.
|
|
"""
|
|
if time_created_secs is None:
|
|
time_created_secs = time.time()
|
|
experiment = api_pb2.Experiment(
|
|
description=description,
|
|
user=user,
|
|
time_created_secs=time_created_secs,
|
|
hparam_infos=hparam_infos,
|
|
metric_infos=metric_infos,
|
|
)
|
|
return _summary(
|
|
metadata.EXPERIMENT_TAG,
|
|
plugin_data_pb2.HParamsPluginData(experiment=experiment),
|
|
)
|
|
|
|
|
|
def session_start_pb(
|
|
hparams, model_uri="", monitor_url="", group_name="", start_time_secs=None
|
|
):
|
|
"""Constructs a SessionStartInfo protobuffer.
|
|
|
|
Creates a summary that contains a training session metadata information.
|
|
One such summary per training session should be created. Each should have
|
|
a different run.
|
|
|
|
Args:
|
|
hparams: A dictionary with string keys. Describes the hyperparameter values
|
|
used in the session, mapping each hyperparameter name to its value.
|
|
Supported value types are `bool`, `int`, `float`, `str`, `list`,
|
|
`tuple`.
|
|
The type of value must correspond to the type of hyperparameter
|
|
(defined in the corresponding api_pb2.HParamInfo member of the
|
|
Experiment protobuf) as follows:
|
|
|
|
+-----------------+---------------------------------+
|
|
|Hyperparameter | Allowed (Python) value types |
|
|
|type | |
|
|
+-----------------+---------------------------------+
|
|
|DATA_TYPE_BOOL | bool |
|
|
|DATA_TYPE_FLOAT64| int, float |
|
|
|DATA_TYPE_STRING | str, tuple, list |
|
|
+-----------------+---------------------------------+
|
|
|
|
Tuple and list instances will be converted to their string
|
|
representation.
|
|
model_uri: See the comment for the field with the same name of
|
|
plugin_data_pb2.SessionStartInfo.
|
|
monitor_url: See the comment for the field with the same name of
|
|
plugin_data_pb2.SessionStartInfo.
|
|
group_name: See the comment for the field with the same name of
|
|
plugin_data_pb2.SessionStartInfo.
|
|
start_time_secs: float. The time to use as the session start time.
|
|
Represented as seconds since the UNIX epoch. If None uses
|
|
the current time.
|
|
Returns:
|
|
The summary protobuffer mentioned above.
|
|
"""
|
|
if start_time_secs is None:
|
|
start_time_secs = time.time()
|
|
session_start_info = plugin_data_pb2.SessionStartInfo(
|
|
model_uri=model_uri,
|
|
monitor_url=monitor_url,
|
|
group_name=group_name,
|
|
start_time_secs=start_time_secs,
|
|
)
|
|
for (hp_name, hp_val) in hparams.items():
|
|
# Boolean typed values need to be checked before integers since in Python
|
|
# isinstance(True/False, int) returns True.
|
|
if isinstance(hp_val, bool):
|
|
session_start_info.hparams[hp_name].bool_value = hp_val
|
|
elif isinstance(hp_val, (float, int)):
|
|
session_start_info.hparams[hp_name].number_value = hp_val
|
|
elif isinstance(hp_val, str):
|
|
session_start_info.hparams[hp_name].string_value = hp_val
|
|
elif isinstance(hp_val, (list, tuple)):
|
|
session_start_info.hparams[hp_name].string_value = str(hp_val)
|
|
else:
|
|
raise TypeError(
|
|
"hparams[%s]=%s has type: %s which is not supported"
|
|
% (hp_name, hp_val, type(hp_val))
|
|
)
|
|
return _summary(
|
|
metadata.SESSION_START_INFO_TAG,
|
|
plugin_data_pb2.HParamsPluginData(
|
|
session_start_info=session_start_info
|
|
),
|
|
)
|
|
|
|
|
|
def session_end_pb(status, end_time_secs=None):
|
|
"""Constructs a SessionEndInfo protobuffer.
|
|
|
|
Creates a summary that contains status information for a completed
|
|
training session. Should be exported after the training session is completed.
|
|
One such summary per training session should be created. Each should have
|
|
a different run.
|
|
|
|
Args:
|
|
status: A tensorboard.hparams.Status enumeration value denoting the
|
|
status of the session.
|
|
end_time_secs: float. The time to use as the session end time. Represented
|
|
as seconds since the unix epoch. If None uses the current time.
|
|
|
|
Returns:
|
|
The summary protobuffer mentioned above.
|
|
"""
|
|
if end_time_secs is None:
|
|
end_time_secs = time.time()
|
|
|
|
session_end_info = plugin_data_pb2.SessionEndInfo(
|
|
status=status, end_time_secs=end_time_secs
|
|
)
|
|
return _summary(
|
|
metadata.SESSION_END_INFO_TAG,
|
|
plugin_data_pb2.HParamsPluginData(session_end_info=session_end_info),
|
|
)
|
|
|
|
|
|
def _summary(tag, hparams_plugin_data):
|
|
"""Returns a summary holding the given HParamsPluginData message.
|
|
|
|
Helper function.
|
|
|
|
Args:
|
|
tag: string. The tag to use.
|
|
hparams_plugin_data: The HParamsPluginData message to use.
|
|
"""
|
|
summary = tf.compat.v1.Summary()
|
|
tb_metadata = metadata.create_summary_metadata(hparams_plugin_data)
|
|
raw_metadata = tb_metadata.SerializeToString()
|
|
tf_metadata = tf.compat.v1.SummaryMetadata.FromString(raw_metadata)
|
|
summary.value.add(
|
|
tag=tag,
|
|
metadata=tf_metadata,
|
|
tensor=_TF_NULL_TENSOR,
|
|
)
|
|
return summary
|
|
|
|
|
|
# Like `metadata.NULL_TENSOR`, but with the TensorFlow version of the
|
|
# proto. Slight kludge needed to expose the `TensorProto` type.
|
|
_TF_NULL_TENSOR = type(tf.make_tensor_proto(0)).FromString(
|
|
metadata.NULL_TENSOR.SerializeToString()
|
|
)
|