427 lines
21 KiB
C++
427 lines
21 KiB
C++
/* Copyright 2017 The OpenXLA Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
// Shape inference is used by the XLA service as the user builds up
|
|
// computation requests.
|
|
|
|
#ifndef XLA_SERVICE_SHAPE_INFERENCE_H_
|
|
#define XLA_SERVICE_SHAPE_INFERENCE_H_
|
|
|
|
#include <vector>
|
|
|
|
#include "absl/types/span.h"
|
|
#include "xla/hlo/ir/hlo_instruction.h"
|
|
#include "xla/hlo/ir/hlo_opcode.h"
|
|
#include "xla/statusor.h"
|
|
#include "xla/types.h"
|
|
#include "xla/xla_data.pb.h"
|
|
|
|
namespace xla {
|
|
|
|
// For a given operation and input shapes, infers what the resulting shape is
|
|
// for the operation. With this functionality, the user does not need to specify
|
|
// the expected result type for computations that are built up via the API --
|
|
// the shape that results from an operation is inferred. Some methods have
|
|
// overloads for inferring shape at the HLO level.
|
|
//
|
|
// TODO(b/73352135): Shape inference does not issue very good error messages, in
|
|
// part because HloInstruction::ToString() is not available since shape
|
|
// inference runs before the HloInstruction object is created. We need a
|
|
// solution for this.
|
|
class ShapeInference {
|
|
public:
|
|
// Infers the shape produced by applying the given unary operation to the
|
|
// given input shape.
|
|
static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
|
|
const Shape& shape);
|
|
static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
|
|
const HloInstruction* operand);
|
|
|
|
// Infers the shape produced by applying the given binary operation to the
|
|
// given input shapes.
|
|
static StatusOr<Shape> InferBinaryOpShape(
|
|
HloOpcode opcode, const Shape& lhs, const Shape& rhs,
|
|
absl::Span<const int64_t> broadcast_dimensions);
|
|
static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
|
|
const HloInstruction* lhs,
|
|
const HloInstruction* rhs);
|
|
|
|
// Infers the shape produced by applying the given ternary operation to the
|
|
// given input shapes.
|
|
static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
|
|
const Shape& rhs,
|
|
const Shape& ehs);
|
|
static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
|
|
const HloInstruction* lhs,
|
|
const HloInstruction* rhs,
|
|
const HloInstruction* ehs);
|
|
|
|
// Infers the shape produced by applying the given variadic operation to the
|
|
// given input operand shapes.
|
|
static StatusOr<Shape> InferVariadicOpShape(
|
|
HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
|
|
static StatusOr<Shape> InferVariadicOpShape(
|
|
HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
|
|
|
|
// Infers the shape produced by applying the given mapping computation shape
|
|
// to the given operand shapes.
|
|
static StatusOr<Shape> InferMapShape(
|
|
absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
|
|
absl::Span<const int64_t> dimensions);
|
|
|
|
// Infers the shape produced by InferBatchNormTraining with the given
|
|
// operands.
|
|
static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
|
|
const Shape& scale_shape,
|
|
const Shape& offset_shape,
|
|
int64_t feature_index);
|
|
|
|
// Infers the shape produced by InferBatchNormInference with the given
|
|
// operands.
|
|
static StatusOr<Shape> InferBatchNormInferenceShape(
|
|
const Shape& operand_shape, const Shape& scale_shape,
|
|
const Shape& offset_shape, const Shape& mean_shape,
|
|
const Shape& variance_shape, int64_t feature_index);
|
|
|
|
// Infers the shape produced by InferBatchNormGrad with the given operands.
|
|
static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
|
|
const Shape& scale_shape,
|
|
const Shape& mean_shape,
|
|
const Shape& var_shape,
|
|
const Shape& output_grad_shape,
|
|
int64_t feature_index);
|
|
|
|
// Infers the shape produced by applying the given convolutional filter (rhs)
|
|
// to lhs in the way specified by the fields on window. An optional
|
|
// preferred_element_type can be specified to upcast the element type.
|
|
static StatusOr<Shape> InferConvolveShape(
|
|
const Shape& lhs, const Shape& rhs, int64_t feature_group_count,
|
|
int64_t batch_group_count, const Window& window,
|
|
const ConvolutionDimensionNumbers& dimension_numbers,
|
|
std::optional<PrimitiveType> preferred_element_type);
|
|
|
|
// Infers the shape produced by the given FFT type on the given operand.
|
|
static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
|
|
absl::Span<const int64_t> fft_length);
|
|
|
|
// Infers the shape produced by the given triangular solve operation.
|
|
static StatusOr<Shape> InferTriangularSolveShape(
|
|
const Shape& a, const Shape& b, const TriangularSolveOptions& options);
|
|
|
|
// Infers the shape produced by the given triangular solve operation.
|
|
static StatusOr<Shape> InferCholeskyShape(const Shape& a);
|
|
|
|
// Infers the shape produced by an all-gather with the given operand shape,
|
|
// concat dimension, and shard count.
|
|
static StatusOr<Shape> InferAllGatherShape(
|
|
absl::Span<const Shape* const> operand_shapes,
|
|
int64_t all_gather_dimension, int64_t shard_count);
|
|
|
|
// Infers the shape produced by an all-gather-start with the given operand
|
|
// shape, concat dimension, and shard count.
|
|
static StatusOr<Shape> InferAllGatherStartShape(
|
|
absl::Span<const Shape* const> operand_shapes,
|
|
int64_t all_gather_dimension, int64_t shard_count);
|
|
|
|
// Infers the shape produced by an all-gather-done given a certain
|
|
// all-gather-start shape.
|
|
static StatusOr<Shape> InferAllGatherDoneShape(
|
|
const Shape& all_gather_start_shape);
|
|
|
|
// Infers the shape produced by a cross replica sum with the given operand
|
|
// shapes.
|
|
static StatusOr<Shape> InferAllReduceShape(
|
|
absl::Span<const Shape* const> operand_shapes);
|
|
|
|
// Infers the shape produced by a reduce-scatter with the given operand
|
|
// shape, scatter dimension, and shard count.
|
|
static StatusOr<Shape> InferReduceScatterShape(
|
|
absl::Span<const Shape* const> operand_shapes, int64_t scatter_dimension,
|
|
int64_t shard_count);
|
|
|
|
// Infers the shape produced by a cross replica sum start.
|
|
static StatusOr<Shape> InferAllReduceStartShape(
|
|
absl::Span<const Shape* const> operand_shapes);
|
|
|
|
// Infers the shape produced by a cross replica sum done.
|
|
static StatusOr<Shape> InferAllReduceDoneShape(const Shape& operand_shape);
|
|
|
|
// Infers final shape of an Alltoall operation that is created by the xla
|
|
// builder.
|
|
static StatusOr<Shape> InferAllToAllShape(const Shape& shape,
|
|
int64_t split_dimension,
|
|
int64_t concat_dimension,
|
|
int64_t split_count);
|
|
|
|
// Infers the shape of an HLO all-to-all instruction.
|
|
static StatusOr<Shape> InferAllToAllTupleShape(
|
|
absl::Span<const Shape* const> operand_shapes);
|
|
|
|
// Infers the shape of a collective permute operation.
|
|
static StatusOr<Shape> InferCollectivePermuteShape(
|
|
absl::Span<const Shape* const> operand_shapes);
|
|
|
|
// Infers the shape of a collective permute start operation.
|
|
static StatusOr<Shape> InferCollectivePermuteStartShape(
|
|
absl::Span<const Shape* const> operand_shapes,
|
|
absl::Span<const Shape> context_shapes);
|
|
|
|
// Infers the shape of a collective permute operation.
|
|
static StatusOr<Shape> InferCollectivePermuteDoneShape(
|
|
const Shape& operand_shape);
|
|
|
|
// Infers the shape produced by applying the given reduction computation
|
|
// shape to the given input operand shape.
|
|
//
|
|
// If pass_index is true, the reduce function is invoked with the element
|
|
// index as the leading parameter, and the program shape should match
|
|
// accordingly (or an error will result).
|
|
static StatusOr<Shape> InferReduceShape(
|
|
absl::Span<const Shape* const> arg_shapes,
|
|
absl::Span<const int64_t> dimensions_to_reduce,
|
|
const ProgramShape& to_apply);
|
|
|
|
// Infers the shape produced by applying the given computation to the operand
|
|
// shape with the given window and stride dimensions.
|
|
static StatusOr<Shape> InferReduceWindowShape(
|
|
const Shape& operand_shape, const Shape& init_value, const Window& window,
|
|
const ProgramShape& to_apply_shape);
|
|
static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
|
|
const Shape& init_value,
|
|
const Window& window);
|
|
static StatusOr<Shape> InferReduceWindowShape(
|
|
absl::Span<const Shape* const> operands,
|
|
absl::Span<const Shape* const> init_values, const Window& window,
|
|
const ProgramShape& to_apply_shape);
|
|
|
|
static StatusOr<Shape> InferReduceWindowShape(
|
|
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
|
const Window& window);
|
|
|
|
// Infers the shape produced by scattering the given source shape to the
|
|
// selected indices of each window on the operand shape.
|
|
static StatusOr<Shape> InferSelectAndScatterShape(
|
|
const Shape& operand_shape, const ProgramShape& select_shape,
|
|
const Window& window, const Shape& source_shape,
|
|
const Shape& init_value_shape, const ProgramShape& scatter_shape);
|
|
|
|
// Infers the shape produced by a reverse operation that reverses the order
|
|
// of the elements in the given dimensions.
|
|
static StatusOr<Shape> InferReverseShape(
|
|
const Shape& operand_shape, absl::Span<const int64_t> dimensions);
|
|
|
|
// Infers the shape produced by a slice operation spanning from the starts to
|
|
// the limits in the original shape's dimensions.
|
|
//
|
|
// e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
|
|
static StatusOr<Shape> InferSliceShape(const Shape& arg,
|
|
absl::Span<const int64_t> starts,
|
|
absl::Span<const int64_t> limits,
|
|
absl::Span<const int64_t> strides);
|
|
|
|
// Infers the shape produced by a dynamic slice operation of size specified
|
|
// in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
|
|
static StatusOr<Shape> InferDynamicSliceShape(
|
|
const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
|
|
absl::Span<const int64_t> slice_sizes, bool allow_scalar_indices = true);
|
|
|
|
// Infers the shape produced by a dynamic update slice operation based
|
|
// on the shape of operand and update.
|
|
static StatusOr<Shape> InferDynamicUpdateSliceShape(
|
|
const Shape& operand_shape, const Shape& update_shape,
|
|
absl::Span<const Shape> start_index_shapes,
|
|
bool allow_scalar_indices = true);
|
|
|
|
// Infers the shape produced by doing a compile-time-constant indexing into
|
|
// the given input shape. This is essential for operations on tuples, because
|
|
// it is impossible to infer the type that comes out of the tuple indexing if
|
|
// it is not a compile time constant.
|
|
static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
|
|
int64_t index);
|
|
|
|
// Infers the shape produced from a while node. condition and body are the
|
|
// shapes of computations for the condition and the body of a while node, and
|
|
// init is the shape of data initially passed in to the body as an argument.
|
|
// The shapes must match; condition: T -> PRED, body: T -> T, init: T
|
|
static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
|
|
const ProgramShape& body,
|
|
const Shape& init);
|
|
|
|
// Infers the shape produced by a predicated or indexed conditional operation.
|
|
static StatusOr<Shape> InferConditionalShape(
|
|
const Shape& branch_index,
|
|
absl::Span<const ProgramShape> branch_computations,
|
|
absl::Span<const Shape> branch_operands);
|
|
|
|
// Infers the shape produced by a broadcast operation.
|
|
static StatusOr<Shape> InferBroadcastShape(
|
|
const Shape& operand, absl::Span<const int64_t> broadcast_sizes);
|
|
|
|
// Checks whether the given parameters can form a broadcast. Returns the same
|
|
// output_shape if it's legal.
|
|
static StatusOr<Shape> InferBroadcastShape(
|
|
const Shape& operand_shape, const Shape& output_shape,
|
|
absl::Span<const int64_t> broadcast_dimensions);
|
|
|
|
// Infers the shape produced by a reshape operation from the element type of
|
|
// its operand and the new dimension sizes specified.
|
|
static StatusOr<Shape> InferReshapeShape(const Shape& operand,
|
|
absl::Span<const int64_t> dimensions,
|
|
absl::Span<const int64_t> new_sizes,
|
|
int64_t inferred_dimension);
|
|
|
|
// Infers the shape produced by a dynamic reshape operation from the element
|
|
// type of its operand and the new dimension sizes specified. The result shape
|
|
// will have dynamic dimensions as specific in `dim_is_dynamic` and bound
|
|
// `new_size_bounds`.
|
|
static StatusOr<Shape> InferDynamicReshapeShape(
|
|
const Shape& operand, absl::Span<const Shape* const> dim_size_shapes,
|
|
absl::Span<const int64_t> new_size_bounds,
|
|
const std::vector<bool>& dims_are_dynamic);
|
|
|
|
// Infers the shape produced by a transpose operation from the element type of
|
|
// its operand and its dimensions field.
|
|
static StatusOr<Shape> InferTransposeShape(
|
|
const Shape& operand, absl::Span<const int64_t> dimensions);
|
|
|
|
// Helper that infers the shape produced by performing a concatenate operation
|
|
// with the given operand shapes.
|
|
static StatusOr<Shape> InferConcatOpShape(
|
|
absl::Span<const Shape* const> arg_shapes, int64_t dimension);
|
|
|
|
// Helper that validates the given operand shape can be converted to the
|
|
// target output_shape via a convert instruction -- the requirement is that
|
|
// the shape is identical except for the element type.
|
|
static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
|
|
PrimitiveType new_element_type);
|
|
|
|
// Helper that validates the given operand shape can be bitcast converted to
|
|
// the target output_shape via a bitcast convert instruction -- the
|
|
// requirement is that the shape is identical except for the element type and
|
|
// the element types have identical bit-widths.
|
|
static StatusOr<Shape> InferBitcastConvertShape(
|
|
const Shape& operand_shape, PrimitiveType new_element_type);
|
|
|
|
// Helper that validates the given operand shape can be converted to the
|
|
// target output_shape via a stochastic convert instruction -- the requirement
|
|
// is that the shape is identical except for the element type.
|
|
static StatusOr<Shape> InferStochasticConvertShape(
|
|
const Shape& operand_shape, const Shape& random_shape,
|
|
PrimitiveType new_element_type);
|
|
|
|
// Helper that validates the input data type for a reduce-precision operation,
|
|
// and returns the result shape.
|
|
static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
|
|
const int exponent_bits,
|
|
const int mantissa_bits);
|
|
|
|
// Helper that infers the shape produced by a pad operation based on the
|
|
// padding configuration.
|
|
static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
|
|
const Shape& padding_value_shape,
|
|
const PaddingConfig& padding_config);
|
|
|
|
// Helper that validates the given arg_shapes are compatible with the shape of
|
|
// the to_apply parameters, and returns the to_apply result shape.
|
|
static StatusOr<Shape> InferCallShape(
|
|
absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
|
|
|
|
// Helper that infers the shape produced by performing a dot operation with
|
|
// the given LHS and RHS shapes. An optional preferred_element_type can be
|
|
// specified to upcast the element type.
|
|
static StatusOr<Shape> InferDotOpShape(
|
|
const Shape& lhs, const Shape& rhs,
|
|
const DotDimensionNumbers& dimension_numbers,
|
|
std::optional<PrimitiveType> preferred_element_type);
|
|
|
|
// Helper that infers the shape of the tensor produced by a gather operation
|
|
// with the given input shape, gather indices shape and gather dimension
|
|
// numbers.
|
|
static StatusOr<Shape> InferGatherShape(
|
|
const Shape& input_shape, const Shape& start_indices_shape,
|
|
const GatherDimensionNumbers& gather_dim_numbers,
|
|
absl::Span<const int64_t> slice_sizes);
|
|
|
|
// Helper that validates the given input shape, scatter indices shape, updates
|
|
// shape, and scatter dimension numbers that constitute a scatter operation,
|
|
// and returns the result shape of the scatter operation.
|
|
static StatusOr<Shape> InferScatterShape(
|
|
absl::Span<const Shape* const> arg_shapes,
|
|
const ProgramShape& to_apply_shape,
|
|
const ScatterDimensionNumbers& scatter_dim_numbers);
|
|
|
|
// Helper that validates the given input shape to GetDimensionSize.
|
|
static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
|
|
int64_t dimension);
|
|
|
|
// Helper that validates the given input shape to SetDimensionSize.
|
|
static StatusOr<Shape> InferSetDimensionSizeShape(const Shape& operand_shape,
|
|
const Shape& val_shape,
|
|
int64_t dimension);
|
|
|
|
static StatusOr<Shape> InferTopKShape(const Shape& operand_shape, int64_t k);
|
|
|
|
// Helper function for creating a Window proto from user-supplied data.
|
|
// Returns error if the user-supplied data was invalid.
|
|
static StatusOr<Window> InferWindowFromDimensions(
|
|
absl::Span<const int64_t> window_dimensions,
|
|
absl::Span<const int64_t> window_strides,
|
|
absl::Span<const std::pair<int64_t, int64_t>> padding,
|
|
absl::Span<const int64_t> lhs_dilation,
|
|
absl::Span<const int64_t> rhs_dilation,
|
|
std::optional<std::vector<bool>> window_reversal = std::nullopt);
|
|
|
|
private:
|
|
// Helper that infers the shape produced by performing an element-wise binary
|
|
// operation with the given LHS and RHS shapes.
|
|
// Note: By "element-wise" we mean operations that look at a single element in
|
|
// the LHS and a single element in the RHS to produce a single output element,
|
|
// even in the presence of broadcasting of one of the operands over the other.
|
|
static StatusOr<Shape> InferElementwiseBinaryOpShape(
|
|
HloOpcode operation, const Shape& lhs, const Shape& rhs,
|
|
absl::Span<const int64_t> broadcast_dimensions);
|
|
|
|
// Helper for inferring the shape of Clamp ops.
|
|
static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
|
|
const Shape& max);
|
|
|
|
// Helper for inferring the shape of Select ops.
|
|
static StatusOr<Shape> InferSelectShape(const Shape& pred,
|
|
const Shape& on_true,
|
|
const Shape& on_false);
|
|
|
|
// Helper for inferring shapes of binary operations which use degenerate
|
|
// dimension broadcasting (a dimension of size 1 in one operand is broadcast
|
|
// up to match the size of the dimension in the other operand).
|
|
static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
|
|
HloOpcode operation, const Shape& lhs, const Shape& rhs);
|
|
|
|
// Helper for inferring shapes of binary operations using "InDim"
|
|
// broadcasting. This is the broadcasting used in the *InDim binary operations
|
|
// (for example ComputationBuilder::AddInDim). smaller_shape must be a
|
|
// lower-rank shape than larger_shape. Returns the shape that the
|
|
// smaller_shape is broadcast to.
|
|
static StatusOr<Shape> InferInDimBroadcastShape(
|
|
const Shape& smaller_shape, const Shape& larger_shape,
|
|
absl::Span<const int64_t> broadcast_dimensions);
|
|
|
|
ShapeInference(const ShapeInference&) = delete;
|
|
ShapeInference& operator=(const ShapeInference&) = delete;
|
|
};
|
|
|
|
} // namespace xla
|
|
|
|
#endif // XLA_SERVICE_SHAPE_INFERENCE_H_
|