# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Class to hold a library of OpDefs and use it to create Brain operations.""" from google.protobuf import text_format from tensorflow.core.config import flags from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.framework import tensor_pb2 from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_callbacks from tensorflow.python.framework import op_def_library_pybind from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import _pywrap_utils from tensorflow.python.util import compat from tensorflow.python.util import tf_contextlib def _Attr(op_def, name): for attr in op_def.attr: if attr.name == name: return attr raise TypeError(f"Inconsistent OpDef for '{op_def.name}', missing attr " f"'{name}'") def _AttrValue(attr_protos, name, op_type_name): if name in attr_protos: return attr_protos[name] raise TypeError(f"Inconsistent OpDef for '{op_type_name}', missing attr " f"'{name}' from '{attr_protos}'.") def _SatisfiesTypeConstraint(dtype, attr_def, param_name): if attr_def.HasField("allowed_values"): allowed_list = attr_def.allowed_values.list.type allowed_values = ", ".join(dtypes.as_dtype(x).name for x in allowed_list) if dtype not in allowed_list: raise TypeError( f"Value passed to parameter '{param_name}' has DataType " f"{dtypes.as_dtype(dtype).name} not in list of allowed values: " f"{allowed_values}") def _SatisfiesLengthConstraint(length, attr_def, param_name, op_type_name): if attr_def.has_minimum and length < attr_def.minimum: raise ValueError(f"Attr '{param_name}' of '{op_type_name}' Op passed list " f"of length {length} less than minimum " f"{attr_def.minimum}.") def _SatisfiesAllowedStringsConstraint(value, attr_def, arg_name, op_type_name): if value not in attr_def.allowed_values.list.s: allowed_values = '", "'.join( map(compat.as_text, attr_def.allowed_values.list.s)) raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed string " f"'{compat.as_text(value)}' not in: \"{allowed_values}\".") def _SatisfiesIntMinimumConstraint(value, attr_def, arg_name, op_type_name): if value < attr_def.minimum: raise ValueError(f"Attr '{arg_name}' of '{op_type_name}' Op passed {value} " f"less than minimum {attr_def.minimum}.") def _IsListParameter(arg): if arg.number_attr: return True elif arg.type_list_attr: return True return False def _NumTypeFields(arg): num = 0 if arg.type != types_pb2.DT_INVALID: num += 1 if arg.type_attr: num += 1 if arg.type_list_attr: num += 1 return num def _IsListValue(v): return isinstance(v, (list, tuple)) def _Flatten(l): """Converts [1, 2, [3, 4], [5]] to [1, 2, 3, 4, 5].""" # [1, 2, [3, 4], [5]] -> [[1], [2], [3, 4], [5]] l_of_l = [x if _IsListValue(x) else [x] for x in l] # [[1], [2], [3, 4], [5]] -> [1, 2, 3, 4, 5] return [item for sublist in l_of_l for item in sublist] def _Restructure(l, structure): """Returns the elements of list l structured according to the given structure. A structure is represented by a list whose elements are either `None` or a non-negative integer. `None` corresponds to a single element in the output list, and an integer N corresponds to a nested list of length N. The function returns a data structure whose shape is given by `structure`, and whose elements are taken from `l`. If `structure` is a singleton, the function returns the single data structure implied by the 0th element of `structure`. For example: _Restructure(["foo", "bar", "baz", "qux"], [None, 2, None]) -> ["foo", ["bar", "baz"], "qux"] _Restructure(["foo"], [None]) -> "foo" _Restructure(["foo"], [1]) -> ["foo"] _Restructure([], [0]) -> [] Args: l: A list. structure: A list whose elements are either `None` or a non-negative integer. Returns: The elements of `l`, restructured according to `structure`. If `structure` is a list of length 1, this function returns the single data structure implied by `structure[0]`. """ result = [] current_index = 0 for element in structure: if element is None: result.append(l[current_index]) current_index += 1 else: result.append(l[current_index:current_index+element]) current_index += element if len(result) == 1: return result[0] else: return tuple(result) def _MakeFloat(v, arg_name): if not isinstance(v, compat.real_types): raise TypeError(f"Expected float for argument '{arg_name}' not {repr(v)}.") return float(v) def _MakeInt(v, arg_name): if isinstance(v, str): raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.") try: return int(v) except (ValueError, TypeError): raise TypeError(f"Expected int for argument '{arg_name}' not {repr(v)}.") def _MakeStr(v, arg_name): if not isinstance(v, compat.bytes_or_text_types): raise TypeError(f"Expected string for argument '{arg_name}' not {repr(v)}.") return compat.as_bytes(v) # Convert unicode strings to bytes. def _MakeBool(v, arg_name): if not isinstance(v, bool): raise TypeError(f"Expected bool for argument '{arg_name}' not {repr(v)}.") return v def _MakeType(v, arg_name): try: v = dtypes.as_dtype(v).base_dtype except TypeError: raise TypeError(f"Expected DataType for argument '{arg_name}' not " f"{repr(v)}.") return v.as_datatype_enum def _MakeShape(v, arg_name): """Convert v into a TensorShapeProto.""" # Args: # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. # arg_name: String, for error messages. # Returns: # A TensorShapeProto. if isinstance(v, tensor_shape_pb2.TensorShapeProto): for d in v.dim: if d.name: logging.warning("Warning: TensorShapeProto with a named dimension: %s", str(v)) break return v try: return tensor_shape.as_shape(v).as_proto() except TypeError as e: raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a " f"TensorShape: {e}") except ValueError as e: raise TypeError(f"Error converting {repr(v)} (arg name = {arg_name}) to a " f"TensorShape: {e}") def _MakeTensor(v, arg_name): """Ensure v is a TensorProto.""" if isinstance(v, tensor_pb2.TensorProto): return v raise TypeError( f"Don't know how to convert {repr(v)} to a TensorProto for argument " f"'{arg_name}'") def _MakeFunc(v, arg_name): """Ensure v is a func.""" if isinstance(v, attr_value_pb2.NameAttrList): return v if isinstance(v, compat.bytes_or_text_types): fn_attr = attr_value_pb2.NameAttrList(name=v) elif hasattr(v, "add_to_graph"): v.add_to_graph(ops.get_default_graph()) if hasattr(v, "_as_name_attr_list"): fn_attr = v._as_name_attr_list # pylint: disable=protected-access else: fn_attr = attr_value_pb2.NameAttrList(name=v.name) else: raise TypeError(f"Don't know how to convert {repr(v)} to a func for " f"argument {arg_name}") return fn_attr # pylint: disable=g-doc-return-or-yield @tf_contextlib.contextmanager def _MaybeColocateWith(inputs): """A context manager for (maybe) colocating with a list of input tensors. Args: inputs: A list of `Tensor` or `Operation` objects. Returns: A context manager. """ if not inputs: yield else: # NOTE(mrry): The `ops.colocate_with()` function accepts only a single # op or tensor, so we create one context manager per element in the list. with ops.colocate_with(inputs[0]), _MaybeColocateWith(inputs[1:]): yield # pylint: enable=g-doc-return-or-yield def apply_op(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Add a node invoking a registered Op to a graph. Example usage: # input1 and input2 can be Tensors or anything ops.convert_to_tensor() # will convert to a Tensor. op_def_library.apply_op("op", input1=input1, input2=input2) # Can specify a node name. op_def_library.apply_op("op", input1=input1, name="node_name") # Must use keyword arguments, with the names specified in the OpDef. op_def_library.apply_op("op", input_name=input, attr_name=attr) All attrs must either be inferred from an input or specified. (If inferred, the attr must not be specified.) If an attr has a default value specified in the Op's OpDef, then you may pass None as the value of that attr to get the default. Args: op_type_name: string. Must match the name field of a registered Op. name: string. Optional name of the created op. **keywords: input Tensor and attr arguments specified by name, and optional parameters to pass when constructing the Operation. Returns: The Tensor(s) representing the output of the operation, or the Operation itself if there are no outputs. Raises: RuntimeError: On some errors. TypeError: On some errors. ValueError: On some errors. """ output_structure, is_stateful, op, outputs = _apply_op_helper( op_type_name, name, **keywords) if output_structure: res = _Restructure(ops.convert_n_to_tensor(outputs), output_structure) if isinstance(res, list) and not res and is_stateful: return op else: return res else: return op # This is temporary Python/C++ code duplication until all of it can be ported # over to C++. # LINT.IfChange def _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos): """Extracts `attr_protos`. For use in _apply_op_helper.""" for attr_def in op_def.attr: key = attr_def.name value = attrs[key] if attr_def.HasField("default_value") and value is None: attr_value = attr_value_pb2.AttrValue() attr_value.CopyFrom(attr_def.default_value) attr_protos[key] = attr_value continue attr_value = value_to_attr_value(value, attr_def.type, key) if attr_def.type.startswith("list("): _SatisfiesLengthConstraint(len(value), attr_def, key, op_type_name) if attr_def.HasField("allowed_values"): if attr_def.type == "string": _SatisfiesAllowedStringsConstraint(attr_value.s, attr_def, key, op_type_name) elif attr_def.type == "list(string)": for value in attr_value.list.s: _SatisfiesAllowedStringsConstraint(value, attr_def, key, op_type_name) if attr_def.has_minimum and attr_def.type == "int": _SatisfiesIntMinimumConstraint(attr_value.i, attr_def, key, op_type_name) if attr_def.type == "type": _SatisfiesTypeConstraint(attr_value.type, attr_def, key) if attr_def.type == "list(type)": for value in attr_value.list.type: _SatisfiesTypeConstraint(value, attr_def, key) attr_protos[key] = attr_value def _ExtractOutputStructure(op_type_name, op_def, attr_protos, output_structure): """Extracts `output_structure`. For use in _apply_op_helper.""" for arg in op_def.output_arg: if arg.number_attr: n = _AttrValue(attr_protos, arg.number_attr, op_type_name).i output_structure.append(n) elif arg.type_attr: t = _AttrValue(attr_protos, arg.type_attr, op_type_name) output_structure.append(None) elif arg.type_list_attr: t = _AttrValue(attr_protos, arg.type_list_attr, op_type_name) output_structure.append(len(t.list.type)) else: output_structure.append(None) def _CanExtractAttrsFastPath(op_def, keywords): """Check if the fast path for _apply_op_helper is applicable.""" # Check if all inputs are already tf.Tensor for input_arg in op_def.input_arg: value = keywords.get(input_arg.name, None) if not isinstance(value, ops.Tensor): return False # Check that attrs are not `func` or `list(func)` type. for attr_def in op_def.attr: if attr_def.type == "func" or attr_def.type == "list(func)": return False return True def _CheckOpDeprecation(op_type_name, op_def, producer): """Checks if the op is deprecated.""" deprecation_version = op_def.deprecation.version if deprecation_version and producer >= deprecation_version: raise NotImplementedError( f"Op {op_type_name} is not available in GraphDef version {producer}. " f"It has been removed in version {deprecation_version}. " f"{op_def.deprecation.explanation}.") def _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, allowed_list_attr_map): """Extracts the `default_type_attr_map` and `allowed_list_attr_map`.""" # TODO(b/31302892): Currently the defaults don't work in the right # way if you have two inputs, one of whose type resolution depends # on the other. Handling this will require restructuring this code # significantly. for attr_def in op_def.attr: if attr_def.type != "type": continue key = attr_def.name if attr_def.HasField("default_value"): default_type_attr_map[key] = dtypes.as_dtype( attr_def.default_value.type) if attr_def.HasField("allowed_values"): allowed_list_attr_map[key] = attr_def.allowed_values.list.type def _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, keywords, default_type_attr_map, attrs, inputs, input_types): """Extracts `attrs`, `inputs`, and `input_types` in _apply_op_helper.""" inferred_from = {} for input_arg in op_def.input_arg: input_name = input_arg.name if input_name in keywords: values = keywords.pop(input_name) elif input_name + "_" in keywords: # Handle the case where the name is a keyword or built-in # for Python so we use the name + _ instead. input_name += "_" values = keywords.pop(input_name) else: raise TypeError(f"No argument for input {input_name} found in {op_def}") # Goals: # * Convert values to Tensors if it contains constants. # * Verify that values is a list if that matches the input_arg's # type. # * If the input_arg's type is determined by attrs, either set # those attrs and validate those attr values are legal (if # they have not yet been set) or validate the input matches # the type indicated by the attrs (if they have already been # inferred via an earlier input). # * If the input_arg has an explicit type, make sure the input # conforms. if _IsListParameter(input_arg): if not _IsListValue(values): raise TypeError( f"Expected list for '{input_name}' argument to '{op_type_name}' " f"Op, not {values}.") # In cases where we expect all elements of the list to have the # same dtype, try to cast non-Tensor elements to that type. dtype = None default_dtype = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.number_attr: if input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] else: for t in values: if isinstance(t, ops.Tensor): dtype = t.dtype break # dtype still not found, prefer using the default dtype # from the attr. if dtype is None and input_arg.type_attr in default_type_attr_map: default_dtype = default_type_attr_map[input_arg.type_attr] try: if not input_arg.is_ref and dtype: dtype = dtypes.as_dtype(dtype).base_dtype values = ops.internal_convert_n_to_tensor( values, name=input_arg.name, dtype=dtype if dtype else None, preferred_dtype=default_dtype, as_ref=input_arg.is_ref) all_types = set(v.dtype.base_dtype for v in values) if input_arg.number_attr and len(all_types) > 1: # All types should match. raise TypeError(f"Not all types matched for {input_arg.name} for " f"{op_type_name}. Got {all_types}") except (TypeError, ValueError): # What types does the conversion function think values have? observed_types = [] for value in values: try: converted_value = ops.convert_to_tensor( value, as_ref=input_arg.is_ref) observed_types.append(converted_value.dtype.base_dtype.name) except (TypeError, ValueError): observed_types.append("") observed = ", ".join(observed_types) prefix = ("Tensors in list passed to '%s' of '%s' Op have types [%s]" % (input_name, op_type_name, observed)) if input_arg.number_attr: if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} that do not match expected type " f"{dtype.name}.") elif input_arg.type_attr in attrs: raise TypeError(f"{prefix} that do not match type {dtype.name} " "inferred from earlier arguments.") else: raise TypeError(f"{prefix} that don't all match.") else: raise TypeError(f"{prefix} that are invalid. Tensors: {values}") types = [x.dtype for x in values] inputs.extend(values) else: # In cases where we have an expected type, try to convert non-Tensor # arguments to that type. dtype = None default_dtype = None allowed_list = None if input_arg.type != types_pb2.DT_INVALID: dtype = input_arg.type elif input_arg.type_attr in attrs: dtype = attrs[input_arg.type_attr] elif input_arg.type_attr in default_type_attr_map: # The dtype could not be inferred solely from the inputs, # so we prefer the attr's default, so code that adds a new attr # with a default is backwards compatible. default_dtype = default_type_attr_map[input_arg.type_attr] allowed_list = allowed_list_attr_map.get(input_arg.type_attr) try: # First see if we can get a valid dtype with the default conversion # and see if it matches an allowed dtypes. Some ops like ConcatV2 may # not list allowed dtypes, in which case we should skip this. if dtype is None and allowed_list: inferred = None try: inferred = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref) except TypeError as err: # When converting a python object such as a list of Dimensions, we # need a dtype to be specified, thus tensor conversion may throw # an exception which we will ignore and try again below. pass # If we did not match an allowed dtype, try again with the default # dtype. This could be because we have an empty tensor and thus we # picked the wrong type. if inferred is not None and inferred.dtype in allowed_list: values = inferred else: values = ops.convert_to_tensor( values, name=input_arg.name, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) else: values = ops.convert_to_tensor( values, name=input_arg.name, dtype=dtype, as_ref=input_arg.is_ref, preferred_dtype=default_dtype) except TypeError as err: if dtype is None: raise err else: raise TypeError( f"Expected {dtypes.as_dtype(dtype).name} passed to parameter " f"'{input_arg.name}' of op '{op_type_name}', got " f"{repr(values)} of type '{type(values).__name__}' instead. " f"Error: {err}") except ValueError: # What type does convert_to_tensor think it has? try: observed = ops.convert_to_tensor( values, as_ref=input_arg.is_ref).dtype.name except ValueError as err: raise ValueError( f"Tried to convert '{input_name}' to a tensor and failed. " f"Error: {err}") prefix = ("Input '%s' of '%s' Op has type %s that does not match" % (input_name, op_type_name, observed)) if input_arg.type != types_pb2.DT_INVALID: raise TypeError(f"{prefix} expected type of " f"{dtypes.as_dtype(input_arg.type).name}.") else: # Update the maps with the default, if needed. k = input_arg.type_attr if k in default_type_attr_map: if k not in attrs: attrs[k] = default_type_attr_map[k] if k not in inferred_from: inferred_from[k] = "Default in OpDef" raise TypeError( f"{prefix} type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") types = [values.dtype] inputs.append(values) base_types = [x.base_dtype for x in types] if input_arg.number_attr: # * or * if input_arg.number_attr in attrs: if len(values) != attrs[input_arg.number_attr]: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} must match length " f"{attrs[input_arg.number_attr]} of argument " f"'{inferred_from[input_arg.number_attr]}'.") else: attrs[input_arg.number_attr] = len(values) inferred_from[input_arg.number_attr] = input_name num_attr = _Attr(op_def, input_arg.number_attr) if num_attr.has_minimum and len(values) < num_attr.minimum: raise ValueError( f"List argument '{input_name}' to '{op_type_name}' Op with " f"length {len(values)} shorter than minimum length " f"{num_attr.minimum}.") # All tensors must have the same base type. if any(bt != base_types[0] for bt in base_types): raise TypeError( f"All tensors passed to '{input_name}' of '{op_type_name}' Op " f"must have the same type. Got {base_types} instead.") if input_arg.type != types_pb2.DT_INVALID: # * case if base_types and base_types[0] != input_arg.type: assert False, "Unreachable" elif input_arg.type_attr in attrs: # * case, where already # has an inferred value. if base_types and base_types[0] != attrs[input_arg.type_attr]: assert False, "Unreachable" else: # * case, where we are now setting # the based on this input if not base_types: # If it's in default_type_attr_map, then wait to set it # (in "process remaining attrs", below). if input_arg.type_attr not in default_type_attr_map: raise TypeError( "Don't know how to infer type variable from empty input " f"list passed to input '{input_name}' of '{op_type_name}' " "Op.") else: attrs[input_arg.type_attr] = base_types[0] inferred_from[input_arg.type_attr] = input_name type_attr = _Attr(op_def, input_arg.type_attr) _SatisfiesTypeConstraint( base_types[0], type_attr, param_name=input_name) elif input_arg.type_attr: # attr_value = base_types[0] if input_arg.type_attr in attrs: if attrs[input_arg.type_attr] != attr_value: raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type " f"{dtypes.as_dtype(attr_value).name} that does not match type " f"{dtypes.as_dtype(attrs[input_arg.type_attr]).name} of " f"argument '{inferred_from[input_arg.type_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_attr), param_name=input_name) attrs[input_arg.type_attr] = attr_value inferred_from[input_arg.type_attr] = input_name elif input_arg.type_list_attr: # attr_value = base_types if input_arg.type_list_attr in attrs: if attrs[input_arg.type_list_attr] != attr_value: actual_types = ", ".join(dtypes.as_dtype(x).name for x in attr_value) expected_types = ", ".join( dtypes.as_dtype(x).name for x in attrs[input_arg.type_list_attr]) raise TypeError( f"Input '{input_name}' of '{op_type_name}' Op has type list of " f"{actual_types} that does not match type list {expected_types}" f" of argument '{inferred_from[input_arg.type_list_attr]}'.") else: for base_type in base_types: _SatisfiesTypeConstraint( base_type, _Attr(op_def, input_arg.type_list_attr), param_name=input_name) attrs[input_arg.type_list_attr] = attr_value inferred_from[input_arg.type_list_attr] = input_name else: # single Tensor with specified type if base_types[0] != input_arg.type: assert False, "Unreachable" if input_arg.is_ref: if not all(x._is_ref_dtype for x in types): # pylint: disable=protected-access raise TypeError( f"'{op_type_name}' Op requires that input '{input_name}' be a " "mutable tensor (e.g.: a tf.Variable)") input_types.extend(types) else: input_types.extend(base_types) def _ExtractRemainingAttrs(op_type_name, op_def, keywords, default_type_attr_map, attrs): """Extracts the remaining attributes into `attrs` in _apply_op_helper.""" for attr in op_def.attr: # Skip attrs that have already had their values inferred if attr.name in attrs: if attr.name in keywords: raise TypeError( f"Should not specify value for inferred attr '{attr.name}' for " f"{op_type_name}.") continue if attr.name in keywords: attrs[attr.name] = keywords.pop(attr.name) elif attr.name + "_" in keywords: # Attrs whose names match Python keywords have an extra '_' # appended, so we must check for that as well. attrs[attr.name] = keywords.pop(attr.name + "_") elif attr.name in default_type_attr_map: attrs[attr.name] = default_type_attr_map[attr.name] else: raise TypeError(f"No argument found for attr {attr.name} for " f"{op_type_name}") def _GetOpDef(op_type_name, keywords): """Returns the OpDef, Graph and Producer. For use in _apply_op_helper.""" op_def = op_def_registry.get(op_type_name) if op_def is None: raise RuntimeError(f"Unrecognized Op name {op_type_name}") # Determine the graph context. try: # Need to flatten all the arguments into a list. # pylint: disable=protected-access g = ops._get_graph_from_inputs(_Flatten(keywords.values())) producer = g.graph_def_versions.producer # pylint: enable=protected-access except AssertionError as e: raise RuntimeError( f"Cannot determine graph for Op '{op_type_name}' due to: {e.message}") return op_def, g, producer def _CheckAllInputsUsed(op_type_name, keywords): """Ensures all inputs passed into _apply_op_helper were used.""" if keywords: all_keywords = ", ".join(sorted(keywords.keys())) raise TypeError(f"{op_type_name} got unexpected keyword arguments: " f"{all_keywords}.") def _apply_op_helper(op_type_name, name=None, **keywords): # pylint: disable=invalid-name """Implementation of apply_op that returns output_structure, op.""" op_def, g, producer = _GetOpDef(op_type_name, keywords) name = name if name else op_type_name attrs, attr_protos = {}, {} default_type_attr_map, allowed_list_attr_map = {}, {} inputs, input_types, output_structure = [], [], [] fallback = True if (_CanExtractAttrsFastPath(op_def, keywords) and flags.config().graph_building_optimization.value()): fallback = False attr_protos, inputs, input_types, output_structure = ( op_def_library_pybind.process_inputs(op_type_name, producer, keywords)) if fallback: _CheckOpDeprecation(op_type_name, op_def, producer) _ExtractDefaultTypesAndAllowedTypes(op_def, default_type_attr_map, allowed_list_attr_map) # Requires that op_def has passed validation (using the C++ # ValidateOpDef() from ../framework/op_def_util.h). with g.as_default(), ops.name_scope(name) as scope: if fallback: _ExtractInputsAndAttrs(op_type_name, op_def, allowed_list_attr_map, keywords, default_type_attr_map, attrs, inputs, input_types) _ExtractRemainingAttrs(op_type_name, op_def, keywords, default_type_attr_map, attrs) _ExtractAttrProto(op_type_name, op_def, attrs, attr_protos) del attrs # attrs is no longer authoritative, use attr_protos instead _ExtractOutputStructure(op_type_name, op_def, attr_protos, output_structure) _CheckAllInputsUsed(op_type_name, keywords) # NOTE(mrry): We add an explicit colocation constraint between # the newly created op and any of its reference-typed inputs. must_colocate_inputs = [val for arg, val in zip(op_def.input_arg, inputs) if arg.is_ref] with _MaybeColocateWith(must_colocate_inputs): # Add Op to graph # pylint: disable=protected-access op = g._create_op_internal(op_type_name, inputs, dtypes=None, name=scope, input_types=input_types, attrs=attr_protos, op_def=op_def) # `outputs` is returned as a separate return value so that the output # tensors can the `op` per se can be decoupled so that the # `op_callbacks` can function properly. See framework/op_callbacks.py # for more details. outputs = op.outputs # Conditionally invoke tfdbg v2's op callback(s). if op_callbacks.should_invoke_op_callbacks(): callback_outputs = op_callbacks.invoke_op_callbacks( op.node_def.op, tuple(op.inputs), attr_protos, tuple(outputs), op_name=op.name, graph=g) if callback_outputs is not None: outputs = callback_outputs return output_structure, op_def.is_stateful, op, outputs def value_to_attr_value(value, attr_type, arg_name): # pylint: disable=invalid-name """Encodes a Python value as an `AttrValue` proto message. Args: value: The value to convert. attr_type: The value type (string) -- see the AttrValue proto definition for valid strings. arg_name: Argument name (for error messages). Returns: An AttrValue proto message that encodes `value`. """ attr_value = attr_value_pb2.AttrValue() if attr_type.startswith("list("): if not _IsListValue(value): raise TypeError(f"Expected list for attr {arg_name}, obtained " f"{type(value).__name__} instead.") if attr_type == "string": attr_value.s = _MakeStr(value, arg_name) elif attr_type == "list(string)": attr_value.list.s.extend([_MakeStr(x, arg_name) for x in value]) elif attr_type == "int": attr_value.i = _MakeInt(value, arg_name) elif attr_type == "list(int)": attr_value.list.i.extend([_MakeInt(x, arg_name) for x in value]) elif attr_type == "float": attr_value.f = _MakeFloat(value, arg_name) elif attr_type == "list(float)": attr_value.list.f.extend([_MakeFloat(x, arg_name) for x in value]) elif attr_type == "bool": attr_value.b = _MakeBool(value, arg_name) elif attr_type == "list(bool)": attr_value.list.b.extend([_MakeBool(x, arg_name) for x in value]) elif attr_type == "type": attr_value.type = _MakeType(value, arg_name) elif attr_type == "list(type)": attr_value.list.type.extend([_MakeType(x, arg_name) for x in value]) elif attr_type == "shape": attr_value.shape.CopyFrom(_MakeShape(value, arg_name)) elif attr_type == "list(shape)": attr_value.list.shape.extend([_MakeShape(x, arg_name) for x in value]) elif attr_type == "tensor": attr_value.tensor.CopyFrom(_MakeTensor(value, arg_name)) elif attr_type == "list(tensor)": attr_value.list.tensor.extend([_MakeTensor(x, arg_name) for x in value]) elif attr_type == "func": attr_value.func.CopyFrom(_MakeFunc(value, arg_name)) elif attr_type == "list(func)": attr_value.list.func.extend([_MakeFunc(x, arg_name) for x in value]) else: raise TypeError(f"Unrecognized Attr type {attr_type} for {arg_name}.") return attr_value # LINT.ThenChange(//tensorflow/python/framework/op_def_library_pybind.cc) # The following symbols are used by op_def_util.cc. _pywrap_utils.RegisterPyObject("tf.dtypes.DType", dtypes.DType) _pywrap_utils.RegisterPyObject("tf.dtypes.as_dtype", dtypes.as_dtype) _pywrap_utils.RegisterPyObject("tf.TensorShape", tensor_shape.TensorShape) _pywrap_utils.RegisterPyObject("tf.as_shape", tensor_shape.as_shape) _pywrap_utils.RegisterPyObject("tf.TensorProto", tensor_pb2.TensorProto) _pywrap_utils.RegisterPyObject("text_format.Parse", text_format.Parse) _pywrap_utils.RegisterPyObject("tf.convert_to_tensor", ops.convert_to_tensor)