Intelegentny_Pszczelarz/.venv/Lib/site-packages/tensorflow/python/keras/saving/utils_v1/signature_def_utils.py

78 lines
2.9 KiB
Python
Raw Normal View History

2023-06-19 00:49:18 +02:00
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SignatureDef utility functions implementation."""
from tensorflow.python.keras.saving.utils_v1 import unexported_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import utils_impl as utils
# LINT.IfChange
def supervised_train_signature_def(
inputs, loss, predictions=None, metrics=None):
return _supervised_signature_def(
unexported_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
predictions=predictions, metrics=metrics)
def supervised_eval_signature_def(
inputs, loss, predictions=None, metrics=None):
return _supervised_signature_def(
unexported_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
predictions=predictions, metrics=metrics)
def _supervised_signature_def(
method_name, inputs, loss=None, predictions=None,
metrics=None):
"""Creates a signature for training and eval data.
This function produces signatures that describe the inputs and outputs
of a supervised process, such as training or evaluation, that
results in loss, metrics, and the like. Note that this function only requires
inputs to be not None.
Args:
method_name: Method name of the SignatureDef as a string.
inputs: dict of string to `Tensor`.
loss: dict of string to `Tensor` representing computed loss.
predictions: dict of string to `Tensor` representing the output predictions.
metrics: dict of string to `Tensor` representing metric ops.
Returns:
A train- or eval-flavored signature_def.
Raises:
ValueError: If inputs or outputs is `None`.
"""
if inputs is None or not inputs:
raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
signature_inputs = {key: utils.build_tensor_info(tensor)
for key, tensor in inputs.items()}
signature_outputs = {}
for output_set in (loss, predictions, metrics):
if output_set is not None:
sig_out = {key: utils.build_tensor_info(tensor)
for key, tensor in output_set.items()}
signature_outputs.update(sig_out)
signature_def = signature_def_utils.build_signature_def(
signature_inputs, signature_outputs, method_name)
return signature_def
# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/signature_def_utils.py)