# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Experimental library that exposes XLA operations directly in TensorFlow. It is sometimes useful to be able to build HLO programs directly from TensorFlow. This file provides Tensorflow operators that mirror the semantics of HLO operators as closely as possible. Note: Most of the operators defined in this module are used by the jax2tf converter (see go/jax2tf for details) and are used in SavedModel produced by jax2tf. Hence, we need to maintain backwards compatibility for these operators. Please reach out to the JAX team if you want to make changes. """ from tensorflow.compiler.tf2xla.ops import gen_xla_ops from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops.numpy_ops import np_utils # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing # ops include: # infeed/outfeed (available via tf.contrib.tpu) # collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) # conditional # gather/scatter # collapse # This file reuses builtin names (following XLA's names, so we can call things # like xla.max), so we capture the builtin versions here. # pylint: disable=redefined-builtin _max = max _min = min _slice = slice # pylint: disable=invalid-name constant = constant_op.constant # Unary operators. # For most arithmetic operators there is a TensorFlow operator # that exactly corresponds to each XLA operator. Rather than defining # XLA-specific variants, we reuse the corresponding TensorFlow operator. # TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 # wrap every HLO operator, because that would allow us to be confident that the # semantics match. def _unary_op(fn): """Wrapper that restricts `fn` to have the correct signature.""" def unary_op_wrapper(x, name=None): return fn(x, name=name) return unary_op_wrapper abs = _unary_op(math_ops.abs) # TODO(phawkins): implement clz. conj = _unary_op(math_ops.conj) cos = _unary_op(math_ops.cos) ceil = _unary_op(math_ops.ceil) digamma = _unary_op(math_ops.digamma) erf = _unary_op(math_ops.erf) erfc = _unary_op(math_ops.erfc) erfinv = _unary_op(math_ops.erfinv) ndtri = _unary_op(math_ops.ndtri) exp = _unary_op(math_ops.exp) expm1 = _unary_op(math_ops.expm1) floor = _unary_op(math_ops.floor) imag = _unary_op(math_ops.imag) is_finite = _unary_op(math_ops.is_finite) lgamma = _unary_op(math_ops.lgamma) log = _unary_op(math_ops.log) log1p = _unary_op(math_ops.log1p) logical_not = _unary_op(math_ops.logical_not) neg = _unary_op(math_ops.neg) real = _unary_op(math_ops.real) # TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for # numbers halfway between two integers. round = _unary_op(math_ops.round) sin = _unary_op(math_ops.sin) sign = _unary_op(math_ops.sign) tan = _unary_op(math_ops.tan) tanh = _unary_op(math_ops.tanh) # Bessel bessel_i0e = _unary_op(special_math_ops.bessel_i0e) bessel_i1e = _unary_op(special_math_ops.bessel_i1e) # Binary operators # The main difference between TensorFlow and XLA binary ops is the broadcasting # semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA # requires an explicit specification of which dimensions to broadcast if the # arguments have different ranks. def _broadcasting_binary_op(fn): """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): """Inner wrapper function.""" broadcast_dims = broadcast_dims or [] broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) # Rather than relying on having static shape information in the TensorFlow # graph, we use an XlaBroadcastHelper op that can compute the correct shapes # at JIT compilation time. x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) return fn(x, y, name=name) return broadcasting_binary_op_wrapper # Map from TF signed types to TF unsigned types. _SIGNED_TO_UNSIGNED_TABLE = { dtypes.int8: dtypes.uint8, dtypes.int16: dtypes.uint16, dtypes.int32: dtypes.uint32, dtypes.int64: dtypes.uint64, } # Map from TF unsigned types to TF signed types. _UNSIGNED_TO_SIGNED_TABLE = { dtypes.uint8: dtypes.int8, dtypes.uint16: dtypes.int16, dtypes.uint32: dtypes.int32, dtypes.uint64: dtypes.int64, } def _shift_right_logical_helper(x, y, name=None): """Performs an integer right logical shift irrespective of input type.""" assert y.dtype == x.dtype dtype = x.dtype signed = dtype in _SIGNED_TO_UNSIGNED_TABLE if signed: unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] x = math_ops.cast(x, unsigned_dtype) y = math_ops.cast(y, unsigned_dtype) output = bitwise_ops.right_shift(x, y, name=name) if signed: output = math_ops.cast(output, dtype) return output def _shift_right_arithmetic_helper(x, y, name=None): """Performs an integer right arithmetic shift irrespective of input type.""" assert y.dtype == x.dtype dtype = x.dtype unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE if unsigned: signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] x = math_ops.cast(x, signed_dtype) y = math_ops.cast(y, signed_dtype) output = bitwise_ops.right_shift(x, y, name=name) if unsigned: output = math_ops.cast(output, dtype) return output add = _broadcasting_binary_op(math_ops.add) sub = _broadcasting_binary_op(math_ops.sub) mul = _broadcasting_binary_op(math_ops.mul) div = _broadcasting_binary_op(math_ops.div) rem = _broadcasting_binary_op(gen_math_ops.mod) max = _broadcasting_binary_op(math_ops.maximum) min = _broadcasting_binary_op(math_ops.minimum) atan2 = _broadcasting_binary_op(math_ops.atan2) complex = _broadcasting_binary_op(math_ops.complex) logical_and = _broadcasting_binary_op(math_ops.logical_and) logical_or = _broadcasting_binary_op(math_ops.logical_or) logical_xor = _broadcasting_binary_op(math_ops.logical_xor) eq = _broadcasting_binary_op(math_ops.equal) ne = _broadcasting_binary_op(math_ops.not_equal) ge = _broadcasting_binary_op(math_ops.greater_equal) gt = _broadcasting_binary_op(math_ops.greater) le = _broadcasting_binary_op(math_ops.less_equal) lt = _broadcasting_binary_op(math_ops.less) pow = _broadcasting_binary_op(math_ops.pow) shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) igamma = _broadcasting_binary_op(math_ops.igamma) igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) igammac = _broadcasting_binary_op(math_ops.igammac) polygamma = _broadcasting_binary_op(math_ops.polygamma) zeta = _broadcasting_binary_op(math_ops.zeta) def _binary_op(fn): """Wrapper that restricts `fn` to have the correct signature.""" def binary_op_wrapper(x, y, name=None): return fn(x, y, name=name) return binary_op_wrapper transpose = _binary_op(array_ops.transpose) rev = _binary_op(array_ops.reverse) bitcast_convert_type = array_ops.bitcast def broadcast(x, dims, name=None): x = ops.convert_to_tensor(x) shape = array_ops.concat([constant_op.constant(dims), array_ops.shape(x)], axis=0) return array_ops.broadcast_to(x, shape, name=name) def clamp(a, x, b, name=None): return min(max(a, x, name=name), b, name=name) concatenate = array_ops.concat def conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count=1, precision_config=None, preferred_element_type=None, name=None, use_v2=False, batch_group_count=1): """Wraps the XLA ConvGeneralDilated operator. ConvGeneralDilated is the most general form of XLA convolution and is documented at https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution Args: lhs: the input tensor rhs: the kernel tensor window_strides: the inter-window strides padding: the padding to apply at the start and end of each input dimensions lhs_dilation: dilation to apply between input elements rhs_dilation: dilation to apply between kernel elements dimension_numbers: a `ConvolutionDimensionNumbers` proto. feature_group_count: number of feature groups for grouped convolution. precision_config: a `xla.PrecisionConfig` proto. preferred_element_type: the result `dtype`. name: an optional name for the operator. use_v2: an optional request to use the XlaConvV2 op even if not necessary. batch_group_count: number of batch groups or grouped filters. Returns: A tensor representing the output of the convolution. """ precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() needs_v2 = ( preferred_element_type or (lhs.dtype != rhs.dtype) or batch_group_count > 1) if preferred_element_type is None: preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) if needs_v2 or use_v2: return gen_xla_ops.xla_conv_v2( lhs, rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, feature_group_count=feature_group_count, batch_group_count=batch_group_count, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, preferred_element_type=preferred_element_type, name=name) return gen_xla_ops.xla_conv( lhs, rhs, window_strides=window_strides, padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, feature_group_count=feature_group_count, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, name=name) convert_element_type = math_ops.cast def dot(lhs, rhs, name=None): return math_ops.tensordot(lhs, rhs, axes=1, name=name) DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers PrecisionConfig = xla_data_pb2.PrecisionConfig def dot_general(lhs, rhs, dimension_numbers, precision_config=None, preferred_element_type=None, name=None, use_v2=False): precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) if preferred_element_type is None: preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) if needs_v2 or use_v2: return gen_xla_ops.xla_dot_v2( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, preferred_element_type=preferred_element_type, name=name) return gen_xla_ops.xla_dot( lhs, rhs, dimension_numbers=dimension_numbers.SerializeToString(), precision_config=precision_config_proto, name=name) def self_adjoint_eig(a, lower, max_iter, epsilon): return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) def svd(a, max_iter, epsilon, precision_config=None): precision_config_proto = "" if precision_config: precision_config_proto = precision_config.SerializeToString() return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) dynamic_slice = gen_xla_ops.xla_dynamic_slice dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice einsum = gen_xla_ops.xla_einsum # TODO(phawkins): generalize tf.pad to support interior padding, and then remove # the XLA-specific pad operator. pad = gen_xla_ops.xla_pad def random_normal(mu, sigma, dims, name=None): mu = ops.convert_to_tensor(mu) return random_ops.random_normal( dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) def random_uniform(minval, maxval, dims, name=None): minval = ops.convert_to_tensor(minval) return random_ops.random_uniform( dims, minval, maxval, dtype=minval.dtype, name=name) def rng_bit_generator(algorithm, initial_state, shape, dtype): """Stateless PRNG bit generator. Wraps the XLA RngBitGenerator operator, documented at https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. Args: algorithm: The PRNG algorithm to use, one of tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should be a u64[2] and for PHILOX a u64[3]. shape: The output shape of the generated data. dtype: The type of the tensor. Returns: a tuple with a new state and generated data of the given shape. """ alg_int = stateless_random_ops.convert_alg_to_int(algorithm) return gen_xla_ops.xla_rng_bit_generator( alg_int, initial_state, shape, dtype=dtype) recv = gen_xla_ops.xla_recv reduce = gen_xla_ops.xla_reduce variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2 ops.no_gradient("XlaVariadicReduce") def reduce_window(operand, init, reducer, window_dimensions, window_strides=None, base_dilations=None, window_dilations=None, padding=None, name=None): """Wraps the XLA ReduceWindow operator. ReduceWindow is documented at https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . Args: operand: the input tensor init: a scalar tensor representing the initial value for the reduction reducer: a reduction function that combines a pair of scalars. window_dimensions: shape of the window, as a list of integers window_strides: inter-window strides, as a list of integers. Optional; if omitted, defaults to strides of 1. padding: padding to apply to 'operand'. List of (low, high) pairs of integers that specify the padding to apply before and after each dimension. Optional; if omitted, defaults to no padding. name: the operator name, or None. Returns: A tensor that represents the output of the reduce_window operator. """ window_strides = window_strides or [1] * len(window_dimensions) base_dilations = base_dilations or [1] * len(window_dimensions) window_dilations = window_dilations or [1] * len(window_dimensions) padding = padding or [(0, 0)] * len(window_dimensions) return gen_xla_ops.xla_reduce_window( input=operand, init_value=init, window_dimensions=window_dimensions, window_strides=window_strides, base_dilations=base_dilations, window_dilations=window_dilations, padding=padding, computation=reducer, name=name) replica_id = gen_xla_ops.xla_replica_id # Set a static bound for the given input value as a hint to Xla compiler, # returns the same value. # Usage: # def f(t, p): # p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. # return t[:p] # xla knows the bound of the slice is 3. set_bound = gen_xla_ops.xla_set_bound # Make a static dimension into a xla bounded dynamic dimension. The current # static dimension size will become the bound and the second operand becomes the # dynamic size of the dimension. # # This should mostly be used for testing. # # def f(): # array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) # # Tells xla the valid size of the array is 3. # dim = 0 # p = xla_set_dynamic_dimension_size(array, dim, 3) # assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size # Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic # dimension into a static dimension. The bound of the size of dimension # `dim_index` becomes the static dimension size. remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size def reshape(x, new_sizes, dimensions=None, name=None): if dimensions is not None: x = array_ops.transpose(x, dimensions) x = array_ops.reshape(x, new_sizes, name=name) return x def select(condition, x, y, name=None): return array_ops.where(condition, x, y, name) select_and_scatter = gen_xla_ops.xla_select_and_scatter send = gen_xla_ops.xla_send def slice(x, start_dims, limit_dims, strides): spec = [ _slice(start, limit, stride) for (start, limit, stride) in zip(start_dims, limit_dims, strides) ] return x[tuple(spec)] sharding = gen_xla_ops.xla_sharding @ops.RegisterGradient("XlaSharding") def _sharding_grad(op, grad): """Gradient for XlaSharding op.""" sharding_attr = op.get_attr("sharding") grad_sharding = gen_xla_ops.xla_sharding( grad, sharding=sharding_attr, unspecified_dims=op.get_attr("unspecified_dims")) # pylint: disable=protected-access grad_sharding.op._set_attr("_XlaSharding", attr_value_pb2.AttrValue(s=sharding_attr)) return [grad_sharding] spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape @ops.RegisterGradient("XlaSpmdFullToShardShape") def _spmd_full_to_shard_shape_grad(op, grad): s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( grad, manual_sharding=op.get_attr("manual_sharding"), full_shape=op.inputs[0].shape.as_list(), dim=op.get_attr("dim"), unspecified_dims=op.get_attr("unspecified_dims")) return [s2f] @ops.RegisterGradient("XlaSpmdShardToFullShape") def _spmd_shard_to_full_shape_grad(op, grad): f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( grad, manual_sharding=op.get_attr("manual_sharding"), dim=op.get_attr("dim"), unspecified_dims=op.get_attr("unspecified_dims")) return [f2s] sort = gen_xla_ops.xla_sort key_value_sort = gen_xla_ops.xla_key_value_sort variadic_sort = gen_xla_ops.xla_variadic_sort while_loop = gen_xla_ops.xla_while dequantize = gen_xla_ops.xla_dequantize custom_call = gen_xla_ops.xla_custom_call def custom_call_v2( call_target_name, operands, result_specs, backend_config=None, has_side_effect=None, name=None, ): """Emits an HLO `CustomCall` operation with multiple outputs. See `CustomCall` specification at https://tensorflow.org/xla/operation_semantics#customcall, and `mhlo.custom_call` specification at https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop. Args: call_target_name: Name of the user function. The function signature must conform to version 3 of the API, see `API_VERSION_STATUS_RETURNING_UNIFIED`. All operands and results assumed to be in the default layout. operands: A sequence of tensors with possibly different types. result_specs: A sequence of tensor specs for all results. backend_config: A string that encodes a metadata for the backend. Empty string by default. has_side_effect: Indicates whether the custom call has side effects. `False` by default. name: Optional name of the operation. Returns: A tuple of output tensors. """ return gen_xla_ops.xla_custom_call_v2( operands=operands, call_target_name=call_target_name, backend_config="" if backend_config is None else backend_config, has_side_effect=False if has_side_effect is None else has_side_effect, result_dtypes=tuple(spec.dtype for spec in result_specs), result_shapes=tuple(spec.shape for spec in result_specs), name=name, ) def call_module(args, *, version=2, module, Tout, Sout, dim_args_spec=()): # See documentation for the XlaCallModule op. return gen_xla_ops.xla_call_module( args, version=version, module=module, dim_args_spec=dim_args_spec, Tout=Tout, Sout=Sout) def gather(operand, start_indices, dimension_numbers, slice_sizes, indices_are_sorted=False, name=None): return gen_xla_ops.xla_gather( operand, start_indices, slice_sizes=slice_sizes, dimension_numbers=dimension_numbers.SerializeToString(), indices_are_sorted=indices_are_sorted, name=name) def scatter(operand, scatter_indices, updates, update_computation, dimension_numbers, indices_are_sorted=False, name=None): return gen_xla_ops.xla_scatter( operand, scatter_indices, updates, update_computation=update_computation, dimension_numbers=dimension_numbers.SerializeToString(), indices_are_sorted=indices_are_sorted, name=name) def optimization_barrier(*args): return gen_xla_ops.xla_optimization_barrier(args) def reduce_precision(operand, exponent_bits, mantissa_bits): return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)