529 lines
22 KiB
Python
529 lines
22 KiB
Python
![]() |
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
|
||
|
|
||
|
"""Callbacks: utilities called at certain points during model training."""
|
||
|
|
||
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import tensorflow.compat.v2 as tf
|
||
|
|
||
|
from keras import backend
|
||
|
from keras import callbacks
|
||
|
|
||
|
# isort: off
|
||
|
from tensorflow.python.platform import tf_logging as logging
|
||
|
from tensorflow.python.util.tf_export import keras_export
|
||
|
|
||
|
|
||
|
@keras_export(v1=["keras.callbacks.TensorBoard"])
|
||
|
class TensorBoard(callbacks.TensorBoard):
|
||
|
|
||
|
"""Enable visualizations for TensorBoard.
|
||
|
|
||
|
TensorBoard is a visualization tool provided with TensorFlow.
|
||
|
|
||
|
This callback logs events for TensorBoard, including:
|
||
|
* Metrics summary plots
|
||
|
* Training graph visualization
|
||
|
* Activation histograms
|
||
|
* Sampled profiling
|
||
|
|
||
|
If you have installed TensorFlow with pip, you should be able
|
||
|
to launch TensorBoard from the command line:
|
||
|
|
||
|
```sh
|
||
|
tensorboard --logdir=path_to_your_logs
|
||
|
```
|
||
|
|
||
|
You can find more information about TensorBoard
|
||
|
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
||
|
|
||
|
Args:
|
||
|
log_dir: the path of the directory where to save the log files to be
|
||
|
parsed by TensorBoard.
|
||
|
histogram_freq: frequency (in epochs) at which to compute activation and
|
||
|
weight histograms for the layers of the model. If set to 0, histograms
|
||
|
won't be computed. Validation data (or split) must be specified for
|
||
|
histogram visualizations.
|
||
|
write_graph: whether to visualize the graph in TensorBoard. The log file
|
||
|
can become quite large when write_graph is set to True.
|
||
|
write_grads: whether to visualize gradient histograms in TensorBoard.
|
||
|
`histogram_freq` must be greater than 0.
|
||
|
batch_size: size of batch of inputs to feed to the network for
|
||
|
histograms computation.
|
||
|
write_images: whether to write model weights to visualize as image in
|
||
|
TensorBoard.
|
||
|
embeddings_freq: frequency (in epochs) at which selected embedding
|
||
|
layers will be saved. If set to 0, embeddings won't be computed. Data
|
||
|
to be visualized in TensorBoard's Embedding tab must be passed as
|
||
|
`embeddings_data`.
|
||
|
embeddings_layer_names: a list of names of layers to keep eye on. If
|
||
|
None or empty list all the embedding layer will be watched.
|
||
|
embeddings_metadata: a dictionary which maps layer name to a file name
|
||
|
in which metadata for this embedding layer is saved.
|
||
|
[Here are details](
|
||
|
https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
|
||
|
about metadata files format. In case if the same metadata file is
|
||
|
used for all embedding layers, string can be passed.
|
||
|
embeddings_data: data to be embedded at layers specified in
|
||
|
`embeddings_layer_names`. Numpy array (if the model has a single
|
||
|
input) or list of Numpy arrays (if the model has multiple inputs).
|
||
|
Learn more about embeddings [in this guide](
|
||
|
https://www.tensorflow.org/programmers_guide/embedding).
|
||
|
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
|
||
|
writes the losses and metrics to TensorBoard after each batch. The
|
||
|
same applies for `'epoch'`. If using an integer, let's say `1000`, the
|
||
|
callback will write the metrics and losses to TensorBoard every 1000
|
||
|
samples. Note that writing too frequently to TensorBoard can slow down
|
||
|
your training.
|
||
|
profile_batch: Profile the batch to sample compute characteristics. By
|
||
|
default, it will profile the second batch. Set profile_batch=0 to
|
||
|
disable profiling.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If histogram_freq is set and no validation data is provided.
|
||
|
|
||
|
@compatibility(eager)
|
||
|
Using the `TensorBoard` callback will work when eager execution is enabled,
|
||
|
with the restriction that outputting histogram summaries of weights and
|
||
|
gradients is not supported. Consequently, `histogram_freq` will be ignored.
|
||
|
@end_compatibility
|
||
|
"""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
log_dir="./logs",
|
||
|
histogram_freq=0,
|
||
|
batch_size=32,
|
||
|
write_graph=True,
|
||
|
write_grads=False,
|
||
|
write_images=False,
|
||
|
embeddings_freq=0,
|
||
|
embeddings_layer_names=None,
|
||
|
embeddings_metadata=None,
|
||
|
embeddings_data=None,
|
||
|
update_freq="epoch",
|
||
|
profile_batch=2,
|
||
|
):
|
||
|
# Don't call super's init since it is an eager-only version.
|
||
|
callbacks.Callback.__init__(self)
|
||
|
self.log_dir = log_dir
|
||
|
self.histogram_freq = histogram_freq
|
||
|
if self.histogram_freq and tf.executing_eagerly():
|
||
|
logging.warning(
|
||
|
UserWarning(
|
||
|
"Weight and gradient histograms not supported for eager"
|
||
|
"execution, setting `histogram_freq` to `0`."
|
||
|
)
|
||
|
)
|
||
|
self.histogram_freq = 0
|
||
|
self.merged = None
|
||
|
self.write_graph = write_graph
|
||
|
self.write_grads = write_grads
|
||
|
self.write_images = write_images
|
||
|
self.batch_size = batch_size
|
||
|
self._current_batch = 0
|
||
|
self._total_batches_seen = 0
|
||
|
self._total_val_batches_seen = 0
|
||
|
self.embeddings_freq = embeddings_freq
|
||
|
self.embeddings_layer_names = embeddings_layer_names
|
||
|
self.embeddings_metadata = embeddings_metadata
|
||
|
self.embeddings_data = embeddings_data
|
||
|
if update_freq == "batch":
|
||
|
self.update_freq = 1
|
||
|
else:
|
||
|
self.update_freq = update_freq
|
||
|
self._samples_seen = 0
|
||
|
self._samples_seen_at_last_write = 0
|
||
|
# TODO(fishx): Add a link to the full profiler tutorial.
|
||
|
self._profile_batch = profile_batch
|
||
|
# True when the profiler was successfully started by this callback.
|
||
|
# We track the status here to make sure callbacks do not interfere with
|
||
|
# each other. The callback will only stop the profiler it started.
|
||
|
self._profiler_started = False
|
||
|
|
||
|
# TensorBoard should only write summaries on the chief when in a
|
||
|
# Multi-Worker setting.
|
||
|
self._chief_worker_only = True
|
||
|
|
||
|
def _init_writer(self, model):
|
||
|
"""Sets file writer."""
|
||
|
if tf.executing_eagerly():
|
||
|
self.writer = tf.summary.create_file_writer(self.log_dir)
|
||
|
if not model.run_eagerly and self.write_graph:
|
||
|
with self.writer.as_default():
|
||
|
tf.summary.graph(backend.get_graph())
|
||
|
elif self.write_graph:
|
||
|
self.writer = tf.compat.v1.summary.FileWriter(
|
||
|
self.log_dir, backend.get_graph()
|
||
|
)
|
||
|
else:
|
||
|
self.writer = tf.compat.v1.summary.FileWriter(self.log_dir)
|
||
|
|
||
|
def _make_histogram_ops(self, model):
|
||
|
"""Defines histogram ops when histogram_freq > 0."""
|
||
|
# only make histogram summary op if it hasn't already been made
|
||
|
if self.histogram_freq and self.merged is None:
|
||
|
for layer in self.model.layers:
|
||
|
for weight in layer.weights:
|
||
|
mapped_weight_name = weight.name.replace(":", "_")
|
||
|
tf.compat.v1.summary.histogram(mapped_weight_name, weight)
|
||
|
if self.write_images:
|
||
|
w_img = tf.compat.v1.squeeze(weight)
|
||
|
shape = tuple(w_img.shape)
|
||
|
if len(shape) == 2: # dense layer kernel case
|
||
|
if shape[0] > shape[1]:
|
||
|
w_img = tf.compat.v1.transpose(w_img)
|
||
|
shape = tuple(w_img.shape)
|
||
|
w_img = tf.reshape(
|
||
|
w_img, [1, shape[0], shape[1], 1]
|
||
|
)
|
||
|
elif len(shape) == 3: # convnet case
|
||
|
if backend.image_data_format() == "channels_last":
|
||
|
# switch to channels_first to display
|
||
|
# every kernel as a separate image
|
||
|
w_img = tf.compat.v1.transpose(
|
||
|
w_img, perm=[2, 0, 1]
|
||
|
)
|
||
|
shape = tuple(w_img.shape)
|
||
|
w_img = tf.reshape(
|
||
|
w_img, [shape[0], shape[1], shape[2], 1]
|
||
|
)
|
||
|
elif len(shape) == 1: # bias case
|
||
|
w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
|
||
|
else:
|
||
|
# not possible to handle 3D convnets etc.
|
||
|
continue
|
||
|
|
||
|
shape = tuple(w_img.shape)
|
||
|
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
|
||
|
tf.compat.v1.summary.image(mapped_weight_name, w_img)
|
||
|
|
||
|
if self.write_grads:
|
||
|
for weight in layer.trainable_weights:
|
||
|
mapped_weight_name = weight.name.replace(":", "_")
|
||
|
grads = model.optimizer.get_gradients(
|
||
|
model.total_loss, weight
|
||
|
)
|
||
|
|
||
|
def is_indexed_slices(grad):
|
||
|
return type(grad).__name__ == "IndexedSlices"
|
||
|
|
||
|
grads = [
|
||
|
grad.values if is_indexed_slices(grad) else grad
|
||
|
for grad in grads
|
||
|
]
|
||
|
tf.compat.v1.summary.histogram(
|
||
|
f"{mapped_weight_name}_grad", grads
|
||
|
)
|
||
|
|
||
|
if hasattr(layer, "output"):
|
||
|
if isinstance(layer.output, list):
|
||
|
for i, output in enumerate(layer.output):
|
||
|
tf.compat.v1.summary.histogram(
|
||
|
f"{layer.name}_out_{i}", output
|
||
|
)
|
||
|
else:
|
||
|
tf.compat.v1.summary.histogram(
|
||
|
f"{layer.name}_out", layer.output
|
||
|
)
|
||
|
|
||
|
def set_model(self, model):
|
||
|
"""Sets Keras model and creates summary ops."""
|
||
|
|
||
|
self.model = model
|
||
|
self._init_writer(model)
|
||
|
# histogram summaries only enabled in graph mode
|
||
|
if not tf.executing_eagerly():
|
||
|
self._make_histogram_ops(model)
|
||
|
self.merged = tf.compat.v1.summary.merge_all()
|
||
|
|
||
|
# If both embedding_freq and embeddings_data are available, we will
|
||
|
# visualize embeddings.
|
||
|
if self.embeddings_freq and self.embeddings_data is not None:
|
||
|
# Avoid circular dependency.
|
||
|
from keras.engine import (
|
||
|
training_utils_v1,
|
||
|
)
|
||
|
|
||
|
self.embeddings_data = training_utils_v1.standardize_input_data(
|
||
|
self.embeddings_data, model.input_names
|
||
|
)
|
||
|
|
||
|
# If embedding_layer_names are not provided, get all of the
|
||
|
# embedding layers from the model.
|
||
|
embeddings_layer_names = self.embeddings_layer_names
|
||
|
if not embeddings_layer_names:
|
||
|
embeddings_layer_names = [
|
||
|
layer.name
|
||
|
for layer in self.model.layers
|
||
|
if type(layer).__name__ == "Embedding"
|
||
|
]
|
||
|
|
||
|
self.assign_embeddings = []
|
||
|
embeddings_vars = {}
|
||
|
|
||
|
self.batch_id = batch_id = tf.compat.v1.placeholder(tf.int32)
|
||
|
self.step = step = tf.compat.v1.placeholder(tf.int32)
|
||
|
|
||
|
for layer in self.model.layers:
|
||
|
if layer.name in embeddings_layer_names:
|
||
|
embedding_input = self.model.get_layer(layer.name).output
|
||
|
embedding_size = np.prod(embedding_input.shape[1:])
|
||
|
embedding_input = tf.reshape(
|
||
|
embedding_input, (step, int(embedding_size))
|
||
|
)
|
||
|
shape = (
|
||
|
self.embeddings_data[0].shape[0],
|
||
|
int(embedding_size),
|
||
|
)
|
||
|
embedding = tf.Variable(
|
||
|
tf.zeros(shape), name=layer.name + "_embedding"
|
||
|
)
|
||
|
embeddings_vars[layer.name] = embedding
|
||
|
batch = tf.compat.v1.assign(
|
||
|
embedding[batch_id : batch_id + step], embedding_input
|
||
|
)
|
||
|
self.assign_embeddings.append(batch)
|
||
|
|
||
|
self.saver = tf.compat.v1.train.Saver(
|
||
|
list(embeddings_vars.values())
|
||
|
)
|
||
|
|
||
|
# Create embeddings_metadata dictionary
|
||
|
if isinstance(self.embeddings_metadata, str):
|
||
|
embeddings_metadata = {
|
||
|
layer_name: self.embeddings_metadata
|
||
|
for layer_name in embeddings_vars.keys()
|
||
|
}
|
||
|
else:
|
||
|
# If embedding_metadata is already a dictionary
|
||
|
embeddings_metadata = self.embeddings_metadata
|
||
|
|
||
|
try:
|
||
|
# isort: off
|
||
|
from tensorboard.plugins import projector
|
||
|
except ImportError:
|
||
|
raise ImportError(
|
||
|
"Failed to import TensorBoard. Please make sure that "
|
||
|
'TensorBoard integration is complete."'
|
||
|
)
|
||
|
|
||
|
# TODO(psv): Add integration tests to test embedding visualization
|
||
|
# with TensorBoard callback. We are unable to write a unit test for
|
||
|
# this because TensorBoard dependency assumes TensorFlow package is
|
||
|
# installed.
|
||
|
config = projector.ProjectorConfig()
|
||
|
for layer_name, tensor in embeddings_vars.items():
|
||
|
embedding = config.embeddings.add()
|
||
|
embedding.tensor_name = tensor.name
|
||
|
|
||
|
if (
|
||
|
embeddings_metadata is not None
|
||
|
and layer_name in embeddings_metadata
|
||
|
):
|
||
|
embedding.metadata_path = embeddings_metadata[layer_name]
|
||
|
|
||
|
projector.visualize_embeddings(self.writer, config)
|
||
|
|
||
|
def _fetch_callback(self, summary):
|
||
|
self.writer.add_summary(summary, self._total_val_batches_seen)
|
||
|
self._total_val_batches_seen += 1
|
||
|
|
||
|
def _write_custom_summaries(self, step, logs=None):
|
||
|
"""Writes metrics out as custom scalar summaries.
|
||
|
|
||
|
Args:
|
||
|
step: the global step to use for TensorBoard.
|
||
|
logs: dict. Keys are scalar summary names, values are
|
||
|
NumPy scalars.
|
||
|
|
||
|
"""
|
||
|
logs = logs or {}
|
||
|
if tf.executing_eagerly():
|
||
|
# use v2 summary ops
|
||
|
with self.writer.as_default(), tf.summary.record_if(True):
|
||
|
for name, value in logs.items():
|
||
|
if isinstance(value, np.ndarray):
|
||
|
value = value.item()
|
||
|
tf.summary.scalar(name, value, step=step)
|
||
|
else:
|
||
|
# use FileWriter from v1 summary
|
||
|
for name, value in logs.items():
|
||
|
if isinstance(value, np.ndarray):
|
||
|
value = value.item()
|
||
|
summary = tf.compat.v1.Summary()
|
||
|
summary_value = summary.value.add()
|
||
|
summary_value.simple_value = value
|
||
|
summary_value.tag = name
|
||
|
self.writer.add_summary(summary, step)
|
||
|
self.writer.flush()
|
||
|
|
||
|
def on_train_batch_begin(self, batch, logs=None):
|
||
|
if self._total_batches_seen == self._profile_batch - 1:
|
||
|
self._start_profiler()
|
||
|
|
||
|
def on_train_batch_end(self, batch, logs=None):
|
||
|
return self.on_batch_end(batch, logs)
|
||
|
|
||
|
def on_test_begin(self, logs=None):
|
||
|
pass
|
||
|
|
||
|
def on_test_end(self, logs=None):
|
||
|
pass
|
||
|
|
||
|
def on_batch_end(self, batch, logs=None):
|
||
|
"""Writes scalar summaries for metrics on every training batch.
|
||
|
|
||
|
Performs profiling if current batch is in profiler_batches.
|
||
|
"""
|
||
|
# Don't output batch_size and batch number as TensorBoard summaries
|
||
|
logs = logs or {}
|
||
|
self._samples_seen += logs.get("size", 1)
|
||
|
samples_seen_since = (
|
||
|
self._samples_seen - self._samples_seen_at_last_write
|
||
|
)
|
||
|
if (
|
||
|
self.update_freq != "epoch"
|
||
|
and samples_seen_since >= self.update_freq
|
||
|
):
|
||
|
batch_logs = {
|
||
|
("batch_" + k): v
|
||
|
for k, v in logs.items()
|
||
|
if k not in ["batch", "size", "num_steps"]
|
||
|
}
|
||
|
self._write_custom_summaries(self._total_batches_seen, batch_logs)
|
||
|
self._samples_seen_at_last_write = self._samples_seen
|
||
|
self._total_batches_seen += 1
|
||
|
self._stop_profiler()
|
||
|
|
||
|
def on_train_begin(self, logs=None):
|
||
|
pass
|
||
|
|
||
|
def on_epoch_begin(self, epoch, logs=None):
|
||
|
"""Add histogram op to Model eval_function callbacks, reset batch
|
||
|
count."""
|
||
|
|
||
|
# check if histogram summary should be run for this epoch
|
||
|
if self.histogram_freq and epoch % self.histogram_freq == 0:
|
||
|
|
||
|
# add the histogram summary op if it should run this epoch
|
||
|
self.model._make_test_function()
|
||
|
if self.merged not in self.model.test_function.fetches:
|
||
|
self.model.test_function.fetches.append(self.merged)
|
||
|
self.model.test_function.fetch_callbacks[
|
||
|
self.merged
|
||
|
] = self._fetch_callback
|
||
|
|
||
|
def on_epoch_end(self, epoch, logs=None):
|
||
|
"""Checks if summary ops should run next epoch, logs scalar
|
||
|
summaries."""
|
||
|
|
||
|
# don't output batch_size and
|
||
|
# batch number as TensorBoard summaries
|
||
|
logs = {
|
||
|
("epoch_" + k): v
|
||
|
for k, v in logs.items()
|
||
|
if k not in ["batch", "size", "num_steps"]
|
||
|
}
|
||
|
if self.update_freq == "epoch":
|
||
|
step = epoch
|
||
|
else:
|
||
|
step = self._samples_seen
|
||
|
self._write_custom_summaries(step, logs)
|
||
|
|
||
|
# pop the histogram summary op after each epoch
|
||
|
if self.histogram_freq:
|
||
|
|
||
|
if self.merged in self.model.test_function.fetches:
|
||
|
self.model.test_function.fetches.remove(self.merged)
|
||
|
if self.merged in self.model.test_function.fetch_callbacks:
|
||
|
self.model.test_function.fetch_callbacks.pop(self.merged)
|
||
|
|
||
|
if self.embeddings_data is None and self.embeddings_freq:
|
||
|
raise ValueError(
|
||
|
"To visualize embeddings, embeddings_data must be provided."
|
||
|
)
|
||
|
|
||
|
if self.embeddings_freq and self.embeddings_data is not None:
|
||
|
if epoch % self.embeddings_freq == 0:
|
||
|
# We need a second forward-pass here because we're passing
|
||
|
# the `embeddings_data` explicitly. This design allows to pass
|
||
|
# arbitrary data as `embeddings_data` and results from the fact
|
||
|
# that we need to know the size of the `tf.Variable`s which
|
||
|
# hold the embeddings in `set_model`. At this point, however,
|
||
|
# the `validation_data` is not yet set.
|
||
|
|
||
|
embeddings_data = self.embeddings_data
|
||
|
n_samples = embeddings_data[0].shape[0]
|
||
|
i = 0
|
||
|
sess = backend.get_session()
|
||
|
while i < n_samples:
|
||
|
step = min(self.batch_size, n_samples - i)
|
||
|
batch = slice(i, i + step)
|
||
|
|
||
|
if isinstance(self.model.input, list):
|
||
|
feed_dict = {
|
||
|
model_input: embeddings_data[idx][batch]
|
||
|
for idx, model_input in enumerate(self.model.input)
|
||
|
}
|
||
|
else:
|
||
|
feed_dict = {
|
||
|
self.model.input: embeddings_data[0][batch]
|
||
|
}
|
||
|
|
||
|
feed_dict.update({self.batch_id: i, self.step: step})
|
||
|
|
||
|
if not isinstance(backend.learning_phase(), int):
|
||
|
feed_dict[backend.learning_phase()] = False
|
||
|
|
||
|
sess.run(self.assign_embeddings, feed_dict=feed_dict)
|
||
|
self.saver.save(
|
||
|
sess,
|
||
|
os.path.join(self.log_dir, "keras_embedding.ckpt"),
|
||
|
epoch,
|
||
|
)
|
||
|
|
||
|
i += self.batch_size
|
||
|
|
||
|
def on_train_end(self, logs=None):
|
||
|
self._stop_profiler()
|
||
|
self.writer.close()
|
||
|
|
||
|
def _start_profiler(self):
|
||
|
"""Starts the profiler if currently inactive."""
|
||
|
if self._profiler_started:
|
||
|
return
|
||
|
try:
|
||
|
tf.profiler.experimental.start(logdir=self.log_dir)
|
||
|
self._profiler_started = True
|
||
|
except tf.errors.AlreadyExistsError as e:
|
||
|
# Profiler errors should not be fatal.
|
||
|
logging.error("Failed to start profiler: %s", e.message)
|
||
|
|
||
|
def _stop_profiler(self):
|
||
|
"""Stops the profiler if currently active."""
|
||
|
if not self._profiler_started:
|
||
|
return
|
||
|
try:
|
||
|
tf.profiler.experimental.stop()
|
||
|
except tf.errors.UnavailableError as e:
|
||
|
# Profiler errors should not be fatal.
|
||
|
logging.error("Failed to stop profiler: %s", e.message)
|
||
|
finally:
|
||
|
self._profiler_started = False
|