3275 lines
125 KiB
Python
3275 lines
125 KiB
Python
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
|
|
"""Callbacks: utilities called at certain points during model training."""
|
|
|
|
import collections
|
|
import copy
|
|
import csv
|
|
import json
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
|
|
import numpy as np
|
|
import tensorflow.compat.v2 as tf
|
|
|
|
from keras import backend
|
|
from keras.distribute import distributed_file_utils
|
|
from keras.distribute import worker_training_state
|
|
from keras.optimizers import optimizer
|
|
from keras.optimizers.schedules import learning_rate_schedule
|
|
from keras.utils import generic_utils
|
|
from keras.utils import io_utils
|
|
from keras.utils import tf_utils
|
|
from keras.utils import version_utils
|
|
from keras.utils.data_utils import Sequence
|
|
from keras.utils.generic_utils import Progbar
|
|
from keras.utils.mode_keys import ModeKeys
|
|
|
|
# isort: off
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.util import deprecation
|
|
from tensorflow.python.util.tf_export import keras_export
|
|
from tensorflow.tools.docs import doc_controls
|
|
|
|
try:
|
|
import requests
|
|
except ImportError:
|
|
requests = None
|
|
|
|
|
|
# Note: `configure_callbacks` is only used in TF1.
|
|
def configure_callbacks(
|
|
callbacks,
|
|
model,
|
|
do_validation=False,
|
|
batch_size=None,
|
|
epochs=None,
|
|
steps_per_epoch=None,
|
|
samples=None,
|
|
verbose=1,
|
|
count_mode="steps",
|
|
mode=ModeKeys.TRAIN,
|
|
):
|
|
"""Configures callbacks for use in various training loops.
|
|
|
|
Args:
|
|
callbacks: List of Callbacks.
|
|
model: Model being trained.
|
|
do_validation: Whether or not validation loop will be run.
|
|
batch_size: Number of samples per batch.
|
|
epochs: Number of epoch to train.
|
|
steps_per_epoch: Number of batches to run per training epoch.
|
|
samples: Number of training samples.
|
|
verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
|
|
count_mode: One of 'steps' or 'samples'. Per-batch or per-sample count.
|
|
mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
|
|
Which loop mode to configure callbacks for.
|
|
|
|
Returns:
|
|
Instance of CallbackList used to control all Callbacks.
|
|
"""
|
|
# Check if callbacks have already been configured.
|
|
if isinstance(callbacks, CallbackList):
|
|
return callbacks
|
|
|
|
if not callbacks:
|
|
callbacks = []
|
|
|
|
# Add additional callbacks during training.
|
|
if mode == ModeKeys.TRAIN:
|
|
model.history = History()
|
|
callbacks = [BaseLogger()] + (callbacks or []) + [model.history]
|
|
if verbose:
|
|
callbacks.append(ProgbarLogger(count_mode))
|
|
callback_list = CallbackList(callbacks)
|
|
|
|
# Set callback model
|
|
callback_model = model._get_callback_model()
|
|
callback_list.set_model(callback_model)
|
|
|
|
set_callback_parameters(
|
|
callback_list,
|
|
model,
|
|
do_validation=do_validation,
|
|
batch_size=batch_size,
|
|
epochs=epochs,
|
|
steps_per_epoch=steps_per_epoch,
|
|
samples=samples,
|
|
verbose=verbose,
|
|
mode=mode,
|
|
)
|
|
|
|
callback_list.model.stop_training = False
|
|
return callback_list
|
|
|
|
|
|
def set_callback_parameters(
|
|
callback_list,
|
|
model,
|
|
do_validation=False,
|
|
batch_size=None,
|
|
epochs=None,
|
|
steps_per_epoch=None,
|
|
samples=None,
|
|
verbose=1,
|
|
mode=ModeKeys.TRAIN,
|
|
):
|
|
"""Sets callback parameters.
|
|
|
|
Args:
|
|
callback_list: CallbackList instance.
|
|
model: Model being trained.
|
|
do_validation: Whether or not validation loop will be run.
|
|
batch_size: Number of samples per batch.
|
|
epochs: Number of epoch to train.
|
|
steps_per_epoch: Number of batches to run per training epoch.
|
|
samples: Number of training samples.
|
|
verbose: int, 0 or 1. Keras logging verbosity to pass to ProgbarLogger.
|
|
mode: String. One of ModeKeys.TRAIN, ModeKeys.TEST, or ModeKeys.PREDICT.
|
|
Which loop mode to configure callbacks for.
|
|
"""
|
|
metric_names = None
|
|
for cbk in callback_list:
|
|
if isinstance(cbk, (BaseLogger, ProgbarLogger)):
|
|
if not metric_names:
|
|
metric_names = model.metrics_names
|
|
cbk.stateful_metrics = metric_names[1:] # Exclude `loss`
|
|
|
|
# Set callback parameters
|
|
callback_metrics = []
|
|
# When we have deferred build scenario with iterator input, we will compile
|
|
# when we standardize first batch of data.
|
|
if mode != ModeKeys.PREDICT:
|
|
if not metric_names:
|
|
metric_names = model.metrics_names
|
|
callback_metrics = copy.copy(metric_names)
|
|
if do_validation:
|
|
callback_metrics += ["val_" + n for n in metric_names]
|
|
callback_params = {
|
|
"batch_size": batch_size,
|
|
"epochs": epochs,
|
|
"steps": steps_per_epoch,
|
|
"samples": samples,
|
|
"verbose": verbose,
|
|
"do_validation": do_validation,
|
|
"metrics": callback_metrics,
|
|
}
|
|
callback_list.set_params(callback_params)
|
|
|
|
|
|
def _is_generator_like(data):
|
|
"""Checks if data is a generator, Sequence, or Iterator."""
|
|
return (
|
|
hasattr(data, "__next__")
|
|
or hasattr(data, "next")
|
|
or isinstance(
|
|
data, (Sequence, tf.compat.v1.data.Iterator, tf.data.Iterator)
|
|
)
|
|
)
|
|
|
|
|
|
def make_logs(model, logs, outputs, mode, prefix=""):
|
|
"""Computes logs for sending to `on_batch_end` methods."""
|
|
metric_names = model.metrics_names
|
|
if mode in {ModeKeys.TRAIN, ModeKeys.TEST} and metric_names:
|
|
for label, output in zip(metric_names, outputs):
|
|
logs[prefix + label] = output
|
|
else:
|
|
logs["outputs"] = outputs
|
|
return logs
|
|
|
|
|
|
@keras_export("keras.callbacks.CallbackList")
|
|
class CallbackList:
|
|
"""Container abstracting a list of callbacks."""
|
|
|
|
def __init__(
|
|
self,
|
|
callbacks=None,
|
|
add_history=False,
|
|
add_progbar=False,
|
|
model=None,
|
|
**params,
|
|
):
|
|
"""Container for `Callback` instances.
|
|
|
|
This object wraps a list of `Callback` instances, making it possible
|
|
to call them all at once via a single endpoint
|
|
(e.g. `callback_list.on_epoch_end(...)`).
|
|
|
|
Args:
|
|
callbacks: List of `Callback` instances.
|
|
add_history: Whether a `History` callback should be added, if one does
|
|
not already exist in the `callbacks` list.
|
|
add_progbar: Whether a `ProgbarLogger` callback should be added, if
|
|
one does not already exist in the `callbacks` list.
|
|
model: The `Model` these callbacks are used with.
|
|
**params: If provided, parameters will be passed to each `Callback`
|
|
via `Callback.set_params`.
|
|
"""
|
|
self.callbacks = tf.nest.flatten(callbacks) if callbacks else []
|
|
self._add_default_callbacks(add_history, add_progbar)
|
|
|
|
if model:
|
|
self.set_model(model)
|
|
if params:
|
|
self.set_params(params)
|
|
|
|
# Performance optimization: determines if batch hooks need to be called.
|
|
|
|
self._supports_tf_logs = all(
|
|
getattr(cb, "_supports_tf_logs", False) for cb in self.callbacks
|
|
)
|
|
self._batch_hooks_support_tf_logs = all(
|
|
getattr(cb, "_supports_tf_logs", False)
|
|
for cb in self.callbacks
|
|
if cb._implements_train_batch_hooks()
|
|
or cb._implements_test_batch_hooks()
|
|
or cb._implements_predict_batch_hooks()
|
|
)
|
|
|
|
self._should_call_train_batch_hooks = any(
|
|
cb._implements_train_batch_hooks() for cb in self.callbacks
|
|
)
|
|
self._should_call_test_batch_hooks = any(
|
|
cb._implements_test_batch_hooks() for cb in self.callbacks
|
|
)
|
|
self._should_call_predict_batch_hooks = any(
|
|
cb._implements_predict_batch_hooks() for cb in self.callbacks
|
|
)
|
|
|
|
self._disallow_batch_hooks_in_ps_strategy()
|
|
|
|
# Performance check: Check batch hooks for slowness compared to batch
|
|
# time. Only run check for custom callbacks (i.e. not present in this
|
|
# file).
|
|
self._check_timing = any(
|
|
cbk.__class__.__name__ not in globals() for cbk in self.callbacks
|
|
)
|
|
self._num_batches_for_timing_check = 5
|
|
self._hook_times = {}
|
|
self._batch_start_time = None
|
|
self._batch_times = []
|
|
|
|
def _add_default_callbacks(self, add_history, add_progbar):
|
|
"""Adds `Callback`s that are always present."""
|
|
self._progbar = None
|
|
self._history = None
|
|
|
|
for cb in self.callbacks:
|
|
if isinstance(cb, ProgbarLogger):
|
|
self._progbar = cb
|
|
elif isinstance(cb, History):
|
|
self._history = cb
|
|
|
|
if self._history is None and add_history:
|
|
self._history = History()
|
|
self.callbacks.append(self._history)
|
|
|
|
if self._progbar is None and add_progbar:
|
|
self._progbar = ProgbarLogger(count_mode="steps")
|
|
self.callbacks.append(self._progbar)
|
|
|
|
def _process_logs(self, logs, is_batch_hook=False):
|
|
"""Turns tensors into numpy arrays or Python scalars if necessary."""
|
|
if logs is None:
|
|
return {}
|
|
if self._supports_tf_logs:
|
|
return logs
|
|
if is_batch_hook and self._batch_hooks_support_tf_logs:
|
|
return logs
|
|
return tf_utils.sync_to_numpy_or_python_type(logs)
|
|
|
|
def append(self, callback):
|
|
self.callbacks.append(callback)
|
|
|
|
def set_params(self, params):
|
|
self.params = params
|
|
for callback in self.callbacks:
|
|
callback.set_params(params)
|
|
|
|
def set_model(self, model):
|
|
self.model = model
|
|
if self._history:
|
|
model.history = self._history
|
|
for callback in self.callbacks:
|
|
callback.set_model(model)
|
|
|
|
def _call_batch_hook(self, mode, hook, batch, logs=None):
|
|
"""Helper function for all batch_{begin | end} methods."""
|
|
if not self.callbacks:
|
|
return
|
|
|
|
if hook == "begin":
|
|
self._call_batch_begin_hook(mode, batch, logs)
|
|
elif hook == "end":
|
|
self._call_batch_end_hook(mode, batch, logs)
|
|
else:
|
|
raise ValueError(
|
|
f"Unrecognized hook: {hook}. "
|
|
'Expected values are ["begin", "end"]'
|
|
)
|
|
|
|
def _call_batch_begin_hook(self, mode, batch, logs):
|
|
"""Helper function for `on_*_batch_begin` methods."""
|
|
hook_name = f"on_{mode}_batch_begin"
|
|
self._call_batch_hook_helper(hook_name, batch, logs)
|
|
|
|
if self._check_timing:
|
|
self._batch_start_time = time.time()
|
|
|
|
def _call_batch_end_hook(self, mode, batch, logs):
|
|
"""Helper function for `on_*_batch_end` methods."""
|
|
hook_name = f"on_{mode}_batch_end"
|
|
|
|
if self._check_timing and batch >= 1:
|
|
batch_time = time.time() - self._batch_start_time
|
|
self._batch_times.append(batch_time)
|
|
|
|
self._call_batch_hook_helper(hook_name, batch, logs)
|
|
|
|
if len(self._batch_times) >= self._num_batches_for_timing_check:
|
|
end_hook_name = hook_name
|
|
begin_hook_name = f"on_{mode}_batch_begin"
|
|
avg_batch_time = sum(self._batch_times) / len(self._batch_times)
|
|
avg_end_hook_time = sum(self._hook_times[end_hook_name]) / len(
|
|
self._hook_times[end_hook_name]
|
|
)
|
|
avg_begin_hook_time = sum(self._hook_times[begin_hook_name]) / len(
|
|
self._hook_times[begin_hook_name]
|
|
)
|
|
|
|
threshold_time = 1.0 * avg_batch_time
|
|
warning_msg = (
|
|
"Callback method `{hook}` is slow compared to "
|
|
"the batch time (batch time: {batch_time:.4f}s vs "
|
|
"`{hook}` time: {hook_time:.4f}s). Check your callbacks."
|
|
)
|
|
if avg_begin_hook_time > threshold_time:
|
|
logging.warning(
|
|
warning_msg.format(
|
|
hook=begin_hook_name,
|
|
batch_time=avg_batch_time,
|
|
hook_time=avg_begin_hook_time,
|
|
)
|
|
)
|
|
if avg_end_hook_time > threshold_time:
|
|
logging.warning(
|
|
warning_msg.format(
|
|
hook=end_hook_name,
|
|
batch_time=avg_batch_time,
|
|
hook_time=avg_end_hook_time,
|
|
)
|
|
)
|
|
self._check_timing = False
|
|
self._batch_start_time = None
|
|
self._batch_times = []
|
|
self._hook_times = {}
|
|
|
|
def _call_batch_hook_helper(self, hook_name, batch, logs):
|
|
"""Helper function for `on_*_batch_*` methods."""
|
|
if self._check_timing:
|
|
start_time = time.time()
|
|
|
|
logs = self._process_logs(logs, is_batch_hook=True)
|
|
for callback in self.callbacks:
|
|
hook = getattr(callback, hook_name)
|
|
hook(batch, logs)
|
|
|
|
if self._check_timing:
|
|
if hook_name not in self._hook_times:
|
|
self._hook_times[hook_name] = []
|
|
self._hook_times[hook_name].append(time.time() - start_time)
|
|
|
|
def _call_begin_hook(self, mode):
|
|
"""Helper function for on_{train|test|predict}_begin methods."""
|
|
if mode == ModeKeys.TRAIN:
|
|
self.on_train_begin()
|
|
elif mode == ModeKeys.TEST:
|
|
self.on_test_begin()
|
|
else:
|
|
self.on_predict_begin()
|
|
|
|
def _call_end_hook(self, mode):
|
|
"""Helper function for on_{train|test|predict}_end methods."""
|
|
if mode == ModeKeys.TRAIN:
|
|
self.on_train_end()
|
|
elif mode == ModeKeys.TEST:
|
|
self.on_test_end()
|
|
else:
|
|
self.on_predict_end()
|
|
|
|
def on_batch_begin(self, batch, logs=None):
|
|
if self._should_call_train_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.TRAIN, "begin", batch, logs=logs)
|
|
|
|
def on_batch_end(self, batch, logs=None):
|
|
if self._should_call_train_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.TRAIN, "end", batch, logs=logs)
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
"""Calls the `on_epoch_begin` methods of its callbacks.
|
|
|
|
This function should only be called during TRAIN mode.
|
|
|
|
Args:
|
|
epoch: Integer, index of epoch.
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_epoch_begin(epoch, logs)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
"""Calls the `on_epoch_end` methods of its callbacks.
|
|
|
|
This function should only be called during TRAIN mode.
|
|
|
|
Args:
|
|
epoch: Integer, index of epoch.
|
|
logs: Dict, metric results for this training epoch, and for the
|
|
validation epoch if validation is performed. Validation result
|
|
keys are prefixed with `val_`.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_epoch_end(epoch, logs)
|
|
|
|
def on_train_batch_begin(self, batch, logs=None):
|
|
"""Calls the `on_train_batch_begin` methods of its callbacks.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict, contains the return value of `model.train_step`.
|
|
Typically, the values of the `Model`'s metrics are returned.
|
|
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
|
|
"""
|
|
if self._should_call_train_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.TRAIN, "begin", batch, logs=logs)
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
"""Calls the `on_train_batch_end` methods of its callbacks.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Aggregated metric results up until this batch.
|
|
"""
|
|
if self._should_call_train_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.TRAIN, "end", batch, logs=logs)
|
|
|
|
def on_test_batch_begin(self, batch, logs=None):
|
|
"""Calls the `on_test_batch_begin` methods of its callbacks.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict, contains the return value of `model.test_step`.
|
|
Typically, the values of the `Model`'s metrics are returned.
|
|
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
|
|
"""
|
|
if self._should_call_test_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.TEST, "begin", batch, logs=logs)
|
|
|
|
def on_test_batch_end(self, batch, logs=None):
|
|
"""Calls the `on_test_batch_end` methods of its callbacks.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Aggregated metric results up until this batch.
|
|
"""
|
|
if self._should_call_test_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.TEST, "end", batch, logs=logs)
|
|
|
|
def on_predict_batch_begin(self, batch, logs=None):
|
|
"""Calls the `on_predict_batch_begin` methods of its callbacks.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict, contains the return value of `model.predict_step`,
|
|
it typically returns a dict with a key 'outputs' containing
|
|
the model's outputs.
|
|
"""
|
|
if self._should_call_predict_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.PREDICT, "begin", batch, logs=logs)
|
|
|
|
def on_predict_batch_end(self, batch, logs=None):
|
|
"""Calls the `on_predict_batch_end` methods of its callbacks.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Aggregated metric results up until this batch.
|
|
"""
|
|
if self._should_call_predict_batch_hooks:
|
|
self._call_batch_hook(ModeKeys.PREDICT, "end", batch, logs=logs)
|
|
|
|
def on_train_begin(self, logs=None):
|
|
"""Calls the `on_train_begin` methods of its callbacks.
|
|
|
|
Args:
|
|
logs: Dict. Currently, no data is passed via this argument
|
|
for this method, but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_train_begin(logs)
|
|
|
|
def on_train_end(self, logs=None):
|
|
"""Calls the `on_train_end` methods of its callbacks.
|
|
|
|
Args:
|
|
logs: Dict. Currently, no data is passed via this argument
|
|
for this method, but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_train_end(logs)
|
|
|
|
def on_test_begin(self, logs=None):
|
|
"""Calls the `on_test_begin` methods of its callbacks.
|
|
|
|
Args:
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_test_begin(logs)
|
|
|
|
def on_test_end(self, logs=None):
|
|
"""Calls the `on_test_end` methods of its callbacks.
|
|
|
|
Args:
|
|
logs: Dict. Currently, no data is passed via this argument
|
|
for this method, but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_test_end(logs)
|
|
|
|
def on_predict_begin(self, logs=None):
|
|
"""Calls the 'on_predict_begin` methods of its callbacks.
|
|
|
|
Args:
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_predict_begin(logs)
|
|
|
|
def on_predict_end(self, logs=None):
|
|
"""Calls the `on_predict_end` methods of its callbacks.
|
|
|
|
Args:
|
|
logs: Dict. Currently, no data is passed via this argument
|
|
for this method, but that may change in the future.
|
|
"""
|
|
logs = self._process_logs(logs)
|
|
for callback in self.callbacks:
|
|
callback.on_predict_end(logs)
|
|
|
|
def __iter__(self):
|
|
return iter(self.callbacks)
|
|
|
|
def _disallow_batch_hooks_in_ps_strategy(self):
|
|
"""Error out if batch-level callbacks are passed with PSStrategy."""
|
|
|
|
strategy = tf.distribute.get_strategy()
|
|
if strategy._should_use_with_coordinator:
|
|
unsupported_callbacks = []
|
|
for cb in self.callbacks:
|
|
# These Callbacks can accept RemoteValues directly.
|
|
if getattr(cb, "_supports_tf_logs", False):
|
|
continue
|
|
if (
|
|
cb._implements_train_batch_hooks()
|
|
or cb._implements_test_batch_hooks()
|
|
or cb._implements_predict_batch_hooks()
|
|
):
|
|
unsupported_callbacks.append(cb)
|
|
if unsupported_callbacks:
|
|
raise ValueError(
|
|
"Batch-level `Callback`s are not supported with "
|
|
"`ParameterServerStrategy`. Found unsupported "
|
|
f"callbacks: {unsupported_callbacks}"
|
|
)
|
|
|
|
def make_logs(self, model, logs, outputs, mode, prefix=""):
|
|
"""Computes logs for sending to `on_batch_end` methods."""
|
|
if not self.callbacks:
|
|
return logs
|
|
|
|
return make_logs(model, logs, outputs, mode, prefix=prefix)
|
|
|
|
|
|
@keras_export("keras.callbacks.Callback")
|
|
class Callback:
|
|
"""Abstract base class used to build new callbacks.
|
|
|
|
Callbacks can be passed to keras methods such as `fit`, `evaluate`, and
|
|
`predict` in order to hook into the various stages of the model training and
|
|
inference lifecycle.
|
|
|
|
To create a custom callback, subclass `keras.callbacks.Callback` and
|
|
override the method associated with the stage of interest. See
|
|
https://www.tensorflow.org/guide/keras/custom_callback for more information.
|
|
|
|
Example:
|
|
|
|
>>> training_finished = False
|
|
>>> class MyCallback(tf.keras.callbacks.Callback):
|
|
... def on_train_end(self, logs=None):
|
|
... global training_finished
|
|
... training_finished = True
|
|
>>> model = tf.keras.Sequential([
|
|
... tf.keras.layers.Dense(1, input_shape=(1,))])
|
|
>>> model.compile(loss='mean_squared_error')
|
|
>>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
|
|
... callbacks=[MyCallback()])
|
|
>>> assert training_finished == True
|
|
|
|
If you want to use `Callback` objects in a custom training loop:
|
|
|
|
1. You should pack all your callbacks into a single `callbacks.CallbackList`
|
|
so they can all be called together.
|
|
2. You will need to manually call all the `on_*` methods at the appropriate
|
|
locations in your loop. Like this:
|
|
|
|
Example:
|
|
```python
|
|
callbacks = tf.keras.callbacks.CallbackList([...])
|
|
callbacks.append(...)
|
|
callbacks.on_train_begin(...)
|
|
for epoch in range(EPOCHS):
|
|
callbacks.on_epoch_begin(epoch)
|
|
for i, data in dataset.enumerate():
|
|
callbacks.on_train_batch_begin(i)
|
|
batch_logs = model.train_step(data)
|
|
callbacks.on_train_batch_end(i, batch_logs)
|
|
epoch_logs = ...
|
|
callbacks.on_epoch_end(epoch, epoch_logs)
|
|
final_logs=...
|
|
callbacks.on_train_end(final_logs)
|
|
```
|
|
|
|
Attributes:
|
|
params: Dict. Training parameters
|
|
(eg. verbosity, batch size, number of epochs...).
|
|
model: Instance of `keras.models.Model`.
|
|
Reference of the model being trained.
|
|
|
|
The `logs` dictionary that callback methods
|
|
take as argument will contain keys for quantities relevant to
|
|
the current batch or epoch (see method-specific docstrings).
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.validation_data = None
|
|
self.model = None
|
|
# Whether this Callback should only run on the chief worker in a
|
|
# Multi-Worker setting.
|
|
# TODO(omalleyt): Make this attr public once solution is stable.
|
|
self._chief_worker_only = None
|
|
self._supports_tf_logs = False
|
|
|
|
def set_params(self, params):
|
|
self.params = params
|
|
|
|
def set_model(self, model):
|
|
self.model = model
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_batch_begin(self, batch, logs=None):
|
|
"""A backwards compatibility alias for `on_train_batch_begin`."""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_batch_end(self, batch, logs=None):
|
|
"""A backwards compatibility alias for `on_train_batch_end`."""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
"""Called at the start of an epoch.
|
|
|
|
Subclasses should override for any actions to run. This function should
|
|
only be called during TRAIN mode.
|
|
|
|
Args:
|
|
epoch: Integer, index of epoch.
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
"""Called at the end of an epoch.
|
|
|
|
Subclasses should override for any actions to run. This function should
|
|
only be called during TRAIN mode.
|
|
|
|
Args:
|
|
epoch: Integer, index of epoch.
|
|
logs: Dict, metric results for this training epoch, and for the
|
|
validation epoch if validation is performed. Validation result
|
|
keys are prefixed with `val_`. For training epoch, the values of
|
|
the `Model`'s metrics are returned. Example:
|
|
`{'loss': 0.2, 'accuracy': 0.7}`.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_train_batch_begin(self, batch, logs=None):
|
|
"""Called at the beginning of a training batch in `fit` methods.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Note that if the `steps_per_execution` argument to `compile` in
|
|
`tf.keras.Model` is set to `N`, this method will only be called every
|
|
`N` batches.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
# For backwards compatibility.
|
|
self.on_batch_begin(batch, logs=logs)
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
"""Called at the end of a training batch in `fit` methods.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Note that if the `steps_per_execution` argument to `compile` in
|
|
`tf.keras.Model` is set to `N`, this method will only be called every
|
|
`N` batches.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Aggregated metric results up until this batch.
|
|
"""
|
|
# For backwards compatibility.
|
|
self.on_batch_end(batch, logs=logs)
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_test_batch_begin(self, batch, logs=None):
|
|
"""Called at the beginning of a batch in `evaluate` methods.
|
|
|
|
Also called at the beginning of a validation batch in the `fit`
|
|
methods, if validation data is provided.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Note that if the `steps_per_execution` argument to `compile` in
|
|
`tf.keras.Model` is set to `N`, this method will only be called every
|
|
`N` batches.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_test_batch_end(self, batch, logs=None):
|
|
"""Called at the end of a batch in `evaluate` methods.
|
|
|
|
Also called at the end of a validation batch in the `fit`
|
|
methods, if validation data is provided.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Note that if the `steps_per_execution` argument to `compile` in
|
|
`tf.keras.Model` is set to `N`, this method will only be called every
|
|
`N` batches.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Aggregated metric results up until this batch.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_predict_batch_begin(self, batch, logs=None):
|
|
"""Called at the beginning of a batch in `predict` methods.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Note that if the `steps_per_execution` argument to `compile` in
|
|
`tf.keras.Model` is set to `N`, this method will only be called every
|
|
`N` batches.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
@generic_utils.default
|
|
def on_predict_batch_end(self, batch, logs=None):
|
|
"""Called at the end of a batch in `predict` methods.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Note that if the `steps_per_execution` argument to `compile` in
|
|
`tf.keras.Model` is set to `N`, this method will only be called every
|
|
`N` batches.
|
|
|
|
Args:
|
|
batch: Integer, index of batch within the current epoch.
|
|
logs: Dict. Aggregated metric results up until this batch.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_train_begin(self, logs=None):
|
|
"""Called at the beginning of training.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Args:
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_train_end(self, logs=None):
|
|
"""Called at the end of training.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Args:
|
|
logs: Dict. Currently the output of the last call to
|
|
`on_epoch_end()` is passed to this argument for this method but
|
|
that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_test_begin(self, logs=None):
|
|
"""Called at the beginning of evaluation or validation.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Args:
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_test_end(self, logs=None):
|
|
"""Called at the end of evaluation or validation.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Args:
|
|
logs: Dict. Currently the output of the last call to
|
|
`on_test_batch_end()` is passed to this argument for this method
|
|
but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_predict_begin(self, logs=None):
|
|
"""Called at the beginning of prediction.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Args:
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
@doc_controls.for_subclass_implementers
|
|
def on_predict_end(self, logs=None):
|
|
"""Called at the end of prediction.
|
|
|
|
Subclasses should override for any actions to run.
|
|
|
|
Args:
|
|
logs: Dict. Currently no data is passed to this argument for this
|
|
method but that may change in the future.
|
|
"""
|
|
|
|
def _implements_train_batch_hooks(self):
|
|
"""Determines if this Callback should be called for each train batch."""
|
|
return (
|
|
not generic_utils.is_default(self.on_batch_begin)
|
|
or not generic_utils.is_default(self.on_batch_end)
|
|
or not generic_utils.is_default(self.on_train_batch_begin)
|
|
or not generic_utils.is_default(self.on_train_batch_end)
|
|
)
|
|
|
|
def _implements_test_batch_hooks(self):
|
|
"""Determines if this Callback should be called for each test batch."""
|
|
return not generic_utils.is_default(
|
|
self.on_test_batch_begin
|
|
) or not generic_utils.is_default(self.on_test_batch_end)
|
|
|
|
def _implements_predict_batch_hooks(self):
|
|
"""Determines if this Callback should be called for each predict
|
|
batch."""
|
|
return not generic_utils.is_default(
|
|
self.on_predict_batch_begin
|
|
) or not generic_utils.is_default(self.on_predict_batch_end)
|
|
|
|
|
|
@keras_export("keras.callbacks.BaseLogger")
|
|
class BaseLogger(Callback):
|
|
"""Callback that accumulates epoch averages of metrics.
|
|
|
|
This callback is automatically applied to every Keras model.
|
|
|
|
Args:
|
|
stateful_metrics: Iterable of string names of metrics that
|
|
should *not* be averaged over an epoch.
|
|
Metrics in this list will be logged as-is in `on_epoch_end`.
|
|
All others will be averaged in `on_epoch_end`.
|
|
"""
|
|
|
|
def __init__(self, stateful_metrics=None):
|
|
super().__init__()
|
|
self.stateful_metrics = set(stateful_metrics or [])
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
self.seen = 0
|
|
self.totals = {}
|
|
|
|
def on_batch_end(self, batch, logs=None):
|
|
logs = logs or {}
|
|
batch_size = logs.get("size", 0)
|
|
# In case of distribution strategy we can potentially run multiple steps
|
|
# at the same time, we should account for that in the `seen`
|
|
# calculation.
|
|
num_steps = logs.get("num_steps", 1)
|
|
self.seen += batch_size * num_steps
|
|
|
|
for k, v in logs.items():
|
|
if k in self.stateful_metrics:
|
|
self.totals[k] = v
|
|
else:
|
|
if k in self.totals:
|
|
self.totals[k] += v * batch_size
|
|
else:
|
|
self.totals[k] = v * batch_size
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
if logs is not None:
|
|
for k in self.params["metrics"]:
|
|
if k in self.totals:
|
|
# Make value available to next callbacks.
|
|
if k in self.stateful_metrics:
|
|
logs[k] = self.totals[k]
|
|
else:
|
|
logs[k] = self.totals[k] / self.seen
|
|
|
|
|
|
@keras_export("keras.callbacks.TerminateOnNaN")
|
|
class TerminateOnNaN(Callback):
|
|
"""Callback that terminates training when a NaN loss is encountered."""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._supports_tf_logs = True
|
|
|
|
def on_batch_end(self, batch, logs=None):
|
|
logs = logs or {}
|
|
loss = logs.get("loss")
|
|
if loss is not None:
|
|
loss = tf_utils.sync_to_numpy_or_python_type(loss)
|
|
if np.isnan(loss) or np.isinf(loss):
|
|
io_utils.print_msg(
|
|
f"Batch {batch}: Invalid loss, terminating training"
|
|
)
|
|
self.model.stop_training = True
|
|
|
|
|
|
@keras_export("keras.callbacks.ProgbarLogger")
|
|
class ProgbarLogger(Callback):
|
|
"""Callback that prints metrics to stdout.
|
|
|
|
Args:
|
|
count_mode: One of `"steps"` or `"samples"`.
|
|
Whether the progress bar should
|
|
count samples seen or steps (batches) seen.
|
|
stateful_metrics: Iterable of string names of metrics that
|
|
should *not* be averaged over an epoch.
|
|
Metrics in this list will be logged as-is.
|
|
All others will be averaged over time (e.g. loss, etc).
|
|
If not provided, defaults to the `Model`'s metrics.
|
|
|
|
Raises:
|
|
ValueError: In case of invalid `count_mode`.
|
|
"""
|
|
|
|
def __init__(self, count_mode: str = "samples", stateful_metrics=None):
|
|
super().__init__()
|
|
self._supports_tf_logs = True
|
|
if count_mode == "samples":
|
|
self.use_steps = False
|
|
elif count_mode == "steps":
|
|
self.use_steps = True
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown `count_mode`: {count_mode}. "
|
|
'Expected values are ["samples", "steps"]'
|
|
)
|
|
# Defaults to all Model's metrics except for loss.
|
|
self.stateful_metrics = (
|
|
set(stateful_metrics) if stateful_metrics else set()
|
|
)
|
|
|
|
self.seen = 0
|
|
self.progbar = None
|
|
self.target = None
|
|
self.verbose = 1
|
|
self.epochs = 1
|
|
|
|
self._train_step, self._test_step, self._predict_step = None, None, None
|
|
self._call_batch_hooks = True
|
|
|
|
self._called_in_fit = False
|
|
|
|
def set_params(self, params):
|
|
self.verbose = params["verbose"]
|
|
self.epochs = params["epochs"]
|
|
if self.use_steps and "steps" in params:
|
|
self.target = params["steps"]
|
|
elif not self.use_steps and "samples" in params:
|
|
self.target = params["samples"]
|
|
else:
|
|
self.target = (
|
|
None # Will be inferred at the end of the first epoch.
|
|
)
|
|
|
|
self._call_batch_hooks = self.verbose == 1
|
|
if self.target is None:
|
|
try:
|
|
self._train_step = self.model._train_counter
|
|
self._test_step = self.model._test_counter
|
|
self._predict_step = self.model._predict_counter
|
|
except AttributeError:
|
|
self._call_batch_hooks = True
|
|
|
|
def on_train_begin(self, logs=None):
|
|
# When this logger is called inside `fit`, validation is silent.
|
|
self._called_in_fit = True
|
|
|
|
def on_test_begin(self, logs=None):
|
|
if not self._called_in_fit:
|
|
self._reset_progbar()
|
|
self._maybe_init_progbar()
|
|
|
|
def on_predict_begin(self, logs=None):
|
|
self._reset_progbar()
|
|
self._maybe_init_progbar()
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
self._reset_progbar()
|
|
self._maybe_init_progbar()
|
|
if self.verbose and self.epochs > 1:
|
|
io_utils.print_msg(f"Epoch {epoch + 1}/{self.epochs}")
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
self._batch_update_progbar(batch, logs)
|
|
|
|
def on_test_batch_end(self, batch, logs=None):
|
|
if not self._called_in_fit:
|
|
self._batch_update_progbar(batch, logs)
|
|
|
|
def on_predict_batch_end(self, batch, logs=None):
|
|
# Don't pass prediction results.
|
|
self._batch_update_progbar(batch, None)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
self._finalize_progbar(logs, self._train_step)
|
|
|
|
def on_test_end(self, logs=None):
|
|
if not self._called_in_fit:
|
|
self._finalize_progbar(logs, self._test_step)
|
|
|
|
def on_predict_end(self, logs=None):
|
|
self._finalize_progbar(logs, self._predict_step)
|
|
|
|
def _reset_progbar(self):
|
|
self.seen = 0
|
|
self.progbar = None
|
|
|
|
def _maybe_init_progbar(self):
|
|
"""Instantiate a `Progbar` if not yet, and update the stateful
|
|
metrics."""
|
|
# TODO(rchao): Legacy TF1 code path may use list for
|
|
# `self.stateful_metrics`. Remove "cast to set" when TF1 support is
|
|
# dropped.
|
|
self.stateful_metrics = set(self.stateful_metrics)
|
|
|
|
if self.model:
|
|
# Update the existing stateful metrics as `self.model.metrics` may
|
|
# contain updated metrics after `MetricsContainer` is built in the
|
|
# first train step.
|
|
self.stateful_metrics = self.stateful_metrics.union(
|
|
set(m.name for m in self.model.metrics)
|
|
)
|
|
|
|
if self.progbar is None:
|
|
self.progbar = Progbar(
|
|
target=self.target,
|
|
verbose=self.verbose,
|
|
stateful_metrics=self.stateful_metrics,
|
|
unit_name="step" if self.use_steps else "sample",
|
|
)
|
|
|
|
self.progbar._update_stateful_metrics(self.stateful_metrics)
|
|
|
|
def _implements_train_batch_hooks(self):
|
|
return self._call_batch_hooks
|
|
|
|
def _implements_test_batch_hooks(self):
|
|
return self._call_batch_hooks
|
|
|
|
def _implements_predict_batch_hooks(self):
|
|
return self._call_batch_hooks
|
|
|
|
def _batch_update_progbar(self, batch, logs=None):
|
|
"""Updates the progbar."""
|
|
logs = logs or {}
|
|
self._maybe_init_progbar()
|
|
if self.use_steps:
|
|
self.seen = batch + 1 # One-indexed.
|
|
else:
|
|
# v1 path only.
|
|
logs = copy.copy(logs)
|
|
batch_size = logs.pop("size", 0)
|
|
num_steps = logs.pop("num_steps", 1)
|
|
logs.pop("batch", None)
|
|
add_seen = num_steps * batch_size
|
|
self.seen += add_seen
|
|
|
|
if self.verbose == 1:
|
|
# Only block async when verbose = 1.
|
|
logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
|
self.progbar.update(self.seen, list(logs.items()), finalize=False)
|
|
|
|
def _finalize_progbar(self, logs, counter):
|
|
logs = tf_utils.sync_to_numpy_or_python_type(logs or {})
|
|
if self.target is None:
|
|
if counter is not None:
|
|
counter = counter.numpy()
|
|
if not self.use_steps:
|
|
counter *= logs.get("size", 1)
|
|
self.target = counter or self.seen
|
|
self.progbar.target = self.target
|
|
self.progbar.update(self.target, list(logs.items()), finalize=True)
|
|
|
|
|
|
@keras_export("keras.callbacks.History")
|
|
class History(Callback):
|
|
"""Callback that records events into a `History` object.
|
|
|
|
This callback is automatically applied to
|
|
every Keras model. The `History` object
|
|
gets returned by the `fit` method of models.
|
|
|
|
Example:
|
|
|
|
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
|
|
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
|
|
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
|
|
... epochs=10, verbose=1)
|
|
>>> print(history.params)
|
|
{'verbose': 1, 'epochs': 10, 'steps': 1}
|
|
>>> # check the keys of history object
|
|
>>> print(history.history.keys())
|
|
dict_keys(['loss'])
|
|
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.history = {}
|
|
|
|
def on_train_begin(self, logs=None):
|
|
self.epoch = []
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
self.epoch.append(epoch)
|
|
for k, v in logs.items():
|
|
self.history.setdefault(k, []).append(v)
|
|
|
|
# Set the history attribute on the model after the epoch ends. This will
|
|
# make sure that the state which is set is the latest one.
|
|
self.model.history = self
|
|
|
|
|
|
@keras_export("keras.callbacks.ModelCheckpoint")
|
|
class ModelCheckpoint(Callback):
|
|
"""Callback to save the Keras model or model weights at some frequency.
|
|
|
|
`ModelCheckpoint` callback is used in conjunction with training using
|
|
`model.fit()` to save a model or weights (in a checkpoint file) at some
|
|
interval, so the model or weights can be loaded later to continue the
|
|
training from the state saved.
|
|
|
|
A few options this callback provides include:
|
|
|
|
- Whether to only keep the model that has achieved the "best performance" so
|
|
far, or whether to save the model at the end of every epoch regardless of
|
|
performance.
|
|
- Definition of 'best'; which quantity to monitor and whether it should be
|
|
maximized or minimized.
|
|
- The frequency it should save at. Currently, the callback supports saving
|
|
at the end of every epoch, or after a fixed number of training batches.
|
|
- Whether only weights are saved, or the whole model is saved.
|
|
|
|
Note: If you get `WARNING:tensorflow:Can save best model only with <name>
|
|
available, skipping` see the description of the `monitor` argument for
|
|
details on how to get this right.
|
|
|
|
Example:
|
|
|
|
```python
|
|
model.compile(loss=..., optimizer=...,
|
|
metrics=['accuracy'])
|
|
|
|
EPOCHS = 10
|
|
checkpoint_filepath = '/tmp/checkpoint'
|
|
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
|
|
filepath=checkpoint_filepath,
|
|
save_weights_only=True,
|
|
monitor='val_accuracy',
|
|
mode='max',
|
|
save_best_only=True)
|
|
|
|
# Model weights are saved at the end of every epoch, if it's the best seen
|
|
# so far.
|
|
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
|
|
|
|
# The model weights (that are considered the best) are loaded into the
|
|
# model.
|
|
model.load_weights(checkpoint_filepath)
|
|
```
|
|
|
|
Args:
|
|
filepath: string or `PathLike`, path to save the model file. e.g.
|
|
filepath = os.path.join(working_dir, 'ckpt', file_name). `filepath`
|
|
can contain named formatting options, which will be filled the value
|
|
of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example:
|
|
if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then the
|
|
model checkpoints will be saved with the epoch number and the
|
|
validation loss in the filename. The directory of the filepath should
|
|
not be reused by any other callbacks to avoid conflicts.
|
|
monitor: The metric name to monitor. Typically the metrics are set by
|
|
the `Model.compile` method. Note:
|
|
|
|
* Prefix the name with `"val_`" to monitor validation metrics.
|
|
* Use `"loss"` or "`val_loss`" to monitor the model's total loss.
|
|
* If you specify metrics as strings, like `"accuracy"`, pass the same
|
|
string (with or without the `"val_"` prefix).
|
|
* If you pass `metrics.Metric` objects, `monitor` should be set to
|
|
`metric.name`
|
|
* If you're not sure about the metric names you can check the contents
|
|
of the `history.history` dictionary returned by
|
|
`history = model.fit()`
|
|
* Multi-output models set additional prefixes on the metric names.
|
|
|
|
verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1
|
|
displays messages when the callback takes an action.
|
|
save_best_only: if `save_best_only=True`, it only saves when the model
|
|
is considered the "best" and the latest best model according to the
|
|
quantity monitored will not be overwritten. If `filepath` doesn't
|
|
contain formatting options like `{epoch}` then `filepath` will be
|
|
overwritten by each new better model.
|
|
mode: one of {'auto', 'min', 'max'}. If `save_best_only=True`, the
|
|
decision to overwrite the current save file is made based on either
|
|
the maximization or the minimization of the monitored quantity.
|
|
For `val_acc`, this should be `max`, for `val_loss` this should be
|
|
`min`, etc. In `auto` mode, the mode is set to `max` if the quantities
|
|
monitored are 'acc' or start with 'fmeasure' and are set to `min` for
|
|
the rest of the quantities.
|
|
save_weights_only: if True, then only the model's weights will be saved
|
|
(`model.save_weights(filepath)`), else the full model is saved
|
|
(`model.save(filepath)`).
|
|
save_freq: `'epoch'` or integer. When using `'epoch'`, the callback
|
|
saves the model after each epoch. When using integer, the callback
|
|
saves the model at end of this many batches. If the `Model` is
|
|
compiled with `steps_per_execution=N`, then the saving criteria will
|
|
be checked every Nth batch. Note that if the saving isn't aligned to
|
|
epochs, the monitored metric may potentially be less reliable (it
|
|
could reflect as little as 1 batch, since the metrics get reset every
|
|
epoch). Defaults to `'epoch'`.
|
|
options: Optional `tf.train.CheckpointOptions` object if
|
|
`save_weights_only` is true or optional `tf.saved_model.SaveOptions`
|
|
object if `save_weights_only` is false.
|
|
initial_value_threshold: Floating point initial "best" value of the
|
|
metric to be monitored. Only applies if `save_best_value=True`. Only
|
|
overwrites the model weights already saved if the performance of
|
|
current model is better than this value.
|
|
**kwargs: Additional arguments for backwards compatibility. Possible key
|
|
is `period`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
filepath,
|
|
monitor: str = "val_loss",
|
|
verbose: int = 0,
|
|
save_best_only: bool = False,
|
|
save_weights_only: bool = False,
|
|
mode: str = "auto",
|
|
save_freq="epoch",
|
|
options=None,
|
|
initial_value_threshold=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self._supports_tf_logs = True
|
|
self.monitor = monitor
|
|
self.verbose = verbose
|
|
self.filepath = io_utils.path_to_string(filepath)
|
|
self.save_best_only = save_best_only
|
|
self.save_weights_only = save_weights_only
|
|
self.save_freq = save_freq
|
|
self.epochs_since_last_save = 0
|
|
self._batches_seen_since_last_saving = 0
|
|
self._last_batch_seen = 0
|
|
self.best = initial_value_threshold
|
|
|
|
if save_weights_only:
|
|
if options is None or isinstance(
|
|
options, tf.train.CheckpointOptions
|
|
):
|
|
self._options = options or tf.train.CheckpointOptions()
|
|
else:
|
|
raise TypeError(
|
|
"If save_weights_only is True, then `options` must be "
|
|
"either None or a tf.train.CheckpointOptions. "
|
|
f"Got {options}."
|
|
)
|
|
else:
|
|
if options is None or isinstance(
|
|
options, tf.saved_model.SaveOptions
|
|
):
|
|
self._options = options or tf.saved_model.SaveOptions()
|
|
else:
|
|
raise TypeError(
|
|
"If save_weights_only is False, then `options` must be "
|
|
"either None or a tf.saved_model.SaveOptions. "
|
|
f"Got {options}."
|
|
)
|
|
|
|
# Deprecated field `load_weights_on_restart` is for loading the
|
|
# checkpoint file from `filepath` at the start of `model.fit()`
|
|
# TODO(rchao): Remove the arg during next breaking release.
|
|
if "load_weights_on_restart" in kwargs:
|
|
self.load_weights_on_restart = kwargs["load_weights_on_restart"]
|
|
logging.warning(
|
|
"`load_weights_on_restart` argument is deprecated. "
|
|
"Please use `model.load_weights()` for loading weights "
|
|
"before the start of `model.fit()`."
|
|
)
|
|
else:
|
|
self.load_weights_on_restart = False
|
|
|
|
# Deprecated field `period` is for the number of epochs between which
|
|
# the model is saved.
|
|
if "period" in kwargs:
|
|
self.period = kwargs["period"]
|
|
logging.warning(
|
|
"`period` argument is deprecated. Please use `save_freq` "
|
|
"to specify the frequency in number of batches seen."
|
|
)
|
|
else:
|
|
self.period = 1
|
|
|
|
if mode not in ["auto", "min", "max"]:
|
|
logging.warning(
|
|
"ModelCheckpoint mode %s is unknown, fallback to auto mode.",
|
|
mode,
|
|
)
|
|
mode = "auto"
|
|
|
|
if mode == "min":
|
|
self.monitor_op = np.less
|
|
if self.best is None:
|
|
self.best = np.Inf
|
|
elif mode == "max":
|
|
self.monitor_op = np.greater
|
|
if self.best is None:
|
|
self.best = -np.Inf
|
|
else:
|
|
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
|
|
self.monitor_op = np.greater
|
|
if self.best is None:
|
|
self.best = -np.Inf
|
|
else:
|
|
self.monitor_op = np.less
|
|
if self.best is None:
|
|
self.best = np.Inf
|
|
|
|
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
|
|
raise ValueError(
|
|
f"Unrecognized save_freq: {self.save_freq}. "
|
|
'Expected save_freq are "epoch" or integer'
|
|
)
|
|
|
|
# Only the chief worker writes model checkpoints, but all workers
|
|
# restore checkpoint at on_train_begin().
|
|
self._chief_worker_only = False
|
|
|
|
def on_train_begin(self, logs=None):
|
|
if self.load_weights_on_restart:
|
|
filepath_to_load = (
|
|
self._get_most_recently_modified_file_matching_pattern(
|
|
self.filepath
|
|
)
|
|
)
|
|
if filepath_to_load is not None and self._checkpoint_exists(
|
|
filepath_to_load
|
|
):
|
|
try:
|
|
# `filepath` may contain placeholders such as `{epoch:02d}`,
|
|
# and thus it attempts to load the most recently modified
|
|
# file with file name matching the pattern.
|
|
self.model.load_weights(filepath_to_load)
|
|
except (IOError, ValueError) as e:
|
|
raise ValueError(
|
|
f"Error loading file from {filepath_to_load}. "
|
|
f"Reason: {e}"
|
|
)
|
|
|
|
def _implements_train_batch_hooks(self):
|
|
# Only call batch hooks when saving on batch
|
|
return self.save_freq != "epoch"
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
if self._should_save_on_batch(batch):
|
|
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
self._current_epoch = epoch
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
self.epochs_since_last_save += 1
|
|
|
|
if self.save_freq == "epoch":
|
|
self._save_model(epoch=epoch, batch=None, logs=logs)
|
|
|
|
def _should_save_on_batch(self, batch):
|
|
"""Handles batch-level saving logic, supports steps_per_execution."""
|
|
if self.save_freq == "epoch":
|
|
return False
|
|
|
|
if batch <= self._last_batch_seen: # New epoch.
|
|
add_batches = batch + 1 # batches are zero-indexed.
|
|
else:
|
|
add_batches = batch - self._last_batch_seen
|
|
self._batches_seen_since_last_saving += add_batches
|
|
self._last_batch_seen = batch
|
|
|
|
if self._batches_seen_since_last_saving >= self.save_freq:
|
|
self._batches_seen_since_last_saving = 0
|
|
return True
|
|
return False
|
|
|
|
def _save_model(self, epoch, batch, logs):
|
|
"""Saves the model.
|
|
|
|
Args:
|
|
epoch: the epoch this iteration is in.
|
|
batch: the batch this iteration is in. `None` if the `save_freq`
|
|
is set to `epoch`.
|
|
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
|
|
"""
|
|
logs = logs or {}
|
|
|
|
if (
|
|
isinstance(self.save_freq, int)
|
|
or self.epochs_since_last_save >= self.period
|
|
):
|
|
# Block only when saving interval is reached.
|
|
logs = tf_utils.sync_to_numpy_or_python_type(logs)
|
|
self.epochs_since_last_save = 0
|
|
filepath = self._get_file_path(epoch, batch, logs)
|
|
|
|
# Create host directory if it doesn't exist.
|
|
dirname = os.path.dirname(filepath)
|
|
if dirname and not tf.io.gfile.exists(dirname):
|
|
tf.io.gfile.makedirs(dirname)
|
|
|
|
try:
|
|
if self.save_best_only:
|
|
current = logs.get(self.monitor)
|
|
if current is None:
|
|
logging.warning(
|
|
"Can save best model only with %s available, "
|
|
"skipping.",
|
|
self.monitor,
|
|
)
|
|
else:
|
|
if self.monitor_op(current, self.best):
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: {self.monitor} "
|
|
"improved "
|
|
f"from {self.best:.5f} to {current:.5f}, "
|
|
f"saving model to {filepath}"
|
|
)
|
|
self.best = current
|
|
if self.save_weights_only:
|
|
self.model.save_weights(
|
|
filepath,
|
|
overwrite=True,
|
|
options=self._options,
|
|
)
|
|
else:
|
|
self.model.save(
|
|
filepath,
|
|
overwrite=True,
|
|
options=self._options,
|
|
)
|
|
else:
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: "
|
|
f"{self.monitor} did not improve "
|
|
f"from {self.best:.5f}"
|
|
)
|
|
else:
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: saving model to {filepath}"
|
|
)
|
|
if self.save_weights_only:
|
|
self.model.save_weights(
|
|
filepath, overwrite=True, options=self._options
|
|
)
|
|
else:
|
|
self.model.save(
|
|
filepath, overwrite=True, options=self._options
|
|
)
|
|
|
|
self._maybe_remove_file()
|
|
except IsADirectoryError: # h5py 3.x
|
|
raise IOError(
|
|
"Please specify a non-directory filepath for "
|
|
"ModelCheckpoint. Filepath used is an existing "
|
|
f"directory: {filepath}"
|
|
)
|
|
except IOError as e: # h5py 2.x
|
|
# `e.errno` appears to be `None` so checking the content of
|
|
# `e.args[0]`.
|
|
if "is a directory" in str(e.args[0]).lower():
|
|
raise IOError(
|
|
"Please specify a non-directory filepath for "
|
|
"ModelCheckpoint. Filepath used is an existing "
|
|
f"directory: f{filepath}"
|
|
)
|
|
# Re-throw the error for any other causes.
|
|
raise e
|
|
|
|
def _get_file_path(self, epoch, batch, logs):
|
|
"""Returns the file path for checkpoint."""
|
|
|
|
try:
|
|
# `filepath` may contain placeholders such as
|
|
# `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between
|
|
# logged metrics and the path's placeholders can cause formatting to
|
|
# fail.
|
|
if batch is None or "batch" in logs:
|
|
file_path = self.filepath.format(epoch=epoch + 1, **logs)
|
|
else:
|
|
file_path = self.filepath.format(
|
|
epoch=epoch + 1, batch=batch + 1, **logs
|
|
)
|
|
except KeyError as e:
|
|
raise KeyError(
|
|
f'Failed to format this callback filepath: "{self.filepath}". '
|
|
f"Reason: {e}"
|
|
)
|
|
self._write_filepath = distributed_file_utils.write_filepath(
|
|
file_path, self.model.distribute_strategy
|
|
)
|
|
return self._write_filepath
|
|
|
|
def _maybe_remove_file(self):
|
|
# Remove the checkpoint directory in multi-worker training where this
|
|
# worker should not checkpoint. It is a dummy directory previously saved
|
|
# for sync distributed training.
|
|
distributed_file_utils.remove_temp_dir_with_filepath(
|
|
self._write_filepath, self.model.distribute_strategy
|
|
)
|
|
|
|
def _checkpoint_exists(self, filepath):
|
|
"""Returns whether the checkpoint `filepath` refers to exists."""
|
|
if filepath.endswith(".h5"):
|
|
return tf.io.gfile.exists(filepath)
|
|
tf_saved_model_exists = tf.io.gfile.exists(filepath)
|
|
tf_weights_only_checkpoint_exists = tf.io.gfile.exists(
|
|
filepath + ".index"
|
|
)
|
|
return tf_saved_model_exists or tf_weights_only_checkpoint_exists
|
|
|
|
def _get_most_recently_modified_file_matching_pattern(self, pattern):
|
|
"""Returns the most recently modified filepath matching pattern.
|
|
|
|
Pattern may contain python formatting placeholder. If
|
|
`tf.train.latest_checkpoint()` does not return None, use that;
|
|
otherwise, check for most recently modified one that matches the
|
|
pattern.
|
|
|
|
In the rare case where there are more than one pattern-matching file
|
|
having the same modified time that is most recent among all, return the
|
|
filepath that is largest (by `>` operator, lexicographically using the
|
|
numeric equivalents). This provides a tie-breaker when multiple files
|
|
are most recent. Note that a larger `filepath` can sometimes indicate a
|
|
later time of modification (for instance, when epoch/batch is used as
|
|
formatting option), but not necessarily (when accuracy or loss is used).
|
|
The tie-breaker is put in the logic as best effort to return the most
|
|
recent, and to avoid undeterministic result.
|
|
|
|
Modified time of a file is obtained with `os.path.getmtime()`.
|
|
|
|
This utility function is best demonstrated via an example:
|
|
|
|
```python
|
|
file_pattern = 'f.batch{batch:02d}epoch{epoch:02d}.h5'
|
|
test_dir = self.get_temp_dir()
|
|
path_pattern = os.path.join(test_dir, file_pattern)
|
|
file_paths = [
|
|
os.path.join(test_dir, file_name) for file_name in
|
|
['f.batch03epoch02.h5',
|
|
'f.batch02epoch02.h5', 'f.batch01epoch01.h5']
|
|
]
|
|
for file_path in file_paths:
|
|
# Write something to each of the files
|
|
self.assertEqual(
|
|
_get_most_recently_modified_file_matching_pattern(path_pattern),
|
|
file_paths[-1])
|
|
```
|
|
|
|
Args:
|
|
pattern: The file pattern that may optionally contain python
|
|
placeholder such as `{epoch:02d}`.
|
|
|
|
Returns:
|
|
The most recently modified file's full filepath matching `pattern`.
|
|
If `pattern` does not contain any placeholder, this returns the
|
|
filepath that exactly matches `pattern`. Returns `None` if no match
|
|
is found.
|
|
"""
|
|
dir_name = os.path.dirname(pattern)
|
|
base_name = os.path.basename(pattern)
|
|
base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$"
|
|
|
|
# If tf.train.latest_checkpoint tells us there exists a latest
|
|
# checkpoint, use that as it is more robust than `os.path.getmtime()`.
|
|
latest_tf_checkpoint = tf.train.latest_checkpoint(dir_name)
|
|
if latest_tf_checkpoint is not None and re.match(
|
|
base_name_regex, os.path.basename(latest_tf_checkpoint)
|
|
):
|
|
return latest_tf_checkpoint
|
|
|
|
latest_mod_time = 0
|
|
file_path_with_latest_mod_time = None
|
|
n_file_with_latest_mod_time = 0
|
|
file_path_with_largest_file_name = None
|
|
|
|
if tf.io.gfile.exists(dir_name):
|
|
for file_name in os.listdir(dir_name):
|
|
# Only consider if `file_name` matches the pattern.
|
|
if re.match(base_name_regex, file_name):
|
|
file_path = os.path.join(dir_name, file_name)
|
|
mod_time = os.path.getmtime(file_path)
|
|
if (
|
|
file_path_with_largest_file_name is None
|
|
or file_path > file_path_with_largest_file_name
|
|
):
|
|
file_path_with_largest_file_name = file_path
|
|
if mod_time > latest_mod_time:
|
|
latest_mod_time = mod_time
|
|
file_path_with_latest_mod_time = file_path
|
|
# In the case a file with later modified time is found,
|
|
# reset the counter for the number of files with latest
|
|
# modified time.
|
|
n_file_with_latest_mod_time = 1
|
|
elif mod_time == latest_mod_time:
|
|
# In the case a file has modified time tied with the
|
|
# most recent, increment the counter for the number of
|
|
# files with latest modified time by 1.
|
|
n_file_with_latest_mod_time += 1
|
|
|
|
if n_file_with_latest_mod_time == 1:
|
|
# Return the sole file that has most recent modified time.
|
|
return file_path_with_latest_mod_time
|
|
else:
|
|
# If there are more than one file having latest modified time,
|
|
# return the file path with the largest file name.
|
|
return file_path_with_largest_file_name
|
|
|
|
|
|
@keras_export("keras.callbacks.BackupAndRestore", v1=[])
|
|
class BackupAndRestore(Callback):
|
|
"""Callback to back up and restore the training state.
|
|
|
|
`BackupAndRestore` callback is intended to recover training from an
|
|
interruption that has happened in the middle of a `Model.fit` execution, by
|
|
backing up the training states in a temporary checkpoint file (with the help
|
|
of a `tf.train.CheckpointManager`), at the end of each epoch. Each backup
|
|
overwrites the previously written checkpoint file, so at any given time
|
|
there is at most one such checkpoint file for backup/restoring purpose.
|
|
|
|
If training restarts before completion, the training state (which includes
|
|
the `Model` weights and epoch number) is restored to the most recently saved
|
|
state at the beginning of a new `Model.fit` run. At the completion of a
|
|
`Model.fit` run, the temporary checkpoint file is deleted.
|
|
|
|
Note that the user is responsible to bring jobs back after the interruption.
|
|
This callback is important for the backup and restore mechanism for fault
|
|
tolerance purpose, and the model to be restored from a previous checkpoint
|
|
is expected to be the same as the one used to back up. If user changes
|
|
arguments passed to compile or fit, the checkpoint saved for fault tolerance
|
|
can become invalid.
|
|
|
|
Note:
|
|
|
|
1. This callback is not compatible with eager execution disabled.
|
|
2. A checkpoint is saved at the end of each epoch. After restoring,
|
|
`Model.fit` redoes any partial work during the unfinished epoch in which the
|
|
training got restarted (so the work done before the interruption doesn't
|
|
affect the final model state).
|
|
3. This works for both single worker and multi-worker modes. When
|
|
`Model.fit` is used with `tf.distribute`, it supports
|
|
`tf.distribute.MirroredStrategy`,
|
|
`tf.distribute.MultiWorkerMirroredStrategy`, `tf.distribute.TPUStrategy`,
|
|
and `tf.distribute.experimental.ParameterServerStrategy`.
|
|
|
|
Example:
|
|
|
|
>>> class InterruptingCallback(tf.keras.callbacks.Callback):
|
|
... def on_epoch_begin(self, epoch, logs=None):
|
|
... if epoch == 4:
|
|
... raise RuntimeError('Interrupting!')
|
|
>>> callback = tf.keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
|
|
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
|
|
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
|
|
>>> try:
|
|
... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
|
|
... batch_size=1, callbacks=[callback, InterruptingCallback()],
|
|
... verbose=0)
|
|
... except:
|
|
... pass
|
|
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
|
|
... epochs=10, batch_size=1, callbacks=[callback],
|
|
... verbose=0)
|
|
>>> # Only 6 more epochs are run, since first training got interrupted at
|
|
>>> # zero-indexed epoch 4, second training will continue from 4 to 9.
|
|
>>> len(history.history['loss'])
|
|
6
|
|
|
|
Besides the option to save at the end of every epoch or every N steps, if
|
|
you are doing distributed training with
|
|
`tf.distribute.MultiWorkerMirroredStrategy` on Google Cloud Platform or
|
|
Google Borg, you can also use the `save_before_preemption` argument
|
|
to enable saving a checkpoint right before a worker gets preempted
|
|
by other jobs and training gets interrupted. See
|
|
`tf.distribute.experimental.PreemptionCheckpointHandler` for more details.
|
|
|
|
Args:
|
|
backup_dir: String, path to store the checkpoint.
|
|
e.g. `backup_dir = os.path.join(working_dir, 'backup')`.
|
|
This is the directory in which the system stores temporary files to
|
|
recover the model from jobs terminated unexpectedly. The directory
|
|
cannot be reused elsewhere to store other files, e.g. by the
|
|
`BackupAndRestore` callback of another training run,
|
|
or by another callback
|
|
(e.g. `ModelCheckpoint`) of the same training.
|
|
save_freq: `'epoch'`, integer, or `False`. When set to `'epoch'`
|
|
the callback saves the checkpoint at the end of each epoch.
|
|
When set to an integer, the callback saves the checkpoint every
|
|
`save_freq` batches. Set `save_freq` to `False` if only using
|
|
preemption checkpointing (with `save_before_preemption=True`).
|
|
delete_checkpoint: Boolean, default to True. This `BackupAndRestore`
|
|
callback works by saving a checkpoint to back up the training state.
|
|
If `delete_checkpoint=True`, the checkpoint will be deleted after
|
|
training is finished. Use `False` if you'd like to keep the checkpoint
|
|
for future usage.
|
|
save_before_preemption: A boolean value instructing whether to turn on
|
|
the automatic checkpoint saving for preemption/maintenance events.
|
|
This only supports
|
|
`tf.distribute.MultiWorkerMirroredStrategy` on Google Cloud Platform
|
|
or Google Borg for now.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
backup_dir,
|
|
save_freq="epoch",
|
|
delete_checkpoint=True,
|
|
save_before_preemption=False,
|
|
):
|
|
super().__init__()
|
|
self.backup_dir = backup_dir
|
|
self._supports_tf_logs = True
|
|
self._supported_strategies = (
|
|
tf.distribute.MirroredStrategy,
|
|
tf.distribute.MultiWorkerMirroredStrategy,
|
|
tf.distribute.experimental.TPUStrategy,
|
|
tf.distribute.TPUStrategy,
|
|
tf.distribute.experimental.ParameterServerStrategy,
|
|
)
|
|
self.save_freq = save_freq
|
|
self.delete_checkpoint = delete_checkpoint
|
|
self.save_before_preemption = save_before_preemption
|
|
self._batches_count = 0
|
|
self._current_epoch = 0
|
|
|
|
if not tf.executing_eagerly():
|
|
if tf.inside_function():
|
|
raise ValueError(
|
|
"This Callback's method contains Python state and "
|
|
"should be called outside of `tf.function`s."
|
|
)
|
|
else: # Legacy graph mode:
|
|
raise ValueError(
|
|
"BackupAndRestore only supports eager mode. In graph "
|
|
"mode, consider using ModelCheckpoint to manually save "
|
|
"and restore weights with `model.load_weights()` and by "
|
|
"providing `initial_epoch` in `model.fit()` for fault "
|
|
"tolerance."
|
|
)
|
|
if (not save_freq) and (not save_before_preemption):
|
|
raise ValueError(
|
|
"Either `save_freq` or `save_before_preemption` " "must be set."
|
|
)
|
|
|
|
# Only the chief worker writes model checkpoints, but all workers
|
|
# restore checkpoint at on_train_begin().
|
|
self._chief_worker_only = False
|
|
|
|
def on_train_begin(self, logs=None):
|
|
# TrainingState is used to manage the training state needed for
|
|
# failure-recovery of a worker in training.
|
|
|
|
if self.model._distribution_strategy and not isinstance(
|
|
self.model.distribute_strategy, self._supported_strategies
|
|
):
|
|
raise NotImplementedError(
|
|
f"{type(self.model.distribute_strategy)} is not supported yet. "
|
|
"Currently BackupAndRestore callback "
|
|
"only supports empty strategy, "
|
|
"MirroredStrategy, MultiWorkerMirroredStrategy and TPUStrategy."
|
|
)
|
|
self.model._training_state = worker_training_state.WorkerTrainingState(
|
|
self.model,
|
|
self.backup_dir,
|
|
self.save_freq,
|
|
self.save_before_preemption,
|
|
)
|
|
self._training_state = self.model._training_state
|
|
self._training_state.restore()
|
|
|
|
def on_train_batch_begin(self, batch, logs=None):
|
|
self._training_state._ckpt_saved_batch.assign(batch)
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
self._training_state.backup_if_preempted()
|
|
if self.save_freq and self.save_freq != "epoch":
|
|
self._batches_count += 1
|
|
if self._batches_count >= self.save_freq:
|
|
self._batches_count = 0
|
|
self._backup(epoch=self._current_epoch, batch=batch)
|
|
|
|
def _implements_train_batch_hooks(self):
|
|
return self.save_freq != "epoch"
|
|
|
|
def on_train_end(self, logs=None):
|
|
if self.delete_checkpoint:
|
|
# On exit of training, delete the training state backup file saved
|
|
# for the purpose of worker recovery unless the user opts out.
|
|
self._training_state.delete_backup()
|
|
# Clean up the training state.
|
|
del self._training_state
|
|
del self.model._training_state
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
self._training_state._ckpt_saved_epoch.assign(epoch)
|
|
self._current_epoch = epoch
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
# Back up the model and current epoch for possible future recovery.
|
|
if self.save_freq == "epoch":
|
|
self._backup(epoch=epoch)
|
|
|
|
def _backup(self, epoch, batch=0):
|
|
self._training_state.back_up(epoch=epoch, batch=batch)
|
|
|
|
|
|
@keras_export("keras.callbacks.experimental.BackupAndRestore", v1=[])
|
|
@deprecation.deprecated_endpoints(
|
|
"keras.callbacks.experimental.BackupAndRestore"
|
|
)
|
|
class BackupAndRestoreExperimental(BackupAndRestore):
|
|
"""Deprecated. Please use `tf.keras.callbacks.BackupAndRestore` instead.
|
|
|
|
Caution: `tf.keras.callbacks.experimental.BackupAndRestore` endpoint is
|
|
deprecated and will be removed in a future release. Please use
|
|
`tf.keras.callbacks.BackupAndRestore`.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
logging.warning(
|
|
"`tf.keras.callbacks.experimental.BackupAndRestore` endpoint is "
|
|
"deprecated and will be removed in a future release. Please use "
|
|
"`tf.keras.callbacks.BackupAndRestore`."
|
|
)
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
@keras_export("keras.callbacks.EarlyStopping")
|
|
class EarlyStopping(Callback):
|
|
"""Stop training when a monitored metric has stopped improving.
|
|
|
|
Assuming the goal of a training is to minimize the loss. With this, the
|
|
metric to be monitored would be `'loss'`, and mode would be `'min'`. A
|
|
`model.fit()` training loop will check at end of every epoch whether
|
|
the loss is no longer decreasing, considering the `min_delta` and
|
|
`patience` if applicable. Once it's found no longer decreasing,
|
|
`model.stop_training` is marked True and the training terminates.
|
|
|
|
The quantity to be monitored needs to be available in `logs` dict.
|
|
To make it so, pass the loss or metrics at `model.compile()`.
|
|
|
|
Args:
|
|
monitor: Quantity to be monitored.
|
|
min_delta: Minimum change in the monitored quantity
|
|
to qualify as an improvement, i.e. an absolute
|
|
change of less than min_delta, will count as no
|
|
improvement.
|
|
patience: Number of epochs with no improvement
|
|
after which training will be stopped.
|
|
verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1
|
|
displays messages when the callback takes an action.
|
|
mode: One of `{"auto", "min", "max"}`. In `min` mode,
|
|
training will stop when the quantity
|
|
monitored has stopped decreasing; in `"max"`
|
|
mode it will stop when the quantity
|
|
monitored has stopped increasing; in `"auto"`
|
|
mode, the direction is automatically inferred
|
|
from the name of the monitored quantity.
|
|
baseline: Baseline value for the monitored quantity.
|
|
Training will stop if the model doesn't show improvement over the
|
|
baseline.
|
|
restore_best_weights: Whether to restore model weights from
|
|
the epoch with the best value of the monitored quantity.
|
|
If False, the model weights obtained at the last step of
|
|
training are used. An epoch will be restored regardless
|
|
of the performance relative to the `baseline`. If no epoch
|
|
improves on `baseline`, training will run for `patience`
|
|
epochs and restore weights from the best epoch in that set.
|
|
start_from_epoch: Number of epochs to wait before starting
|
|
to monitor improvement. This allows for a warm-up period in which
|
|
no improvement is expected and thus training will not be stopped.
|
|
|
|
|
|
Example:
|
|
|
|
>>> callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
|
|
>>> # This callback will stop the training when there is no improvement in
|
|
>>> # the loss for three consecutive epochs.
|
|
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
|
|
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
|
|
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
|
|
... epochs=10, batch_size=1, callbacks=[callback],
|
|
... verbose=0)
|
|
>>> len(history.history['loss']) # Only 4 epochs are run.
|
|
4
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
monitor="val_loss",
|
|
min_delta=0,
|
|
patience=0,
|
|
verbose=0,
|
|
mode="auto",
|
|
baseline=None,
|
|
restore_best_weights=False,
|
|
start_from_epoch=0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.monitor = monitor
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.baseline = baseline
|
|
self.min_delta = abs(min_delta)
|
|
self.wait = 0
|
|
self.stopped_epoch = 0
|
|
self.restore_best_weights = restore_best_weights
|
|
self.best_weights = None
|
|
self.start_from_epoch = start_from_epoch
|
|
|
|
if mode not in ["auto", "min", "max"]:
|
|
logging.warning(
|
|
"EarlyStopping mode %s is unknown, fallback to auto mode.",
|
|
mode,
|
|
)
|
|
mode = "auto"
|
|
|
|
if mode == "min":
|
|
self.monitor_op = np.less
|
|
elif mode == "max":
|
|
self.monitor_op = np.greater
|
|
else:
|
|
if (
|
|
self.monitor.endswith("acc")
|
|
or self.monitor.endswith("accuracy")
|
|
or self.monitor.endswith("auc")
|
|
):
|
|
self.monitor_op = np.greater
|
|
else:
|
|
self.monitor_op = np.less
|
|
|
|
if self.monitor_op == np.greater:
|
|
self.min_delta *= 1
|
|
else:
|
|
self.min_delta *= -1
|
|
|
|
def on_train_begin(self, logs=None):
|
|
# Allow instances to be re-used
|
|
self.wait = 0
|
|
self.stopped_epoch = 0
|
|
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
|
|
self.best_weights = None
|
|
self.best_epoch = 0
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
current = self.get_monitor_value(logs)
|
|
if current is None or epoch < self.start_from_epoch:
|
|
# If no monitor value exists or still in initial warm-up stage.
|
|
return
|
|
if self.restore_best_weights and self.best_weights is None:
|
|
# Restore the weights after first epoch if no progress is ever made.
|
|
self.best_weights = self.model.get_weights()
|
|
|
|
self.wait += 1
|
|
if self._is_improvement(current, self.best):
|
|
self.best = current
|
|
self.best_epoch = epoch
|
|
if self.restore_best_weights:
|
|
self.best_weights = self.model.get_weights()
|
|
# Only restart wait if we beat both the baseline and our previous
|
|
# best.
|
|
if self.baseline is None or self._is_improvement(
|
|
current, self.baseline
|
|
):
|
|
self.wait = 0
|
|
return
|
|
|
|
# Only check after the first epoch.
|
|
if self.wait >= self.patience and epoch > 0:
|
|
self.stopped_epoch = epoch
|
|
self.model.stop_training = True
|
|
if self.restore_best_weights and self.best_weights is not None:
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
"Restoring model weights from "
|
|
"the end of the best epoch: "
|
|
f"{self.best_epoch + 1}."
|
|
)
|
|
self.model.set_weights(self.best_weights)
|
|
|
|
def on_train_end(self, logs=None):
|
|
if self.stopped_epoch > 0 and self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"Epoch {self.stopped_epoch + 1}: early stopping"
|
|
)
|
|
|
|
def get_monitor_value(self, logs):
|
|
logs = logs or {}
|
|
monitor_value = logs.get(self.monitor)
|
|
if monitor_value is None:
|
|
logging.warning(
|
|
"Early stopping conditioned on metric `%s` "
|
|
"which is not available. Available metrics are: %s",
|
|
self.monitor,
|
|
",".join(list(logs.keys())),
|
|
)
|
|
return monitor_value
|
|
|
|
def _is_improvement(self, monitor_value, reference_value):
|
|
return self.monitor_op(monitor_value - self.min_delta, reference_value)
|
|
|
|
|
|
@keras_export("keras.callbacks.RemoteMonitor")
|
|
class RemoteMonitor(Callback):
|
|
"""Callback used to stream events to a server.
|
|
|
|
Requires the `requests` library.
|
|
Events are sent to `root + '/publish/epoch/end/'` by default. Calls are
|
|
HTTP POST, with a `data` argument which is a
|
|
JSON-encoded dictionary of event data.
|
|
If `send_as_json=True`, the content type of the request will be
|
|
`"application/json"`.
|
|
Otherwise the serialized JSON will be sent within a form.
|
|
|
|
Args:
|
|
root: String; root url of the target server.
|
|
path: String; path relative to `root` to which the events will be sent.
|
|
field: String; JSON field under which the data will be stored.
|
|
The field is used only if the payload is sent within a form
|
|
(i.e. send_as_json is set to False).
|
|
headers: Dictionary; optional custom HTTP headers.
|
|
send_as_json: Boolean; whether the request should be
|
|
sent as `"application/json"`.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
root="http://localhost:9000",
|
|
path="/publish/epoch/end/",
|
|
field="data",
|
|
headers=None,
|
|
send_as_json=False,
|
|
):
|
|
super().__init__()
|
|
|
|
self.root = root
|
|
self.path = path
|
|
self.field = field
|
|
self.headers = headers
|
|
self.send_as_json = send_as_json
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
if requests is None:
|
|
raise ImportError("RemoteMonitor requires the `requests` library.")
|
|
logs = logs or {}
|
|
send = {}
|
|
send["epoch"] = epoch
|
|
for k, v in logs.items():
|
|
# np.ndarray and np.generic are not scalar types
|
|
# therefore we must unwrap their scalar values and
|
|
# pass to the json-serializable dict 'send'
|
|
if isinstance(v, (np.ndarray, np.generic)):
|
|
send[k] = v.item()
|
|
else:
|
|
send[k] = v
|
|
try:
|
|
if self.send_as_json:
|
|
requests.post(
|
|
self.root + self.path, json=send, headers=self.headers
|
|
)
|
|
else:
|
|
requests.post(
|
|
self.root + self.path,
|
|
{self.field: json.dumps(send)},
|
|
headers=self.headers,
|
|
)
|
|
except requests.exceptions.RequestException:
|
|
logging.warning(
|
|
"Warning: could not reach RemoteMonitor root server at "
|
|
+ str(self.root)
|
|
)
|
|
|
|
|
|
@keras_export("keras.callbacks.LearningRateScheduler")
|
|
class LearningRateScheduler(Callback):
|
|
"""Learning rate scheduler.
|
|
|
|
At the beginning of every epoch, this callback gets the updated learning
|
|
rate value from `schedule` function provided at `__init__`, with the current
|
|
epoch and current learning rate, and applies the updated learning rate on
|
|
the optimizer.
|
|
|
|
Args:
|
|
schedule: a function that takes an epoch index (integer, indexed from 0)
|
|
and current learning rate (float) as inputs and returns a new
|
|
learning rate as output (float).
|
|
verbose: int. 0: quiet, 1: update messages.
|
|
|
|
Example:
|
|
|
|
>>> # This function keeps the initial learning rate for the first ten epochs
|
|
>>> # and decreases it exponentially after that.
|
|
>>> def scheduler(epoch, lr):
|
|
... if epoch < 10:
|
|
... return lr
|
|
... else:
|
|
... return lr * tf.math.exp(-0.1)
|
|
>>>
|
|
>>> model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
|
|
>>> model.compile(tf.keras.optimizers.SGD(), loss='mse')
|
|
>>> round(model.optimizer.lr.numpy(), 5)
|
|
0.01
|
|
|
|
>>> callback = tf.keras.callbacks.LearningRateScheduler(scheduler)
|
|
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
|
|
... epochs=15, callbacks=[callback], verbose=0)
|
|
>>> round(model.optimizer.lr.numpy(), 5)
|
|
0.00607
|
|
|
|
"""
|
|
|
|
def __init__(self, schedule, verbose=0):
|
|
super().__init__()
|
|
self.schedule = schedule
|
|
self.verbose = verbose
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
if not hasattr(self.model.optimizer, "lr"):
|
|
raise ValueError('Optimizer must have a "lr" attribute.')
|
|
try: # new API
|
|
lr = float(backend.get_value(self.model.optimizer.lr))
|
|
lr = self.schedule(epoch, lr)
|
|
except TypeError: # Support for old API for backward compatibility
|
|
lr = self.schedule(epoch)
|
|
if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
|
|
raise ValueError(
|
|
'The output of the "schedule" function '
|
|
f"should be float. Got: {lr}"
|
|
)
|
|
if isinstance(lr, tf.Tensor) and not lr.dtype.is_floating:
|
|
raise ValueError(
|
|
f"The dtype of `lr` Tensor should be float. Got: {lr.dtype}"
|
|
)
|
|
backend.set_value(self.model.optimizer.lr, backend.get_value(lr))
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch + 1}: LearningRateScheduler setting learning "
|
|
f"rate to {lr}."
|
|
)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
logs["lr"] = backend.get_value(self.model.optimizer.lr)
|
|
|
|
|
|
def keras_model_summary(name, data, step=None):
|
|
"""Writes a Keras model as JSON to as a Summary.
|
|
|
|
Writing the Keras model configuration allows the TensorBoard graph plugin to
|
|
render a conceptual graph, as opposed to graph of ops. In case the model
|
|
fails to serialize as JSON, it ignores and returns False.
|
|
|
|
Args:
|
|
name: A name for this summary. The summary tag used for TensorBoard will
|
|
be this name prefixed by any active name scopes.
|
|
data: A Keras Model to write.
|
|
step: Explicit `int64`-castable monotonic step value for this summary. If
|
|
omitted, this defaults to `tf.summary.experimental.get_step()`, which
|
|
must not be None.
|
|
|
|
Returns:
|
|
True on success, or False if no summary was written because no default
|
|
summary writer was available.
|
|
|
|
Raises:
|
|
ValueError: if a default writer exists, but no step was provided and
|
|
`tf.summary.experimental.get_step()` is None.
|
|
"""
|
|
summary_metadata = tf.compat.v1.SummaryMetadata()
|
|
# Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for
|
|
# the rationale.
|
|
summary_metadata.plugin_data.plugin_name = "graph_keras_model"
|
|
# version number = 1
|
|
summary_metadata.plugin_data.content = b"1"
|
|
|
|
try:
|
|
json_string = data.to_json()
|
|
except Exception as exc:
|
|
# An exception should not break a model code.
|
|
logging.warning(
|
|
"Model failed to serialize as JSON. Ignoring... %s", exc
|
|
)
|
|
return False
|
|
|
|
with tf.summary.experimental.summary_scope(
|
|
name, "graph_keras_model", [data, step]
|
|
) as (tag, _):
|
|
with tf.device("cpu:0"):
|
|
tensor = tf.constant(json_string, dtype=tf.string)
|
|
return tf.summary.write(
|
|
tag=tag, tensor=tensor, step=step, metadata=summary_metadata
|
|
)
|
|
|
|
|
|
@keras_export("keras.callbacks.TensorBoard", v1=[])
|
|
class TensorBoard(Callback, version_utils.TensorBoardVersionSelector):
|
|
|
|
"""Enable visualizations for TensorBoard.
|
|
|
|
TensorBoard is a visualization tool provided with TensorFlow.
|
|
|
|
This callback logs events for TensorBoard, including:
|
|
|
|
* Metrics summary plots
|
|
* Training graph visualization
|
|
* Weight histograms
|
|
* Sampled profiling
|
|
|
|
When used in `Model.evaluate`, in addition to epoch summaries, there will be
|
|
a summary that records evaluation metrics vs `Model.optimizer.iterations`
|
|
written. The metric names will be prepended with `evaluation`, with
|
|
`Model.optimizer.iterations` being the step in the visualized TensorBoard.
|
|
|
|
If you have installed TensorFlow with pip, you should be able
|
|
to launch TensorBoard from the command line:
|
|
|
|
```
|
|
tensorboard --logdir=path_to_your_logs
|
|
```
|
|
|
|
You can find more information about TensorBoard
|
|
[here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
|
|
|
|
Args:
|
|
log_dir: the path of the directory where to save the log files to be
|
|
parsed by TensorBoard. e.g. log_dir = os.path.join(working_dir,
|
|
'logs') This directory should not be reused by any other callbacks.
|
|
histogram_freq: frequency (in epochs) at which to compute
|
|
weight histograms for the layers of the model. If set to 0, histograms
|
|
won't be computed. Validation data (or split) must be specified for
|
|
histogram visualizations.
|
|
write_graph: whether to visualize the graph in TensorBoard. The log file
|
|
can become quite large when write_graph is set to True.
|
|
write_images: whether to write model weights to visualize as image in
|
|
TensorBoard.
|
|
write_steps_per_second: whether to log the training steps per second
|
|
into TensorBoard. This supports both epoch and batch frequency
|
|
logging.
|
|
update_freq: `'batch'` or `'epoch'` or integer. When using `'epoch'`,
|
|
writes the losses and metrics to TensorBoard after every epoch.
|
|
If using an integer, let's say `1000`, all metrics and losses
|
|
(including custom ones added by `Model.compile`) will be logged to
|
|
TensorBoard every 1000 batches. `'batch'` is a synonym for `1`,
|
|
meaning that they will be written every batch.
|
|
Note however that writing too frequently to TensorBoard can slow down
|
|
your training, especially when used with `tf.distribute.Strategy` as
|
|
it will incur additional synchronization overhead.
|
|
Use with `ParameterServerStrategy` is not supported.
|
|
Batch-level summary writing is also available via `train_step`
|
|
override. Please see
|
|
[TensorBoard Scalars tutorial](https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging) # noqa: E501
|
|
for more details.
|
|
profile_batch: Profile the batch(es) to sample compute characteristics.
|
|
profile_batch must be a non-negative integer or a tuple of integers.
|
|
A pair of positive integers signify a range of batches to profile.
|
|
By default, profiling is disabled.
|
|
embeddings_freq: frequency (in epochs) at which embedding layers will be
|
|
visualized. If set to 0, embeddings won't be visualized.
|
|
embeddings_metadata: Dictionary which maps embedding layer names to the
|
|
filename of a file in which to save metadata for the embedding layer.
|
|
In case the same metadata file is to be
|
|
used for all embedding layers, a single filename can be passed.
|
|
|
|
Examples:
|
|
|
|
Basic usage:
|
|
|
|
```python
|
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")
|
|
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
|
# Then run the tensorboard command to view the visualizations.
|
|
```
|
|
|
|
Custom batch-level summaries in a subclassed Model:
|
|
|
|
```python
|
|
class MyModel(tf.keras.Model):
|
|
|
|
def build(self, _):
|
|
self.dense = tf.keras.layers.Dense(10)
|
|
|
|
def call(self, x):
|
|
outputs = self.dense(x)
|
|
tf.summary.histogram('outputs', outputs)
|
|
return outputs
|
|
|
|
model = MyModel()
|
|
model.compile('sgd', 'mse')
|
|
|
|
# Make sure to set `update_freq=N` to log a batch-level summary every N
|
|
# batches. In addition to any `tf.summary` contained in `Model.call`,
|
|
# metrics added in `Model.compile` will be logged every N batches.
|
|
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
|
|
model.fit(x_train, y_train, callbacks=[tb_callback])
|
|
```
|
|
|
|
Custom batch-level summaries in a Functional API Model:
|
|
|
|
```python
|
|
def my_summary(x):
|
|
tf.summary.histogram('x', x)
|
|
return x
|
|
|
|
inputs = tf.keras.Input(10)
|
|
x = tf.keras.layers.Dense(10)(inputs)
|
|
outputs = tf.keras.layers.Lambda(my_summary)(x)
|
|
model = tf.keras.Model(inputs, outputs)
|
|
model.compile('sgd', 'mse')
|
|
|
|
# Make sure to set `update_freq=N` to log a batch-level summary every N
|
|
# batches. In addition to any `tf.summary` contained in `Model.call`,
|
|
# metrics added in `Model.compile` will be logged every N batches.
|
|
tb_callback = tf.keras.callbacks.TensorBoard('./logs', update_freq=1)
|
|
model.fit(x_train, y_train, callbacks=[tb_callback])
|
|
```
|
|
|
|
Profiling:
|
|
|
|
```python
|
|
# Profile a single batch, e.g. the 5th batch.
|
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
|
log_dir='./logs', profile_batch=5)
|
|
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
|
|
|
# Profile a range of batches, e.g. from 10 to 20.
|
|
tensorboard_callback = tf.keras.callbacks.TensorBoard(
|
|
log_dir='./logs', profile_batch=(10,20))
|
|
model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
log_dir="logs",
|
|
histogram_freq=0,
|
|
write_graph=True,
|
|
write_images=False,
|
|
write_steps_per_second=False,
|
|
update_freq="epoch",
|
|
profile_batch=0,
|
|
embeddings_freq=0,
|
|
embeddings_metadata=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self._supports_tf_logs = True
|
|
self._validate_kwargs(kwargs)
|
|
|
|
self.log_dir = io_utils.path_to_string(log_dir)
|
|
self.histogram_freq = histogram_freq
|
|
self.write_graph = write_graph
|
|
self.write_images = write_images
|
|
self.write_steps_per_second = write_steps_per_second
|
|
self.update_freq = 1 if update_freq == "batch" else update_freq
|
|
self.embeddings_freq = embeddings_freq
|
|
self.embeddings_metadata = embeddings_metadata
|
|
self._init_profile_batch(profile_batch)
|
|
self._global_train_batch = 0
|
|
self._previous_epoch_iterations = 0
|
|
self._train_accumulated_time = 0
|
|
self._batch_start_time = 0
|
|
|
|
# Lazily initialized in order to avoid creating event files when
|
|
# not needed.
|
|
self._writers = {}
|
|
|
|
# Used to restore any existing `SummaryWriter` after training ends.
|
|
self._prev_summary_state = []
|
|
|
|
def _validate_kwargs(self, kwargs):
|
|
"""Handle arguments were supported in V1."""
|
|
if kwargs.get("write_grads", False):
|
|
logging.warning(
|
|
"`write_grads` will be ignored in TensorFlow 2.0 "
|
|
"for the `TensorBoard` Callback."
|
|
)
|
|
if kwargs.get("batch_size", False):
|
|
logging.warning(
|
|
"`batch_size` is no longer needed in the "
|
|
"`TensorBoard` Callback and will be ignored "
|
|
"in TensorFlow 2.0."
|
|
)
|
|
if kwargs.get("embeddings_layer_names", False):
|
|
logging.warning(
|
|
"`embeddings_layer_names` is not supported in "
|
|
"TensorFlow 2.0. Instead, all `Embedding` layers "
|
|
"will be visualized."
|
|
)
|
|
if kwargs.get("embeddings_data", False):
|
|
logging.warning(
|
|
"`embeddings_data` is not supported in TensorFlow "
|
|
"2.0. Instead, all `Embedding` variables will be "
|
|
"visualized."
|
|
)
|
|
|
|
supported_kwargs = {
|
|
"write_grads",
|
|
"embeddings_layer_names",
|
|
"embeddings_data",
|
|
"batch_size",
|
|
}
|
|
unrecognized_kwargs = set(kwargs.keys()) - supported_kwargs
|
|
|
|
# Only allow kwargs that were supported in V1.
|
|
if unrecognized_kwargs:
|
|
raise ValueError(
|
|
"Unrecognized arguments in `TensorBoard` Callback: "
|
|
f"{unrecognized_kwargs}. "
|
|
f"Supported kwargs are: {supported_kwargs}"
|
|
)
|
|
|
|
def set_model(self, model):
|
|
"""Sets Keras model and writes graph if specified."""
|
|
self.model = model
|
|
self._log_write_dir = self._get_log_write_dir()
|
|
|
|
self._train_dir = os.path.join(self._log_write_dir, "train")
|
|
self._train_step = self.model._train_counter
|
|
|
|
self._val_dir = os.path.join(self._log_write_dir, "validation")
|
|
self._val_step = self.model._test_counter
|
|
|
|
self._writers = {} # Resets writers.
|
|
|
|
self._should_write_train_graph = False
|
|
if self.write_graph:
|
|
self._write_keras_model_summary()
|
|
self._should_write_train_graph = True
|
|
if self.embeddings_freq:
|
|
self._configure_embeddings()
|
|
|
|
@property
|
|
def _train_writer(self):
|
|
if "train" not in self._writers:
|
|
self._writers["train"] = tf.summary.create_file_writer(
|
|
self._train_dir
|
|
)
|
|
return self._writers["train"]
|
|
|
|
@property
|
|
def _val_writer(self):
|
|
if "val" not in self._writers:
|
|
self._writers["val"] = tf.summary.create_file_writer(self._val_dir)
|
|
return self._writers["val"]
|
|
|
|
def _get_log_write_dir(self):
|
|
"""For multi-worker, only chief should write, others write to '/tmp'."""
|
|
return distributed_file_utils.write_dirpath(
|
|
self.log_dir, self.model.distribute_strategy
|
|
)
|
|
|
|
def _delete_tmp_write_dir(self):
|
|
"""Deletes tmp write directories for multi-worker."""
|
|
distributed_file_utils.remove_temp_dirpath(
|
|
self.log_dir, self.model.distribute_strategy
|
|
)
|
|
|
|
def _write_keras_model_train_graph(self):
|
|
"""Writes Keras model train_function graph to TensorBoard."""
|
|
with self._train_writer.as_default():
|
|
with tf.summary.record_if(True):
|
|
train_fn = self.model.train_tf_function
|
|
# If the train_function is a `tf.function`, we can write out a
|
|
# graph
|
|
if hasattr(train_fn, "function_spec"):
|
|
# TODO(b/243822285): Use _variable_creation_fn directly.
|
|
if hasattr(train_fn, "_concrete_stateful_fn"):
|
|
tf.summary.graph(train_fn._concrete_stateful_fn.graph)
|
|
else:
|
|
tf.summary.graph(
|
|
train_fn._concrete_variable_creation_fn.graph
|
|
)
|
|
|
|
def _write_keras_model_summary(self):
|
|
"""Writes Keras graph network summary to TensorBoard."""
|
|
with self._train_writer.as_default():
|
|
with tf.summary.record_if(True):
|
|
summary_writable = (
|
|
self.model._is_graph_network
|
|
or self.model.__class__.__name__ == "Sequential"
|
|
)
|
|
if summary_writable:
|
|
keras_model_summary("keras", self.model, step=0)
|
|
|
|
def _configure_embeddings(self):
|
|
"""Configure the Projector for embeddings."""
|
|
# TODO(omalleyt): Add integration tests.
|
|
from keras.layers import core
|
|
from keras.protobuf import projector_config_pb2
|
|
|
|
# isort: off
|
|
from google.protobuf import text_format
|
|
|
|
config = projector_config_pb2.ProjectorConfig()
|
|
for layer in self.model.layers:
|
|
if isinstance(layer, core.Embedding):
|
|
embedding = config.embeddings.add()
|
|
# Embeddings are always the first layer, so this naming should
|
|
# be consistent in any keras models checkpoints.
|
|
name = (
|
|
"layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE"
|
|
)
|
|
embedding.tensor_name = name
|
|
|
|
if self.embeddings_metadata is not None:
|
|
if isinstance(self.embeddings_metadata, str):
|
|
embedding.metadata_path = self.embeddings_metadata
|
|
else:
|
|
if layer.name in self.embeddings_metadata.keys():
|
|
embedding.metadata_path = (
|
|
self.embeddings_metadata.pop(layer.name)
|
|
)
|
|
|
|
if self.embeddings_metadata and not isinstance(
|
|
self.embeddings_metadata, str
|
|
):
|
|
raise ValueError(
|
|
"Unrecognized `Embedding` layer names passed to "
|
|
"`keras.callbacks.TensorBoard` `embeddings_metadata` "
|
|
f"argument: {self.embeddings_metadata.keys()}"
|
|
)
|
|
|
|
config_pbtxt = text_format.MessageToString(config)
|
|
path = os.path.join(self._log_write_dir, "projector_config.pbtxt")
|
|
with tf.io.gfile.GFile(path, "w") as f:
|
|
f.write(config_pbtxt)
|
|
|
|
def _push_writer(self, writer, step):
|
|
"""Sets the default writer for custom batch-level summaries."""
|
|
if self.update_freq == "epoch":
|
|
return
|
|
|
|
should_record = lambda: tf.equal(step % self.update_freq, 0)
|
|
# TODO(b/151339474): Fix deadlock when not using .value() here.
|
|
summary_context = (
|
|
writer.as_default(step.value()),
|
|
tf.summary.record_if(should_record),
|
|
)
|
|
self._prev_summary_state.append(summary_context)
|
|
summary_context[0].__enter__()
|
|
summary_context[1].__enter__()
|
|
|
|
def _pop_writer(self):
|
|
"""Pops the current writer."""
|
|
if self.update_freq == "epoch":
|
|
return
|
|
|
|
# See _push_writer for the content of the previous_context, which is
|
|
# pair of context.
|
|
previous_context = self._prev_summary_state.pop()
|
|
previous_context[1].__exit__(*sys.exc_info())
|
|
previous_context[0].__exit__(*sys.exc_info())
|
|
|
|
def _close_writers(self):
|
|
for writer in self._writers.values():
|
|
writer.close()
|
|
|
|
def _init_profile_batch(self, profile_batch):
|
|
"""Validate profile_batch value and set the range of batches to profile.
|
|
|
|
Sets values of _start_batch and _stop_batch attributes,
|
|
specifying the start and stop batch to profile.
|
|
Setting `profile_batch=0` disables profiling.
|
|
|
|
Args:
|
|
profile_batch: The range of batches to profile. Should be a
|
|
non-negative integer or a comma separated string of pair of positive
|
|
integers. A pair of positive integers signify a range of batches to
|
|
profile.
|
|
|
|
Raises:
|
|
ValueError: If profile_batch is not an integer or a comma separated
|
|
pair of positive integers.
|
|
|
|
"""
|
|
profile_batch_error_message = (
|
|
"profile_batch must be a non-negative integer or "
|
|
"2-tuple of positive "
|
|
"integers. A pair of positive integers "
|
|
"signifies a range of batches "
|
|
f"to profile. Found: {profile_batch}"
|
|
)
|
|
|
|
# Support legacy way of specifying "start,stop" or "start" as str.
|
|
if isinstance(profile_batch, str):
|
|
profile_batch = str(profile_batch).split(",")
|
|
profile_batch = tf.nest.map_structure(int, profile_batch)
|
|
|
|
if isinstance(profile_batch, int):
|
|
self._start_batch = profile_batch
|
|
self._stop_batch = profile_batch
|
|
elif (
|
|
isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2
|
|
):
|
|
self._start_batch, self._stop_batch = profile_batch
|
|
else:
|
|
raise ValueError(profile_batch_error_message)
|
|
|
|
if self._start_batch < 0 or self._stop_batch < self._start_batch:
|
|
raise ValueError(profile_batch_error_message)
|
|
|
|
# True when the profiler was successfully started by this callback.
|
|
# We track the status here to make sure callbacks do not interfere with
|
|
# each other. The callback will only stop the profiler it started.
|
|
self._profiler_started = False
|
|
if self._start_batch > 0:
|
|
# Warm up and improve the profiling accuracy.
|
|
self._start_profiler(logdir="")
|
|
self._stop_profiler(save=False)
|
|
# True when a trace is running.
|
|
self._is_tracing = False
|
|
|
|
# Setting `profile_batch=0` disables profiling.
|
|
self._should_trace = not (
|
|
self._start_batch == 0 and self._stop_batch == 0
|
|
)
|
|
|
|
def on_train_begin(self, logs=None):
|
|
self._global_train_batch = 0
|
|
self._previous_epoch_iterations = 0
|
|
self._push_writer(self._train_writer, self._train_step)
|
|
|
|
def on_train_end(self, logs=None):
|
|
self._pop_writer()
|
|
|
|
if self._is_tracing:
|
|
self._stop_trace()
|
|
|
|
self._close_writers()
|
|
self._delete_tmp_write_dir()
|
|
|
|
def on_test_begin(self, logs=None):
|
|
self._push_writer(self._val_writer, self._val_step)
|
|
|
|
def on_test_end(self, logs=None):
|
|
if self.model.optimizer and hasattr(self.model.optimizer, "iterations"):
|
|
with tf.summary.record_if(True), self._val_writer.as_default():
|
|
for name, value in logs.items():
|
|
tf.summary.scalar(
|
|
"evaluation_" + name + "_vs_iterations",
|
|
value,
|
|
step=self.model.optimizer.iterations.read_value(),
|
|
)
|
|
self._pop_writer()
|
|
|
|
def _implements_train_batch_hooks(self):
|
|
# Only call batch hooks when tracing or write_steps_per_second are
|
|
# enabled
|
|
return self._should_trace or self.write_steps_per_second
|
|
|
|
def on_train_batch_begin(self, batch, logs=None):
|
|
self._global_train_batch += 1
|
|
if self.write_steps_per_second:
|
|
self._batch_start_time = time.time()
|
|
if not self._should_trace:
|
|
return
|
|
|
|
if self._global_train_batch == self._start_batch:
|
|
self._start_trace()
|
|
|
|
def on_train_batch_end(self, batch, logs=None):
|
|
if self._should_write_train_graph:
|
|
self._write_keras_model_train_graph()
|
|
self._should_write_train_graph = False
|
|
if self.write_steps_per_second:
|
|
batch_run_time = time.time() - self._batch_start_time
|
|
tf.summary.scalar(
|
|
"batch_steps_per_second",
|
|
1.0 / batch_run_time,
|
|
step=self._train_step,
|
|
)
|
|
|
|
# `logs` isn't necessarily always a dict. For example, when using
|
|
# `tf.distribute.experimental.ParameterServerStrategy`, a
|
|
# `tf.distribute.experimental.coordinator.RemoteValue` will be passed.
|
|
# For now, we just disable `update_freq` in those cases.
|
|
if isinstance(logs, dict):
|
|
for name, value in logs.items():
|
|
tf.summary.scalar("batch_" + name, value, step=self._train_step)
|
|
|
|
if not self._should_trace:
|
|
return
|
|
|
|
if self._is_tracing and self._global_train_batch >= self._stop_batch:
|
|
self._stop_trace()
|
|
|
|
def on_epoch_begin(self, epoch, logs=None):
|
|
# Keeps track of epoch for profiling.
|
|
if self.write_steps_per_second:
|
|
self._previous_epoch_iterations = (
|
|
self.model.optimizer.iterations.numpy()
|
|
)
|
|
self._epoch_start_time = time.time()
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
"""Runs metrics and histogram summaries at epoch end."""
|
|
self._log_epoch_metrics(epoch, logs)
|
|
|
|
if self.histogram_freq and epoch % self.histogram_freq == 0:
|
|
self._log_weights(epoch)
|
|
|
|
if self.embeddings_freq and epoch % self.embeddings_freq == 0:
|
|
self._log_embeddings(epoch)
|
|
|
|
def _start_trace(self):
|
|
tf.summary.trace_on(graph=True, profiler=False)
|
|
self._start_profiler(logdir=self.log_dir)
|
|
self._is_tracing = True
|
|
|
|
def _stop_trace(self, batch=None):
|
|
"""Logs the trace graph to TensorBoard."""
|
|
if batch is None:
|
|
batch = self._stop_batch
|
|
with self._train_writer.as_default():
|
|
with tf.summary.record_if(True):
|
|
# TODO(b/126388999): Remove step info in the summary name.
|
|
tf.summary.trace_export(name="batch_%d" % batch, step=batch)
|
|
self._stop_profiler()
|
|
self._is_tracing = False
|
|
|
|
def _collect_learning_rate(self, logs):
|
|
if isinstance(self.model.optimizer, optimizer.Optimizer):
|
|
lr_schedule = getattr(self.model.optimizer, "_learning_rate", None)
|
|
else:
|
|
lr_schedule = getattr(self.model.optimizer, "lr", None)
|
|
if isinstance(lr_schedule, learning_rate_schedule.LearningRateSchedule):
|
|
logs["learning_rate"] = lr_schedule(self.model.optimizer.iterations)
|
|
return logs
|
|
|
|
def _compute_steps_per_second(self):
|
|
current_iteration = self.model.optimizer.iterations.numpy()
|
|
time_since_epoch_begin = time.time() - self._epoch_start_time
|
|
steps_per_second = (
|
|
current_iteration - self._previous_epoch_iterations
|
|
) / time_since_epoch_begin
|
|
return steps_per_second
|
|
|
|
def _log_epoch_metrics(self, epoch, logs):
|
|
"""Writes epoch metrics out as scalar summaries.
|
|
|
|
Args:
|
|
epoch: Int. The global step to use for TensorBoard.
|
|
logs: Dict. Keys are scalar summary names, values are scalars.
|
|
"""
|
|
if not logs:
|
|
return
|
|
|
|
train_logs = {k: v for k, v in logs.items() if not k.startswith("val_")}
|
|
val_logs = {k: v for k, v in logs.items() if k.startswith("val_")}
|
|
train_logs = self._collect_learning_rate(train_logs)
|
|
if self.write_steps_per_second:
|
|
train_logs["steps_per_second"] = self._compute_steps_per_second()
|
|
|
|
with tf.summary.record_if(True):
|
|
if train_logs:
|
|
with self._train_writer.as_default():
|
|
for name, value in train_logs.items():
|
|
tf.summary.scalar("epoch_" + name, value, step=epoch)
|
|
if val_logs:
|
|
with self._val_writer.as_default():
|
|
for name, value in val_logs.items():
|
|
name = name[4:] # Remove 'val_' prefix.
|
|
tf.summary.scalar("epoch_" + name, value, step=epoch)
|
|
|
|
def _log_weights(self, epoch):
|
|
"""Logs the weights of the Model to TensorBoard."""
|
|
with self._train_writer.as_default():
|
|
with tf.summary.record_if(True):
|
|
for layer in self.model.layers:
|
|
for weight in layer.weights:
|
|
weight_name = weight.name.replace(":", "_")
|
|
# Add a suffix to prevent summary tag name collision.
|
|
histogram_weight_name = weight_name + "/histogram"
|
|
tf.summary.histogram(
|
|
histogram_weight_name, weight, step=epoch
|
|
)
|
|
if self.write_images:
|
|
# Add a suffix to prevent summary tag name
|
|
# collision.
|
|
image_weight_name = weight_name + "/image"
|
|
self._log_weight_as_image(
|
|
weight, image_weight_name, epoch
|
|
)
|
|
self._train_writer.flush()
|
|
|
|
def _log_weight_as_image(self, weight, weight_name, epoch):
|
|
"""Logs a weight as a TensorBoard image."""
|
|
w_img = tf.squeeze(weight)
|
|
shape = backend.int_shape(w_img)
|
|
if len(shape) == 1: # Bias case
|
|
w_img = tf.reshape(w_img, [1, shape[0], 1, 1])
|
|
elif len(shape) == 2: # Dense layer kernel case
|
|
if shape[0] > shape[1]:
|
|
w_img = tf.transpose(w_img)
|
|
shape = backend.int_shape(w_img)
|
|
w_img = tf.reshape(w_img, [1, shape[0], shape[1], 1])
|
|
elif len(shape) == 3: # ConvNet case
|
|
if backend.image_data_format() == "channels_last":
|
|
# Switch to channels_first to display every kernel as a separate
|
|
# image.
|
|
w_img = tf.transpose(w_img, perm=[2, 0, 1])
|
|
shape = backend.int_shape(w_img)
|
|
w_img = tf.reshape(w_img, [shape[0], shape[1], shape[2], 1])
|
|
|
|
shape = backend.int_shape(w_img)
|
|
# Not possible to handle 3D convnets etc.
|
|
if len(shape) == 4 and shape[-1] in [1, 3, 4]:
|
|
tf.summary.image(weight_name, w_img, step=epoch)
|
|
|
|
def _log_embeddings(self, epoch):
|
|
embeddings_ckpt = os.path.join(
|
|
self._log_write_dir,
|
|
"train",
|
|
f"keras_embedding.ckpt-{epoch}",
|
|
)
|
|
self.model.save_weights(embeddings_ckpt)
|
|
|
|
def _start_profiler(self, logdir):
|
|
"""Starts the profiler if currently inactive.
|
|
|
|
Args:
|
|
logdir: Directory where profiler results will be saved.
|
|
"""
|
|
if self._profiler_started:
|
|
return
|
|
try:
|
|
tf.profiler.experimental.start(logdir=logdir)
|
|
self._profiler_started = True
|
|
except tf.errors.AlreadyExistsError as e:
|
|
# Profiler errors should not be fatal.
|
|
logging.error("Failed to start profiler: %s", e.message)
|
|
|
|
def _stop_profiler(self, save=True):
|
|
"""Stops the profiler if currently active.
|
|
|
|
Args:
|
|
save: Whether to save the profiler results to TensorBoard.
|
|
"""
|
|
if not self._profiler_started:
|
|
return
|
|
try:
|
|
tf.profiler.experimental.stop(save=save)
|
|
except tf.errors.UnavailableError as e:
|
|
# Profiler errors should not be fatal.
|
|
logging.error("Failed to stop profiler: %s", e.message)
|
|
finally:
|
|
self._profiler_started = False
|
|
|
|
|
|
@keras_export("keras.callbacks.ReduceLROnPlateau")
|
|
class ReduceLROnPlateau(Callback):
|
|
"""Reduce learning rate when a metric has stopped improving.
|
|
|
|
Models often benefit from reducing the learning rate by a factor
|
|
of 2-10 once learning stagnates. This callback monitors a
|
|
quantity and if no improvement is seen for a 'patience' number
|
|
of epochs, the learning rate is reduced.
|
|
|
|
Example:
|
|
|
|
```python
|
|
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,
|
|
patience=5, min_lr=0.001)
|
|
model.fit(X_train, Y_train, callbacks=[reduce_lr])
|
|
```
|
|
|
|
Args:
|
|
monitor: quantity to be monitored.
|
|
factor: factor by which the learning rate will be reduced.
|
|
`new_lr = lr * factor`.
|
|
patience: number of epochs with no improvement after which learning rate
|
|
will be reduced.
|
|
verbose: int. 0: quiet, 1: update messages.
|
|
mode: one of `{'auto', 'min', 'max'}`. In `'min'` mode,
|
|
the learning rate will be reduced when the
|
|
quantity monitored has stopped decreasing; in `'max'` mode it will be
|
|
reduced when the quantity monitored has stopped increasing; in
|
|
`'auto'` mode, the direction is automatically inferred from the name
|
|
of the monitored quantity.
|
|
min_delta: threshold for measuring the new optimum, to only focus on
|
|
significant changes.
|
|
cooldown: number of epochs to wait before resuming normal operation
|
|
after lr has been reduced.
|
|
min_lr: lower bound on the learning rate.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
monitor="val_loss",
|
|
factor=0.1,
|
|
patience=10,
|
|
verbose=0,
|
|
mode="auto",
|
|
min_delta=1e-4,
|
|
cooldown=0,
|
|
min_lr=0,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.monitor = monitor
|
|
if factor >= 1.0:
|
|
raise ValueError(
|
|
"ReduceLROnPlateau does not support "
|
|
f"a factor >= 1.0. Got {factor}"
|
|
)
|
|
if "epsilon" in kwargs:
|
|
min_delta = kwargs.pop("epsilon")
|
|
logging.warning(
|
|
"`epsilon` argument is deprecated and "
|
|
"will be removed, use `min_delta` instead."
|
|
)
|
|
self.factor = factor
|
|
self.min_lr = min_lr
|
|
self.min_delta = min_delta
|
|
self.patience = patience
|
|
self.verbose = verbose
|
|
self.cooldown = cooldown
|
|
self.cooldown_counter = 0 # Cooldown counter.
|
|
self.wait = 0
|
|
self.best = 0
|
|
self.mode = mode
|
|
self.monitor_op = None
|
|
self._reset()
|
|
|
|
def _reset(self):
|
|
"""Resets wait counter and cooldown counter."""
|
|
if self.mode not in ["auto", "min", "max"]:
|
|
logging.warning(
|
|
"Learning rate reduction mode %s is unknown, "
|
|
"fallback to auto mode.",
|
|
self.mode,
|
|
)
|
|
self.mode = "auto"
|
|
if self.mode == "min" or (
|
|
self.mode == "auto" and "acc" not in self.monitor
|
|
):
|
|
self.monitor_op = lambda a, b: np.less(a, b - self.min_delta)
|
|
self.best = np.Inf
|
|
else:
|
|
self.monitor_op = lambda a, b: np.greater(a, b + self.min_delta)
|
|
self.best = -np.Inf
|
|
self.cooldown_counter = 0
|
|
self.wait = 0
|
|
|
|
def on_train_begin(self, logs=None):
|
|
self._reset()
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
logs["lr"] = backend.get_value(self.model.optimizer.lr)
|
|
current = logs.get(self.monitor)
|
|
if current is None:
|
|
logging.warning(
|
|
"Learning rate reduction is conditioned on metric `%s` "
|
|
"which is not available. Available metrics are: %s",
|
|
self.monitor,
|
|
",".join(list(logs.keys())),
|
|
)
|
|
|
|
else:
|
|
if self.in_cooldown():
|
|
self.cooldown_counter -= 1
|
|
self.wait = 0
|
|
|
|
if self.monitor_op(current, self.best):
|
|
self.best = current
|
|
self.wait = 0
|
|
elif not self.in_cooldown():
|
|
self.wait += 1
|
|
if self.wait >= self.patience:
|
|
old_lr = backend.get_value(self.model.optimizer.lr)
|
|
if old_lr > np.float32(self.min_lr):
|
|
new_lr = old_lr * self.factor
|
|
new_lr = max(new_lr, self.min_lr)
|
|
backend.set_value(self.model.optimizer.lr, new_lr)
|
|
if self.verbose > 0:
|
|
io_utils.print_msg(
|
|
f"\nEpoch {epoch +1}: "
|
|
"ReduceLROnPlateau reducing "
|
|
f"learning rate to {new_lr}."
|
|
)
|
|
self.cooldown_counter = self.cooldown
|
|
self.wait = 0
|
|
|
|
def in_cooldown(self):
|
|
return self.cooldown_counter > 0
|
|
|
|
|
|
@keras_export("keras.callbacks.CSVLogger")
|
|
class CSVLogger(Callback):
|
|
"""Callback that streams epoch results to a CSV file.
|
|
|
|
Supports all values that can be represented as a string,
|
|
including 1D iterables such as `np.ndarray`.
|
|
|
|
Example:
|
|
|
|
```python
|
|
csv_logger = CSVLogger('training.log')
|
|
model.fit(X_train, Y_train, callbacks=[csv_logger])
|
|
```
|
|
|
|
Args:
|
|
filename: Filename of the CSV file, e.g. `'run/log.csv'`.
|
|
separator: String used to separate elements in the CSV file.
|
|
append: Boolean. True: append if file exists (useful for continuing
|
|
training). False: overwrite existing file.
|
|
"""
|
|
|
|
def __init__(self, filename, separator=",", append=False):
|
|
self.sep = separator
|
|
self.filename = io_utils.path_to_string(filename)
|
|
self.append = append
|
|
self.writer = None
|
|
self.keys = None
|
|
self.append_header = True
|
|
super().__init__()
|
|
|
|
def on_train_begin(self, logs=None):
|
|
if self.append:
|
|
if tf.io.gfile.exists(self.filename):
|
|
with tf.io.gfile.GFile(self.filename, "r") as f:
|
|
self.append_header = not bool(len(f.readline()))
|
|
mode = "a"
|
|
else:
|
|
mode = "w"
|
|
self.csv_file = tf.io.gfile.GFile(self.filename, mode)
|
|
|
|
def on_epoch_end(self, epoch, logs=None):
|
|
logs = logs or {}
|
|
|
|
def handle_value(k):
|
|
is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
|
|
if isinstance(k, str):
|
|
return k
|
|
elif (
|
|
isinstance(k, collections.abc.Iterable)
|
|
and not is_zero_dim_ndarray
|
|
):
|
|
return f"\"[{', '.join(map(str, k))}]\""
|
|
else:
|
|
return k
|
|
|
|
if self.keys is None:
|
|
self.keys = sorted(logs.keys())
|
|
|
|
if self.model.stop_training:
|
|
# We set NA so that csv parsers do not fail for this last epoch.
|
|
logs = dict(
|
|
(k, logs[k]) if k in logs else (k, "NA") for k in self.keys
|
|
)
|
|
|
|
if not self.writer:
|
|
|
|
class CustomDialect(csv.excel):
|
|
delimiter = self.sep
|
|
|
|
fieldnames = ["epoch"] + self.keys
|
|
|
|
self.writer = csv.DictWriter(
|
|
self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
|
|
)
|
|
if self.append_header:
|
|
self.writer.writeheader()
|
|
|
|
row_dict = collections.OrderedDict({"epoch": epoch})
|
|
row_dict.update((key, handle_value(logs[key])) for key in self.keys)
|
|
self.writer.writerow(row_dict)
|
|
self.csv_file.flush()
|
|
|
|
def on_train_end(self, logs=None):
|
|
self.csv_file.close()
|
|
self.writer = None
|
|
|
|
|
|
@keras_export("keras.callbacks.LambdaCallback")
|
|
class LambdaCallback(Callback):
|
|
r"""Callback for creating simple, custom callbacks on-the-fly.
|
|
|
|
This callback is constructed with anonymous functions that will be called
|
|
at the appropriate time (during `Model.{fit | evaluate | predict}`).
|
|
Note that the callbacks expects positional arguments, as:
|
|
|
|
- `on_epoch_begin` and `on_epoch_end` expect two positional arguments:
|
|
`epoch`, `logs`
|
|
- `on_batch_begin` and `on_batch_end` expect two positional arguments:
|
|
`batch`, `logs`
|
|
- `on_train_begin` and `on_train_end` expect one positional argument:
|
|
`logs`
|
|
|
|
Args:
|
|
on_epoch_begin: called at the beginning of every epoch.
|
|
on_epoch_end: called at the end of every epoch.
|
|
on_batch_begin: called at the beginning of every batch.
|
|
on_batch_end: called at the end of every batch.
|
|
on_train_begin: called at the beginning of model training.
|
|
on_train_end: called at the end of model training.
|
|
|
|
Example:
|
|
|
|
```python
|
|
# Print the batch number at the beginning of every batch.
|
|
batch_print_callback = LambdaCallback(
|
|
on_batch_begin=lambda batch,logs: print(batch))
|
|
|
|
# Stream the epoch loss to a file in JSON format. The file content
|
|
# is not well-formed JSON but rather has a JSON object per line.
|
|
import json
|
|
json_log = open('loss_log.json', mode='wt', buffering=1)
|
|
json_logging_callback = LambdaCallback(
|
|
on_epoch_end=lambda epoch, logs: json_log.write(
|
|
json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\n'),
|
|
on_train_end=lambda logs: json_log.close()
|
|
)
|
|
|
|
# Terminate some processes after having finished model training.
|
|
processes = ...
|
|
cleanup_callback = LambdaCallback(
|
|
on_train_end=lambda logs: [
|
|
p.terminate() for p in processes if p.is_alive()])
|
|
|
|
model.fit(...,
|
|
callbacks=[batch_print_callback,
|
|
json_logging_callback,
|
|
cleanup_callback])
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
on_epoch_begin=None,
|
|
on_epoch_end=None,
|
|
on_batch_begin=None,
|
|
on_batch_end=None,
|
|
on_train_begin=None,
|
|
on_train_end=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.__dict__.update(kwargs)
|
|
if on_epoch_begin is not None:
|
|
self.on_epoch_begin = on_epoch_begin
|
|
if on_epoch_end is not None:
|
|
self.on_epoch_end = on_epoch_end
|
|
if on_batch_begin is not None:
|
|
self.on_batch_begin = on_batch_begin
|
|
if on_batch_end is not None:
|
|
self.on_batch_end = on_batch_end
|
|
if on_train_begin is not None:
|
|
self.on_train_begin = on_train_begin
|
|
if on_train_end is not None:
|
|
self.on_train_end = on_train_end
|