3RNN/Lib/site-packages/tensorflow/include/xla/runtime/symbolic_shape.h
2024-05-26 19:49:15 +02:00

118 lines
4.8 KiB
C++

/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_RUNTIME_SYMBOLIC_SHAPE_H_
#define XLA_RUNTIME_SYMBOLIC_SHAPE_H_
#include <optional>
#include "absl/status/statusor.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/runtime/arguments.h"
#include "xla/runtime/constraints.h"
namespace xla {
namespace runtime {
// Symbolic shapes resolver computes the symbolic shapes of the arguments based
// on the function signature, and concrete shapes of the arguments at runtime.
//
// Example: dimensions that have the same symbolic shape at runtime.
//
// signature: func @compute(%arg0: tensor<?xf32>, %arg1: tensor<?xf32)
// ^ ^
// arguments: memref<123xf32> memref<123xf32>
// ^ ^
// symbolic shapes: [-2xf32] [-2xf32]
//
// Each unknown dimension in the function signature will be assigned a symbolic
// dimension. If multiple shaped arguments have unknown dimensions that are the
// same at runtime, they will be assigned the same symbolic dimensions value
// (e.g. `-2` in the example above).
//
// If an unknown dimension at runtime is equal to some statically known
// dimension in the function signature (of any shaped argument), it will be
// resolved to that statically known constant value:
//
// Example: in this example unknown dimension of `arg0` replaced with a `32`.
//
// signature: func @compute(%arg0: tensor<?xf32>, %arg1: tensor<32xf32>)
// ^ ^
// arguments: memref<32xf32> memref<32xf32>
// ^ ^
// symbolic shapes: [32xf32] [32xf32]
//
// Unknown dimensions that are `1` at runtime are always materialized as a
// statically known `1` in the symbolic shape.
class SymbolicShapesResolver {
public:
// Dimension size can be symbolic (<= -2) or static.
using SymbolicShape = llvm::SmallVector<int64_t>;
// Dimension size can be dynamic (ShapedType::kDynamic) or static.
using StaticShape = llvm::SmallVector<int64_t>;
SymbolicShapesResolver(const FunctionType& signature,
absl::Span<const ArgumentConstraint> constraints);
// Resolves symbolic shapes from the runtime arguments. Returns failure if
// runtime dimensions do not match the statically known dimensions.
absl::StatusOr<llvm::SmallVector<SymbolicShape>> Resolve(
ArgumentsRef arguments) const;
// Resolves symbolic shapes and computes the hash value from the runtime
// arguments. Returns failure if runtime dimensions do not match the
// statically known dimensions.
//
// This function might not return the same hash value as calling `Resolve` and
// then `Hash`, because it might use more efficient hashing algorithm.
absl::StatusOr<llvm::hash_code> ResolveHash(ArgumentsRef arguments) const;
// Replaces all symbolic dimensions with dynamic dimension.
static StaticShape Normalize(const SymbolicShape& shape);
// Computes a hash value of the symbolic shapes.
static llvm::hash_code Hash(absl::Span<const SymbolicShape> symbolic_shapes);
ArgumentConstraint constraint(size_t index) const;
size_t num_arguments() const;
bool has_argument_sizes(size_t index) const;
const StaticShape& argument_sizes(size_t index) const;
bool seen_static_size(size_t dim) const;
private:
// Constraints on the function arguments.
llvm::SmallVector<ArgumentConstraint> constraints_;
// Statically known sizes of shaped arguments from the function signature. For
// non-shaped arguments (e.g. opaque pointers) we keep empty shape value.
llvm::SmallVector<std::optional<StaticShape>> arguments_sizes_;
// Values of statically known dimensions sizes in the function signature.
llvm::DenseSet<int64_t> seen_static_sizes_;
// The iteration order for the arguments when resolving symbolic shapes.
llvm::SmallVector<size_t> iteration_order_;
// The iteration order for the arguments when resolving symbolic shapes hash.
llvm::SmallVector<size_t> hash_iteration_order_;
};
} // namespace runtime
} // namespace xla
#endif // XLA_RUNTIME_SYMBOLIC_SHAPE_H_