330 lines
11 KiB
Python
330 lines
11 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Functions called by the generated code to execute an eager-mode op."""
|
|
|
|
from google.protobuf import text_format
|
|
from tensorflow.core.framework import tensor_pb2
|
|
from tensorflow.python import pywrap_tfe
|
|
from tensorflow.python.eager import core
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import tensor_conversion_registry
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.types import core as core_types
|
|
from tensorflow.python.util import compat
|
|
|
|
|
|
def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
|
"""Execute a TensorFlow operation.
|
|
|
|
Args:
|
|
op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
|
|
execute.
|
|
num_outputs: The number of outputs of the operation to fetch. (Explicitly
|
|
provided instead of being inferred for performance reasons).
|
|
inputs: A list of inputs to the operation. Each entry should be a Tensor, or
|
|
a value which can be passed to the Tensor constructor to create one.
|
|
attrs: A tuple with alternating string attr names and attr values for this
|
|
operation.
|
|
ctx: The value of context.context().
|
|
name: Customized name for the operation.
|
|
|
|
Returns:
|
|
List of output Tensor objects. The list is empty if there are no outputs
|
|
|
|
Raises:
|
|
An exception on error.
|
|
"""
|
|
device_name = ctx.device_name
|
|
# pylint: disable=protected-access
|
|
try:
|
|
ctx.ensure_initialized()
|
|
tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
|
|
inputs, attrs, num_outputs)
|
|
except core._NotOkStatusException as e:
|
|
if name is not None:
|
|
e.message += " name: " + name
|
|
raise core._status_to_exception(e) from None
|
|
except TypeError as e:
|
|
keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
|
|
if keras_symbolic_tensors:
|
|
raise core._SymbolicException(
|
|
"Inputs to eager execution function cannot be Keras symbolic "
|
|
"tensors, but found {}".format(keras_symbolic_tensors))
|
|
raise e
|
|
# pylint: enable=protected-access
|
|
return tensors
|
|
|
|
|
|
def execute_with_cancellation(op_name,
|
|
num_outputs,
|
|
inputs,
|
|
attrs,
|
|
ctx,
|
|
cancellation_manager,
|
|
name=None):
|
|
"""Execute a TensorFlow operation.
|
|
|
|
Args:
|
|
op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
|
|
execute.
|
|
num_outputs: The number of outputs of the operation to fetch. (Explicitly
|
|
provided instead of being inferred for performance reasons).
|
|
inputs: A list of inputs to the operation. Each entry should be a Tensor, or
|
|
a value which can be passed to the Tensor constructor to create one.
|
|
attrs: A tuple with alternating string attr names and attr values for this
|
|
operation.
|
|
ctx: The value of context.context().
|
|
cancellation_manager: a `CancellationManager` object that can be used to
|
|
cancel the operation.
|
|
name: Customized name for the operation.
|
|
|
|
Returns:
|
|
List of output Tensor objects. The list is empty if there are no outputs
|
|
|
|
Raises:
|
|
An exception on error.
|
|
"""
|
|
device_name = ctx.device_name
|
|
# pylint: disable=protected-access
|
|
try:
|
|
ctx.ensure_initialized()
|
|
tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
|
|
op_name, inputs, attrs,
|
|
cancellation_manager._impl,
|
|
num_outputs)
|
|
except core._NotOkStatusException as e:
|
|
if name is not None:
|
|
e.message += " name: " + name
|
|
raise core._status_to_exception(e) from None
|
|
except TypeError as e:
|
|
keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
|
|
if keras_symbolic_tensors:
|
|
raise core._SymbolicException(
|
|
"Inputs to eager execution function cannot be Keras symbolic "
|
|
"tensors, but found {}".format(keras_symbolic_tensors))
|
|
raise e
|
|
# pylint: enable=protected-access
|
|
return tensors
|
|
|
|
|
|
def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None):
|
|
"""Monkey-patch to execute to enable execution callbacks."""
|
|
tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
|
|
for callback in ctx.op_callbacks:
|
|
callback(op_name, tuple(inputs), attrs, tensors, name)
|
|
|
|
return tensors
|
|
|
|
|
|
execute = quick_execute
|
|
|
|
|
|
def must_record_gradient():
|
|
"""Import backprop if you want gradients recorded."""
|
|
return False
|
|
|
|
|
|
def record_gradient(unused_op_name, unused_inputs, unused_attrs,
|
|
unused_outputs):
|
|
"""Import backprop if you want gradients recorded."""
|
|
pass
|
|
|
|
|
|
def make_float(v, arg_name):
|
|
if not isinstance(v, compat.real_types):
|
|
raise TypeError("Expected float for argument '%s' not %s." %
|
|
(arg_name, repr(v)))
|
|
return float(v)
|
|
|
|
|
|
def make_int(v, arg_name):
|
|
if isinstance(v, str):
|
|
raise TypeError("Expected int for argument '%s' not %s." %
|
|
(arg_name, repr(v)))
|
|
try:
|
|
return int(v)
|
|
except (ValueError, TypeError):
|
|
raise TypeError("Expected int for argument '%s' not %s." %
|
|
(arg_name, repr(v)))
|
|
|
|
|
|
def make_str(v, arg_name):
|
|
if not isinstance(v, compat.bytes_or_text_types):
|
|
raise TypeError("Expected string for argument '%s' not %s." %
|
|
(arg_name, repr(v)))
|
|
return compat.as_bytes(v) # Convert unicode strings to bytes.
|
|
|
|
|
|
def make_bool(v, arg_name):
|
|
if not isinstance(v, bool):
|
|
raise TypeError("Expected bool for argument '%s' not %s." %
|
|
(arg_name, repr(v)))
|
|
return v
|
|
|
|
|
|
def make_type(v, arg_name):
|
|
try:
|
|
v = dtypes.as_dtype(v).base_dtype
|
|
except TypeError:
|
|
raise TypeError("Expected DataType for argument '%s' not %s." %
|
|
(arg_name, repr(v)))
|
|
i = v.as_datatype_enum
|
|
return i
|
|
|
|
|
|
def make_shape(v, arg_name):
|
|
"""Convert v into a list."""
|
|
# Args:
|
|
# v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
|
|
# arg_name: String, for error messages.
|
|
|
|
# Returns:
|
|
# None if the rank is unknown, otherwise a list of ints (or Nones in the
|
|
# position where the dimension is unknown).
|
|
try:
|
|
shape = tensor_shape.as_shape(v)
|
|
except TypeError as e:
|
|
raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e))
|
|
except ValueError as e:
|
|
raise ValueError("Error converting %s to a TensorShape: %s." %
|
|
(arg_name, e))
|
|
if shape.ndims is None:
|
|
return None
|
|
else:
|
|
return shape.as_list()
|
|
|
|
|
|
def make_tensor(v, arg_name):
|
|
"""Ensure v is a TensorProto."""
|
|
if isinstance(v, tensor_pb2.TensorProto):
|
|
return v
|
|
elif isinstance(v, str):
|
|
pb = tensor_pb2.TensorProto()
|
|
text_format.Merge(v, pb)
|
|
return pb
|
|
raise TypeError(
|
|
"Don't know how to convert %s to a TensorProto for argument '%s'." %
|
|
(repr(v), arg_name))
|
|
|
|
|
|
def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None):
|
|
"""Convert sequence `l` to eager same-type Tensors."""
|
|
del ctx # Unused
|
|
if (not l) and (default_dtype is not None):
|
|
return default_dtype, [] # List is empty; assume default dtype.
|
|
for x in l:
|
|
if not isinstance(x, core_types.Value):
|
|
break
|
|
else: # note: intentional for-else
|
|
return l[0]._datatype_enum(), l # pylint: disable=protected-access
|
|
|
|
# Is some input already a Tensor with a dtype?
|
|
dtype = None
|
|
for t in l:
|
|
if isinstance(t, core_types.Value):
|
|
dtype = t.dtype
|
|
break
|
|
|
|
if dtype is None:
|
|
# Infer a dtype based on the first value, and use that dtype for the
|
|
# remaining values.
|
|
|
|
ret = []
|
|
for t in l:
|
|
tensor = None
|
|
# First see if we can get a valid dtype with the default conversion
|
|
# and see if it matches an allowed dtypes. Some ops like ConcatV2 may
|
|
# not list allowed dtypes, in which case we should skip this.
|
|
if dtype is None and allowed_dtypes:
|
|
tensor = tensor_conversion_registry.convert(t)
|
|
# If we did not match an allowed dtype, try again with the default
|
|
# dtype. This could be because we have an empty tensor and thus we
|
|
# picked the wrong type.
|
|
if tensor.dtype not in allowed_dtypes:
|
|
tensor = None
|
|
|
|
if tensor is None:
|
|
tensor = tensor_conversion_registry.convert(
|
|
t, dtype, preferred_dtype=default_dtype
|
|
)
|
|
|
|
ret.append(tensor)
|
|
if dtype is None:
|
|
dtype = tensor.dtype
|
|
else:
|
|
ret = [tensor_conversion_registry.convert(t, dtype) for t in l]
|
|
|
|
# TODO(slebedev): consider removing this as it leaks a Keras concept.
|
|
# pylint: disable=protected-access
|
|
keras_symbolic_tensors = [x for x in ret if _is_keras_symbolic_tensor(x)]
|
|
if keras_symbolic_tensors:
|
|
raise core._SymbolicException(
|
|
"Using symbolic output of a Keras layer during eager execution "
|
|
"{}".format(keras_symbolic_tensors))
|
|
# pylint: enable=protected-access
|
|
return dtype.as_datatype_enum, ret
|
|
|
|
|
|
def convert_to_mixed_eager_tensors(values, ctx):
|
|
del ctx # Unused
|
|
v = [tensor_conversion_registry.convert(t) for t in values]
|
|
types = [t._datatype_enum() for t in v] # pylint: disable=protected-access
|
|
return types, v
|
|
|
|
|
|
def args_to_mixed_eager_tensors(lists, ctx):
|
|
"""Converts a list of same-length lists of values to eager tensors."""
|
|
del ctx # Unused
|
|
assert len(lists) > 1
|
|
|
|
# Generate an error if len(lists[i]) is not the same for all i.
|
|
lists_ret = [[]]
|
|
for l in lists[1:]:
|
|
if len(l) != len(lists[0]):
|
|
raise ValueError(
|
|
"Expected list arguments to be the same length: %d != %d (%r vs. %r)."
|
|
% (len(lists[0]), len(l), lists[0], l))
|
|
lists_ret.append([])
|
|
|
|
# Convert the first element of each list first, then the second element, etc.
|
|
types = []
|
|
for i in range(len(lists[0])):
|
|
dtype = None
|
|
# If any list has a Tensor, use that dtype
|
|
for l in lists:
|
|
if isinstance(l[i], core_types.Value):
|
|
dtype = l[i].dtype
|
|
break
|
|
if dtype is None:
|
|
# Convert the first one and use its dtype.
|
|
lists_ret[0].append(tensor_conversion_registry.convert(lists[0][i]))
|
|
dtype = lists_ret[0][i].dtype
|
|
for j in range(1, len(lists)):
|
|
lists_ret[j].append(
|
|
tensor_conversion_registry.convert(lists[j][i], dtype=dtype)
|
|
)
|
|
else:
|
|
# Convert everything to the found dtype.
|
|
for j in range(len(lists)):
|
|
lists_ret[j].append(
|
|
tensor_conversion_registry.convert(lists[j][i], dtype=dtype)
|
|
)
|
|
types.append(dtype.as_datatype_enum)
|
|
return types, lists_ret
|
|
|
|
|
|
def _is_keras_symbolic_tensor(x):
|
|
return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"
|