Intelegentny_Pszczelarz/.venv/Lib/site-packages/keras/callbacks_v1.py

529 lines
22 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Callbacks: utilities called at certain points during model training."""
import os
import numpy as np
import tensorflow.compat.v2 as tf
from keras import backend
from keras import callbacks
# isort: off
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export
@keras_export(v1=["keras.callbacks.TensorBoard"])
class TensorBoard(callbacks.TensorBoard):
"""Enable visualizations for TensorBoard.
TensorBoard is a visualization tool provided with TensorFlow.
This callback logs events for TensorBoard, including:
* Metrics summary plots
* Training graph visualization
* Activation histograms
* Sampled profiling
If you have installed TensorFlow with pip, you should be able
to launch TensorBoard from the command line:
```sh
tensorboard --logdir=path_to_your_logs
```
You can find more information about TensorBoard
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
Args:
log_dir: the path of the directory where to save the log files to be
parsed by TensorBoard.
histogram_freq: frequency (in epochs) at which to compute activation and
weight histograms for the layers of the model. If set to 0, histograms
won't be computed. Validation data (or split) must be specified for
histogram visualizations.
write_graph: whether to visualize the graph in TensorBoard. The log file
can become quite large when write_graph is set to True.
write_grads: whether to visualize gradient histograms in TensorBoard.
`histogram_freq` must be greater than 0.
batch_size: size of batch of inputs to feed to the network for
histograms computation.
write_images: whether to write model weights to visualize as image in
TensorBoard.
embeddings_freq: frequency (in epochs) at which selected embedding
layers will be saved. If set to 0, embeddings won't be computed. Data
to be visualized in TensorBoard's Embedding tab must be passed as
`embeddings_data`.
embeddings_layer_names: a list of names of layers to keep eye on. If
None or empty list all the embedding layer will be watched.
embeddings_metadata: a dictionary which maps layer name to a file name
in which metadata for this embedding layer is saved.
[Here are details](
https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
about metadata files format. In case if the same metadata file is
used for all embedding layers, string can be passed.
embeddings_data: data to be embedded at layers specified in
`embeddings_layer_names`. Numpy array (if the model has a single
input) or list of Numpy arrays (if the model has multiple inputs).
Learn more about embeddings [in this guide](
https://www.tensorflow.org/programmers_guide/embedding).
update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
writes the losses and metrics to TensorBoard after each batch. The
same applies for `'epoch'`. If using an integer, let's say `1000`, the
callback will write the metrics and losses to TensorBoard every 1000
samples. Note that writing too frequently to TensorBoard can slow down
your training.
profile_batch: Profile the batch to sample compute characteristics. By
default, it will profile the second batch. Set profile_batch=0 to
disable profiling.
Raises:
ValueError: If histogram_freq is set and no validation data is provided.
@compatibility(eager)
Using the `TensorBoard` callback will work when eager execution is enabled,
with the restriction that outputting histogram summaries of weights and
gradients is not supported. Consequently, `histogram_freq` will be ignored.
@end_compatibility
"""
def __init__(
self,
log_dir="./logs",
histogram_freq=0,
batch_size=32,
write_graph=True,
write_grads=False,
write_images=False,
embeddings_freq=0,
embeddings_layer_names=None,
embeddings_metadata=None,
embeddings_data=None,
update_freq="epoch",
profile_batch=2,
):
# Don't call super's init since it is an eager-only version.
callbacks.Callback.__init__(self)
self.log_dir = log_dir
self.histogram_freq = histogram_freq
if self.histogram_freq and tf.executing_eagerly():
logging.warning(
UserWarning(
"Weight and gradient histograms not supported for eager"
"execution, setting `histogram_freq` to `0`."
)
)
self.histogram_freq = 0
self.merged = None
self.write_graph = write_graph
self.write_grads = write_grads
self.write_images = write_images
self.batch_size = batch_size
self._current_batch = 0
self._total_batches_seen = 0
self._total_val_batches_seen = 0
self.embeddings_freq = embeddings_freq
self.embeddings_layer_names = embeddings_layer_names
self.embeddings_metadata = embeddings_metadata
self.embeddings_data = embeddings_data
if update_freq == "batch":
self.update_freq = 1
else:
self.update_freq = update_freq
self._samples_seen = 0
self._samples_seen_at_last_write = 0
# TODO(fishx): Add a link to the full profiler tutorial.
self._profile_batch = profile_batch
# True when the profiler was successfully started by this callback.
# We track the status here to make sure callbacks do not interfere with
# each other. The callback will only stop the profiler it started.
self._profiler_started = False
# TensorBoard should only write summaries on the chief when in a
# Multi-Worker setting.
self._chief_worker_only = True
def _init_writer(self, model):
"""Sets file writer."""
if tf.executing_eagerly():
self.writer = tf.summary.create_file_writer(self.log_dir)
if not model.run_eagerly and self.write_graph:
with self.writer.as_default():
tf.summary.graph(backend.get_graph())
elif self.write_graph:
self.writer = tf.compat.v1.summary.FileWriter(
self.log_dir, backend.get_graph()
)
else:
self.writer = tf.compat.v1.summary.FileWriter(self.log_dir)
def _make_histogram_ops(self, model):
"""Defines histogram ops when histogram_freq > 0."""
# only make histogram summary op if it hasn't already been made
if self.histogram_freq and self.merged is None:
for layer in self.model.layers:
for weight in layer.weights:
mapped_weight_name = weight.name.replace(":", "_")
tf.compat.v1.summary.histogram(mapped_weight_name, weight)
if self.write_images:
w_img = tf.compat.v1.squeeze(weight)
shape = tuple(w_img.shape)
if len(shape) == 2: # dense layer kernel case
if shape[0] > shape[1]:
w_img = tf.compat.v1.transpose(w_img)
shape = tuple(w_img.shape)
w_img = tf.reshape(
w_img, [1, shape[0], shape[1], 1]
)
elif len(shape) == 3: # convnet case
if backend.image_data_format() == "channels_last":
# switch to channels_first to display
# every kernel as a separate image
w_img = tf.compat.v1.transpose(
w_img, perm=[2, 0, 1]
)
shape = tuple(w_img.shape)
w_img = tf.reshape(
w_img, [shape[0], shape[1], shape[2], 1]
)
elif len(shape) == 1: # bias case
w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
else:
# not possible to handle 3D convnets etc.
continue
shape = tuple(w_img.shape)
assert len(shape) == 4 and shape[-1] in [1, 3, 4]
tf.compat.v1.summary.image(mapped_weight_name, w_img)
if self.write_grads:
for weight in layer.trainable_weights:
mapped_weight_name = weight.name.replace(":", "_")
grads = model.optimizer.get_gradients(
model.total_loss, weight
)
def is_indexed_slices(grad):
return type(grad).__name__ == "IndexedSlices"
grads = [
grad.values if is_indexed_slices(grad) else grad
for grad in grads
]
tf.compat.v1.summary.histogram(
f"{mapped_weight_name}_grad", grads
)
if hasattr(layer, "output"):
if isinstance(layer.output, list):
for i, output in enumerate(layer.output):
tf.compat.v1.summary.histogram(
f"{layer.name}_out_{i}", output
)
else:
tf.compat.v1.summary.histogram(
f"{layer.name}_out", layer.output
)
def set_model(self, model):
"""Sets Keras model and creates summary ops."""
self.model = model
self._init_writer(model)
# histogram summaries only enabled in graph mode
if not tf.executing_eagerly():
self._make_histogram_ops(model)
self.merged = tf.compat.v1.summary.merge_all()
# If both embedding_freq and embeddings_data are available, we will
# visualize embeddings.
if self.embeddings_freq and self.embeddings_data is not None:
# Avoid circular dependency.
from keras.engine import (
training_utils_v1,
)
self.embeddings_data = training_utils_v1.standardize_input_data(
self.embeddings_data, model.input_names
)
# If embedding_layer_names are not provided, get all of the
# embedding layers from the model.
embeddings_layer_names = self.embeddings_layer_names
if not embeddings_layer_names:
embeddings_layer_names = [
layer.name
for layer in self.model.layers
if type(layer).__name__ == "Embedding"
]
self.assign_embeddings = []
embeddings_vars = {}
self.batch_id = batch_id = tf.compat.v1.placeholder(tf.int32)
self.step = step = tf.compat.v1.placeholder(tf.int32)
for layer in self.model.layers:
if layer.name in embeddings_layer_names:
embedding_input = self.model.get_layer(layer.name).output
embedding_size = np.prod(embedding_input.shape[1:])
embedding_input = tf.reshape(
embedding_input, (step, int(embedding_size))
)
shape = (
self.embeddings_data[0].shape[0],
int(embedding_size),
)
embedding = tf.Variable(
tf.zeros(shape), name=layer.name + "_embedding"
)
embeddings_vars[layer.name] = embedding
batch = tf.compat.v1.assign(
embedding[batch_id : batch_id + step], embedding_input
)
self.assign_embeddings.append(batch)
self.saver = tf.compat.v1.train.Saver(
list(embeddings_vars.values())
)
# Create embeddings_metadata dictionary
if isinstance(self.embeddings_metadata, str):
embeddings_metadata = {
layer_name: self.embeddings_metadata
for layer_name in embeddings_vars.keys()
}
else:
# If embedding_metadata is already a dictionary
embeddings_metadata = self.embeddings_metadata
try:
# isort: off
from tensorboard.plugins import projector
except ImportError:
raise ImportError(
"Failed to import TensorBoard. Please make sure that "
'TensorBoard integration is complete."'
)
# TODO(psv): Add integration tests to test embedding visualization
# with TensorBoard callback. We are unable to write a unit test for
# this because TensorBoard dependency assumes TensorFlow package is
# installed.
config = projector.ProjectorConfig()
for layer_name, tensor in embeddings_vars.items():
embedding = config.embeddings.add()
embedding.tensor_name = tensor.name
if (
embeddings_metadata is not None
and layer_name in embeddings_metadata
):
embedding.metadata_path = embeddings_metadata[layer_name]
projector.visualize_embeddings(self.writer, config)
def _fetch_callback(self, summary):
self.writer.add_summary(summary, self._total_val_batches_seen)
self._total_val_batches_seen += 1
def _write_custom_summaries(self, step, logs=None):
"""Writes metrics out as custom scalar summaries.
Args:
step: the global step to use for TensorBoard.
logs: dict. Keys are scalar summary names, values are
NumPy scalars.
"""
logs = logs or {}
if tf.executing_eagerly():
# use v2 summary ops
with self.writer.as_default(), tf.summary.record_if(True):
for name, value in logs.items():
if isinstance(value, np.ndarray):
value = value.item()
tf.summary.scalar(name, value, step=step)
else:
# use FileWriter from v1 summary
for name, value in logs.items():
if isinstance(value, np.ndarray):
value = value.item()
summary = tf.compat.v1.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
self.writer.add_summary(summary, step)
self.writer.flush()
def on_train_batch_begin(self, batch, logs=None):
if self._total_batches_seen == self._profile_batch - 1:
self._start_profiler()
def on_train_batch_end(self, batch, logs=None):
return self.on_batch_end(batch, logs)
def on_test_begin(self, logs=None):
pass
def on_test_end(self, logs=None):
pass
def on_batch_end(self, batch, logs=None):
"""Writes scalar summaries for metrics on every training batch.
Performs profiling if current batch is in profiler_batches.
"""
# Don't output batch_size and batch number as TensorBoard summaries
logs = logs or {}
self._samples_seen += logs.get("size", 1)
samples_seen_since = (
self._samples_seen - self._samples_seen_at_last_write
)
if (
self.update_freq != "epoch"
and samples_seen_since >= self.update_freq
):
batch_logs = {
("batch_" + k): v
for k, v in logs.items()
if k not in ["batch", "size", "num_steps"]
}
self._write_custom_summaries(self._total_batches_seen, batch_logs)
self._samples_seen_at_last_write = self._samples_seen
self._total_batches_seen += 1
self._stop_profiler()
def on_train_begin(self, logs=None):
pass
def on_epoch_begin(self, epoch, logs=None):
"""Add histogram op to Model eval_function callbacks, reset batch
count."""
# check if histogram summary should be run for this epoch
if self.histogram_freq and epoch % self.histogram_freq == 0:
# add the histogram summary op if it should run this epoch
self.model._make_test_function()
if self.merged not in self.model.test_function.fetches:
self.model.test_function.fetches.append(self.merged)
self.model.test_function.fetch_callbacks[
self.merged
] = self._fetch_callback
def on_epoch_end(self, epoch, logs=None):
"""Checks if summary ops should run next epoch, logs scalar
summaries."""
# don't output batch_size and
# batch number as TensorBoard summaries
logs = {
("epoch_" + k): v
for k, v in logs.items()
if k not in ["batch", "size", "num_steps"]
}
if self.update_freq == "epoch":
step = epoch
else:
step = self._samples_seen
self._write_custom_summaries(step, logs)
# pop the histogram summary op after each epoch
if self.histogram_freq:
if self.merged in self.model.test_function.fetches:
self.model.test_function.fetches.remove(self.merged)
if self.merged in self.model.test_function.fetch_callbacks:
self.model.test_function.fetch_callbacks.pop(self.merged)
if self.embeddings_data is None and self.embeddings_freq:
raise ValueError(
"To visualize embeddings, embeddings_data must be provided."
)
if self.embeddings_freq and self.embeddings_data is not None:
if epoch % self.embeddings_freq == 0:
# We need a second forward-pass here because we're passing
# the `embeddings_data` explicitly. This design allows to pass
# arbitrary data as `embeddings_data` and results from the fact
# that we need to know the size of the `tf.Variable`s which
# hold the embeddings in `set_model`. At this point, however,
# the `validation_data` is not yet set.
embeddings_data = self.embeddings_data
n_samples = embeddings_data[0].shape[0]
i = 0
sess = backend.get_session()
while i < n_samples:
step = min(self.batch_size, n_samples - i)
batch = slice(i, i + step)
if isinstance(self.model.input, list):
feed_dict = {
model_input: embeddings_data[idx][batch]
for idx, model_input in enumerate(self.model.input)
}
else:
feed_dict = {
self.model.input: embeddings_data[0][batch]
}
feed_dict.update({self.batch_id: i, self.step: step})
if not isinstance(backend.learning_phase(), int):
feed_dict[backend.learning_phase()] = False
sess.run(self.assign_embeddings, feed_dict=feed_dict)
self.saver.save(
sess,
os.path.join(self.log_dir, "keras_embedding.ckpt"),
epoch,
)
i += self.batch_size
def on_train_end(self, logs=None):
self._stop_profiler()
self.writer.close()
def _start_profiler(self):
"""Starts the profiler if currently inactive."""
if self._profiler_started:
return
try:
tf.profiler.experimental.start(logdir=self.log_dir)
self._profiler_started = True
except tf.errors.AlreadyExistsError as e:
# Profiler errors should not be fatal.
logging.error("Failed to start profiler: %s", e.message)
def _stop_profiler(self):
"""Stops the profiler if currently active."""
if not self._profiler_started:
return
try:
tf.profiler.experimental.stop()
except tf.errors.UnavailableError as e:
# Profiler errors should not be fatal.
logging.error("Failed to stop profiler: %s", e.message)
finally:
self._profiler_started = False