78 lines
2.9 KiB
Python
78 lines
2.9 KiB
Python
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
# ==============================================================================
|
||
|
"""SignatureDef utility functions implementation."""
|
||
|
|
||
|
from tensorflow.python.keras.saving.utils_v1 import unexported_constants
|
||
|
from tensorflow.python.saved_model import signature_def_utils
|
||
|
from tensorflow.python.saved_model import utils_impl as utils
|
||
|
|
||
|
|
||
|
# LINT.IfChange
|
||
|
def supervised_train_signature_def(
|
||
|
inputs, loss, predictions=None, metrics=None):
|
||
|
return _supervised_signature_def(
|
||
|
unexported_constants.SUPERVISED_TRAIN_METHOD_NAME, inputs, loss=loss,
|
||
|
predictions=predictions, metrics=metrics)
|
||
|
|
||
|
|
||
|
def supervised_eval_signature_def(
|
||
|
inputs, loss, predictions=None, metrics=None):
|
||
|
return _supervised_signature_def(
|
||
|
unexported_constants.SUPERVISED_EVAL_METHOD_NAME, inputs, loss=loss,
|
||
|
predictions=predictions, metrics=metrics)
|
||
|
|
||
|
|
||
|
def _supervised_signature_def(
|
||
|
method_name, inputs, loss=None, predictions=None,
|
||
|
metrics=None):
|
||
|
"""Creates a signature for training and eval data.
|
||
|
|
||
|
This function produces signatures that describe the inputs and outputs
|
||
|
of a supervised process, such as training or evaluation, that
|
||
|
results in loss, metrics, and the like. Note that this function only requires
|
||
|
inputs to be not None.
|
||
|
|
||
|
Args:
|
||
|
method_name: Method name of the SignatureDef as a string.
|
||
|
inputs: dict of string to `Tensor`.
|
||
|
loss: dict of string to `Tensor` representing computed loss.
|
||
|
predictions: dict of string to `Tensor` representing the output predictions.
|
||
|
metrics: dict of string to `Tensor` representing metric ops.
|
||
|
|
||
|
Returns:
|
||
|
A train- or eval-flavored signature_def.
|
||
|
|
||
|
Raises:
|
||
|
ValueError: If inputs or outputs is `None`.
|
||
|
"""
|
||
|
if inputs is None or not inputs:
|
||
|
raise ValueError('{} inputs cannot be None or empty.'.format(method_name))
|
||
|
|
||
|
signature_inputs = {key: utils.build_tensor_info(tensor)
|
||
|
for key, tensor in inputs.items()}
|
||
|
|
||
|
signature_outputs = {}
|
||
|
for output_set in (loss, predictions, metrics):
|
||
|
if output_set is not None:
|
||
|
sig_out = {key: utils.build_tensor_info(tensor)
|
||
|
for key, tensor in output_set.items()}
|
||
|
signature_outputs.update(sig_out)
|
||
|
|
||
|
signature_def = signature_def_utils.build_signature_def(
|
||
|
signature_inputs, signature_outputs, method_name)
|
||
|
|
||
|
return signature_def
|
||
|
# LINT.ThenChange(//tensorflow/python/keras/saving/utils_v1/signature_def_utils.py)
|