Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/keras/mixed_precision/loss_scale.py

59 lines
2.0 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains keras-specific LossScale functionality.
This functions cannot be in the non-keras loss_scale.py file since they depend
on keras, and files outside of keras should not depend on files inside keras.
"""
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
def serialize(loss_scale):
return generic_utils.serialize_keras_object(loss_scale)
def deserialize(config, custom_objects=None):
loss_scale_module_objects = {
'FixedLossScale': loss_scale_module.FixedLossScale,
'DynamicLossScale': loss_scale_module.DynamicLossScale,
}
return generic_utils.deserialize_keras_object(
config,
module_objects=loss_scale_module_objects,
custom_objects=custom_objects,
printable_module_name='loss scale'
)
def get(identifier):
"""Get a loss scale object."""
if isinstance(identifier, dict):
return deserialize(identifier)
if isinstance(identifier, (int, float)):
return loss_scale_module.FixedLossScale(identifier)
if identifier == 'dynamic':
return loss_scale_module.DynamicLossScale()
if isinstance(identifier, loss_scale_module.LossScale):
return identifier
elif identifier is None:
return None
else:
raise ValueError('Could not interpret loss scale identifier: %s' %
identifier)