439 lines
14 KiB
C
439 lines
14 KiB
C
|
/* Copyright 2018 The OpenXLA Authors.
|
||
|
|
||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
you may not use this file except in compliance with the License.
|
||
|
You may obtain a copy of the License at
|
||
|
|
||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
||
|
Unless required by applicable law or agreed to in writing, software
|
||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
See the License for the specific language governing permissions and
|
||
|
limitations under the License.
|
||
|
==============================================================================*/
|
||
|
|
||
|
#ifndef XLA_LAYOUT_H_
|
||
|
#define XLA_LAYOUT_H_
|
||
|
|
||
|
#include <cstdint>
|
||
|
#include <limits>
|
||
|
#include <memory>
|
||
|
#include <ostream>
|
||
|
#include <string>
|
||
|
|
||
|
#include "absl/container/inlined_vector.h"
|
||
|
#include "absl/types/span.h"
|
||
|
#include "xla/printer.h"
|
||
|
#include "xla/util.h"
|
||
|
#include "xla/xla_data.pb.h"
|
||
|
#include "tsl/platform/logging.h" // IWYU pragma: keep
|
||
|
|
||
|
namespace xla {
|
||
|
|
||
|
class Shape;
|
||
|
|
||
|
// Describes a tile used in tiling-based layout. Refer to
|
||
|
// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
|
||
|
// details.
|
||
|
class Tile {
|
||
|
public:
|
||
|
Tile() = default;
|
||
|
explicit Tile(absl::Span<const int64_t> dimensions)
|
||
|
: dimensions_(dimensions.begin(), dimensions.end()) {}
|
||
|
|
||
|
// De/Serialize a Tile to and from a TileProto.
|
||
|
static Tile CreateFromProto(const TileProto& tile_proto) {
|
||
|
return Tile(tile_proto.dimensions());
|
||
|
}
|
||
|
TileProto ToProto() const;
|
||
|
|
||
|
bool operator==(const Tile& other) const {
|
||
|
return dimensions() == other.dimensions();
|
||
|
}
|
||
|
bool operator!=(const Tile& other) const { return !(*this == other); }
|
||
|
|
||
|
void Print(Printer* printer) const;
|
||
|
|
||
|
std::string ToString() const;
|
||
|
|
||
|
// Returns the bound of the tile in the given dimension index.
|
||
|
int64_t dimension(int i) const { return dimensions_[i]; }
|
||
|
|
||
|
// Returns the dimensions of the tile.
|
||
|
absl::Span<const int64_t> dimensions() const { return dimensions_; }
|
||
|
|
||
|
Tile& add_dimensions(int64_t value) {
|
||
|
dimensions_.push_back(value);
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Tile& clear_dimensions() {
|
||
|
dimensions_.clear();
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// This dimension size means the corresponding dimension in the shape is
|
||
|
// combined with the next minor dimension before tiling is applied.
|
||
|
static constexpr int64_t kCombineDimension =
|
||
|
std::numeric_limits<int64_t>::min();
|
||
|
|
||
|
template <typename H>
|
||
|
friend H AbslHashValue(H h, const Tile& t) {
|
||
|
return H::combine(std::move(h), t.dimensions_);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
// The bounds of the tile.
|
||
|
absl::InlinedVector<int64_t, 2> dimensions_;
|
||
|
};
|
||
|
|
||
|
using TileVector = absl::InlinedVector<Tile, 3>;
|
||
|
|
||
|
// TODO: Rename the `dim_level_types` field to `lvl_types`, so that it
|
||
|
// matches `mlir::sparse_tensor::SparseTensorEncodingAttr`.
|
||
|
class Layout {
|
||
|
public:
|
||
|
Layout();
|
||
|
Layout(const Layout& other);
|
||
|
Layout(Layout&& other);
|
||
|
~Layout();
|
||
|
|
||
|
// Constructs a dense layout with the given minor-to-major order.
|
||
|
explicit Layout(absl::Span<const int64_t> minor_to_major);
|
||
|
|
||
|
// Constructs a dense tiled layout with the given minor-to-major order, dim
|
||
|
// level types, and tiles.
|
||
|
explicit Layout(absl::Span<const int64_t> minor_to_major,
|
||
|
absl::Span<const DimLevelType> dim_level_types,
|
||
|
absl::Span<const bool> dim_unique,
|
||
|
absl::Span<const bool> dim_ordered,
|
||
|
absl::Span<const Tile> tiles,
|
||
|
int64_t tail_padding_alignment_in_elements = 1,
|
||
|
PrimitiveType index_primitive_type = PRIMITIVE_TYPE_INVALID,
|
||
|
PrimitiveType element_primitive_type = PRIMITIVE_TYPE_INVALID,
|
||
|
int64_t element_size_in_bits = 0, int64_t memory_space = 0,
|
||
|
std::unique_ptr<Shape> physical_shape = nullptr,
|
||
|
int64_t dynamic_shape_metadata_prefix_bytes = 0);
|
||
|
|
||
|
Layout& operator=(const Layout& other);
|
||
|
Layout& operator=(Layout&& other);
|
||
|
|
||
|
// Construct a shape from a LayoutProto.
|
||
|
static Layout CreateFromProto(const LayoutProto& proto);
|
||
|
|
||
|
// Returns a LayoutProto representation of the Layout.
|
||
|
LayoutProto ToProto() const;
|
||
|
|
||
|
// Prints a human-readable string that represents this layout.
|
||
|
void Print(Printer* printer) const;
|
||
|
|
||
|
// Returns a human-readable string that represents this layout.
|
||
|
std::string ToString() const;
|
||
|
|
||
|
// Equal is a configurable functor to check the equality of two layouts.
|
||
|
//
|
||
|
// Examples:
|
||
|
//
|
||
|
// - Comparing two layouts ignoring their difference in tiles:
|
||
|
// Equal().IgnoreTiles()(layout1, layout2);
|
||
|
class Equal {
|
||
|
public:
|
||
|
Equal() = default;
|
||
|
|
||
|
bool operator()(const Layout& lhs, const Layout& rhs);
|
||
|
|
||
|
Equal& IgnoreTiles() {
|
||
|
ignore_tiles_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& IgnoreTailPaddingAlignmentInElements() {
|
||
|
ignore_tail_padding_alignment_in_elements_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& IgnoreIndexPrimitiveType() {
|
||
|
ignore_index_primitive_type_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& IgnorePointerPrimitiveType() {
|
||
|
ignore_pointer_primitive_type_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& IgnoreMemorySpace() {
|
||
|
ignore_memory_space_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& IgnorePhysicalShape() {
|
||
|
ignore_physical_shape_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& IgnoreElementSize() {
|
||
|
ignore_element_size_ = true;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
Equal& MinorToMajorOnly() {
|
||
|
return IgnoreTiles()
|
||
|
.IgnoreIndexPrimitiveType()
|
||
|
.IgnorePointerPrimitiveType()
|
||
|
.IgnoreMemorySpace()
|
||
|
.IgnorePhysicalShape()
|
||
|
.IgnoreElementSize()
|
||
|
.IgnoreTailPaddingAlignmentInElements();
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
bool ignore_tiles_ = false;
|
||
|
bool ignore_tail_padding_alignment_in_elements_ = false;
|
||
|
bool ignore_element_size_ = false;
|
||
|
bool ignore_index_primitive_type_ = false;
|
||
|
bool ignore_pointer_primitive_type_ = false;
|
||
|
bool ignore_memory_space_ = false;
|
||
|
bool ignore_physical_shape_ = false;
|
||
|
};
|
||
|
|
||
|
bool operator==(const Layout& other) const;
|
||
|
bool operator!=(const Layout& other) const { return !(*this == other); }
|
||
|
|
||
|
// The following methods mirror the protobuf generated code interface for the
|
||
|
// message LayoutProto. This enabled easy migration of this data structure
|
||
|
// from a proto to a proper C++ class.
|
||
|
//
|
||
|
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
||
|
// interface.
|
||
|
|
||
|
// Methods for accessing the DimLevelType array.
|
||
|
int dim_level_types_size() const { return n_dim_level_types_; }
|
||
|
DimLevelType dim_level_type(int index) const {
|
||
|
return dim_attributes_[index].dim_level_type;
|
||
|
}
|
||
|
Layout& set_dim_level_type(int index, DimLevelType dim_level_type) {
|
||
|
dim_attributes_[index].dim_level_type = dim_level_type;
|
||
|
return *this;
|
||
|
}
|
||
|
Layout& add_dim_level_type(DimLevelType dim_level_type) {
|
||
|
while (n_dim_level_types_ >= dim_attributes_.size()) {
|
||
|
dim_attributes_.push_back(DimInfo());
|
||
|
}
|
||
|
dim_attributes_[n_dim_level_types_].dim_level_type = dim_level_type;
|
||
|
n_dim_level_types_++;
|
||
|
return *this;
|
||
|
}
|
||
|
Layout& clear_dim_level_types() {
|
||
|
n_dim_level_types_ = 0;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// Methods for accessing the dim_unique array.
|
||
|
int dim_unique_size() const { return n_dim_unique_; }
|
||
|
bool dim_unique(int index) const { return dim_attributes_[index].dim_unique; }
|
||
|
Layout& set_dim_unique(int index, bool unique) {
|
||
|
dim_attributes_[index].dim_unique = unique;
|
||
|
return *this;
|
||
|
}
|
||
|
Layout& add_dim_unique(bool unique) {
|
||
|
while (n_dim_unique_ >= dim_attributes_.size()) {
|
||
|
dim_attributes_.push_back(DimInfo());
|
||
|
}
|
||
|
dim_attributes_[n_dim_unique_].dim_unique = unique;
|
||
|
n_dim_unique_++;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// Methods for accessing the dim_ordered array.
|
||
|
int dim_ordered_size() const { return n_dim_ordered_; }
|
||
|
bool dim_ordered(int index) const {
|
||
|
return dim_attributes_[index].dim_ordered;
|
||
|
}
|
||
|
Layout& set_dim_ordered(int index, bool ordered) {
|
||
|
dim_attributes_[index].dim_ordered = ordered;
|
||
|
return *this;
|
||
|
}
|
||
|
Layout& add_dim_ordered(bool ordered) {
|
||
|
while (n_dim_ordered_ >= dim_attributes_.size()) {
|
||
|
dim_attributes_.push_back(DimInfo());
|
||
|
}
|
||
|
dim_attributes_[n_dim_ordered_].dim_ordered = ordered;
|
||
|
n_dim_ordered_++;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// Methods for accessing the minor-to-major array.
|
||
|
int minor_to_major_size() const { return minor_to_major_.size(); }
|
||
|
int64_t minor_to_major(int index) const { return minor_to_major_[index]; }
|
||
|
Layout& set_minor_to_major(int index, int64_t value) {
|
||
|
minor_to_major_[index] = value;
|
||
|
return *this;
|
||
|
}
|
||
|
Layout& add_minor_to_major(int64_t value) {
|
||
|
minor_to_major_.push_back(value);
|
||
|
return *this;
|
||
|
}
|
||
|
Layout& clear_minor_to_major() {
|
||
|
minor_to_major_.clear();
|
||
|
return *this;
|
||
|
}
|
||
|
// Removes the given dimension from 'minor_to_major_', and adjusts the other
|
||
|
// dimensions accordingly. Also adjusts 'dim_level_types_', 'dim_ordered_' and
|
||
|
// 'dim_unique_' in case it is a sparse layout.
|
||
|
Layout& DeleteDimension(int64_t dim_to_delete);
|
||
|
absl::Span<const int64_t> minor_to_major() const { return minor_to_major_; }
|
||
|
DimensionVector* mutable_minor_to_major() { return &minor_to_major_; }
|
||
|
|
||
|
// Methods for accessing the tile field.
|
||
|
int64_t tiles_size() const { return tiles_.size(); }
|
||
|
const Tile& tiles(int index) const { return tiles_[index]; }
|
||
|
Tile* mutable_tiles(int index) { return &tiles_[index]; }
|
||
|
Tile* add_tiles() {
|
||
|
tiles_.push_back(Tile());
|
||
|
return &tiles_.back();
|
||
|
}
|
||
|
Layout& clear_tiles() {
|
||
|
tiles_.clear();
|
||
|
return *this;
|
||
|
}
|
||
|
absl::Span<const Tile> tiles() const { return tiles_; }
|
||
|
TileVector* mutable_tiles() { return &tiles_; }
|
||
|
|
||
|
int64_t element_size_in_bits() const { return element_size_in_bits_; }
|
||
|
Layout& set_element_size_in_bits(int64_t value) {
|
||
|
element_size_in_bits_ = value;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
int64_t tail_padding_alignment_in_elements() const {
|
||
|
return tail_padding_alignment_in_elements_;
|
||
|
}
|
||
|
Layout& set_tail_padding_alignment_in_elements(int64_t value) {
|
||
|
tail_padding_alignment_in_elements_ = value;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
PrimitiveType index_primitive_type() const { return index_primitive_type_; }
|
||
|
Layout& set_index_primitive_type(PrimitiveType value) {
|
||
|
index_primitive_type_ = value;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
PrimitiveType pointer_primitive_type() const {
|
||
|
return pointer_primitive_type_;
|
||
|
}
|
||
|
Layout& set_pointer_primitive_type(PrimitiveType value) {
|
||
|
pointer_primitive_type_ = value;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
static constexpr int64_t kDefaultMemorySpace = 0;
|
||
|
static constexpr int64_t kGenericFastMemorySpace = 1;
|
||
|
int64_t memory_space() const { return memory_space_; }
|
||
|
Layout& set_memory_space(int64_t value) {
|
||
|
memory_space_ = value;
|
||
|
return *this;
|
||
|
}
|
||
|
|
||
|
// Methods for accessing the physical shape.
|
||
|
bool has_physical_shape() const { return physical_shape_ != nullptr; }
|
||
|
const Shape& physical_shape() const {
|
||
|
CHECK(has_physical_shape());
|
||
|
return *physical_shape_;
|
||
|
}
|
||
|
Shape* mutable_physical_shape();
|
||
|
void clear_physical_shape();
|
||
|
|
||
|
int64_t dynamic_shape_metadata_prefix_bytes() const {
|
||
|
return dynamic_shape_metadata_prefix_bytes_;
|
||
|
}
|
||
|
void set_dynamic_shape_metadata_prefix_bytes(int64_t bytes) {
|
||
|
dynamic_shape_metadata_prefix_bytes_ = bytes;
|
||
|
}
|
||
|
|
||
|
void Swap(Layout* other) {
|
||
|
using std::swap;
|
||
|
swap(*this, *other);
|
||
|
}
|
||
|
|
||
|
void Clear() { *this = Layout(); }
|
||
|
|
||
|
template <typename H>
|
||
|
friend H AbslHashValue(H h, const Layout& l) {
|
||
|
return H::combine(std::move(h), l.minor_to_major_, l.tiles_,
|
||
|
l.element_size_in_bits_, l.index_primitive_type_,
|
||
|
l.pointer_primitive_type_, l.memory_space_,
|
||
|
l.tail_padding_alignment_in_elements_);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
// We store a single inlined vector to hold
|
||
|
struct DimInfo {
|
||
|
DimInfo()
|
||
|
: dim_level_type(DIM_DENSE), dim_unique(false), dim_ordered(false) {}
|
||
|
|
||
|
DimLevelType dim_level_type : 6;
|
||
|
bool dim_unique : 1;
|
||
|
bool dim_ordered : 1;
|
||
|
};
|
||
|
absl::InlinedVector<DimInfo, InlineRank()> dim_attributes_;
|
||
|
|
||
|
uint8_t n_dim_level_types_ = 0;
|
||
|
uint8_t n_dim_unique_ = 0;
|
||
|
uint8_t n_dim_ordered_ = 0;
|
||
|
|
||
|
// The primitive type to use for sparse array indices and pointers. Each of
|
||
|
// these must either be INVALID, or an unsigned integer type.
|
||
|
PrimitiveType index_primitive_type_ : 8;
|
||
|
PrimitiveType pointer_primitive_type_ : 8;
|
||
|
|
||
|
// The number of bits used to store an individual array element.
|
||
|
// When the value is 0, default to ShapeUtil::ByteSizeOfPrimitiveType.
|
||
|
uint16_t element_size_in_bits_ = 0;
|
||
|
|
||
|
// The assigned memory space.
|
||
|
int8_t memory_space_ = 0;
|
||
|
|
||
|
// A map from physical dimension numbers to logical dimension numbers.
|
||
|
// The first element is the most minor physical dimension (fastest varying
|
||
|
// index) and the last the most major (slowest varying index). The contents of
|
||
|
// the vector are the indices of the *logical* dimensions in the shape.
|
||
|
//
|
||
|
// For example, in shape f32[8,100,100,3]{3,0,2,1}, the logical dimensions
|
||
|
// are [8,100,100,3] and minor_to_major_ is {3,0,2,1}.
|
||
|
// So, the most minor physical dimension is [8,100,100,3][3], which is size 3.
|
||
|
// The second most minor is [8,100,100,3][0], which is size 8.
|
||
|
// The third most minor is [8,100,100,3][2], which is size 100.
|
||
|
// And the major dim is [8,100,100,3][1], which is size 100.
|
||
|
DimensionVector minor_to_major_;
|
||
|
|
||
|
// The tiles used in tiling-based layout.
|
||
|
TileVector tiles_;
|
||
|
|
||
|
// The shape is padded at the end to multiple of, in terms of number of
|
||
|
// elements. This is useful when tiling does not bring the shape to certain
|
||
|
// desired granules. Tiling effectively pads/reshapes/transposes the shape
|
||
|
// to another shape. This field pads the total number of elements of that
|
||
|
// new shape to a multiple of certain number of elements. This is useful such
|
||
|
// as we want a layout which does not tile the data but still requires it to
|
||
|
// be padded to certain number of elements.
|
||
|
int64_t tail_padding_alignment_in_elements_ = 1;
|
||
|
|
||
|
// The physical on-device shape used to represent a sparse array.
|
||
|
std::unique_ptr<Shape> physical_shape_;
|
||
|
|
||
|
// The dynamic shape metadata size in bytes in front of the shape data. The
|
||
|
// field may be non-zero for a static shape whose associated buffer is for a
|
||
|
// dynamic shape, e.g. a result of SliceToDynamic.
|
||
|
int64_t dynamic_shape_metadata_prefix_bytes_ = 0;
|
||
|
};
|
||
|
|
||
|
std::ostream& operator<<(std::ostream& out, const Tile& Tile);
|
||
|
std::ostream& operator<<(std::ostream& out, const Layout& layout);
|
||
|
|
||
|
} // namespace xla
|
||
|
|
||
|
#endif // XLA_LAYOUT_H_
|