/* Copyright 2020 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef XLA_PJRT_UTILS_H_ #define XLA_PJRT_UTILS_H_ #include #include #include #include #include #include "absl/container/flat_hash_map.h" #include "absl/types/span.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "xla/client/executable_build_options.h" #include "xla/client/xla_computation.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/pjrt/layout_mode.h" #include "xla/service/computation_placer.h" #include "xla/shape.h" #include "xla/status.h" #include "xla/statusor.h" #include "xla/xla_data.pb.h" namespace xla { // Returns the num_replicas, num_partitions and device assignment given a // ExecutableBuildOptions and whether we want a portable executable. Status ParseDeviceAssignmentCompileOptions( bool compile_portable_executable, ExecutableBuildOptions* build_options, std::function(int, int)> GetDefaultDeviceAssignmentFunction, int* num_replicas, int* num_partitions, std::shared_ptr* device_assignment); // Returns the LayoutMode for each argument of the main function in the // module. Checks for the "mhlo.layout_mode" attr, and if not present, assumes // LayoutMode::Mode::kDefault. StatusOr> GetArgLayoutModes(mlir::ModuleOp module); // Returns the LayoutMode for each output of the main function in the // module. Checks for the "mhlo.layout_mode" attr, and if not present, assumes // LayoutMode::Mode::kDefault. StatusOr> GetOutputLayoutModes(mlir::ModuleOp module); // Populates the frontend attributes "arg_layout_mode" and "out_layout_mode" in // xla_computation based on `module`. This function must be called before the // LayoutMode getters below work correctly on `computation`. Status AddLayoutModesToFrontendAttrs(mlir::ModuleOp module, XlaComputation& xla_computation); // Returns the LayoutMode for each argument of the computations. Checks for the // "arg_layout_mode" frontend attribute, and if not present, assumes // LayoutMode::Mode::kDefault. StatusOr> GetArgLayoutModes( const XlaComputation& computation); // Returns the LayoutMode for each argument of the computations. Checks for the // "out_layout_mode" frontend attribute, and if not present, assumes // LayoutMode::Mode::kDefault. StatusOr> GetOutputLayoutModes( const XlaComputation& computation); // Returns (arg shapes, output shape) with properly-set Layouts that can // be passed to XLA to reflect arg_layout_modes and out_layout_modes. StatusOr, Shape>> LayoutModesToXlaShapes( const XlaComputation& computation, std::vector arg_layout_modes, std::vector out_layout_modes, std::function(Shape)> choose_compact_layout_for_shape_function); // Generates useful data structures for communciating desired layouts to XLA: // * Returns a vector of argument xla::Shapes with properly-set Layouts // * Returns vector of pointers to those Shapes to create HloModuleConfig // * Modifies `build_options` to have the correct result_layout set or unset StatusOr, std::vector>> LayoutModesToXla(const XlaComputation& computation, std::vector arg_layout_modes, std::vector out_layout_modes, std::function(Shape)> choose_compact_layout_for_shape_function, ExecutableBuildOptions& build_options); // Returns pointers to the argument layouts given an XlaComputation and // ExecutableBuildOptions. Status DetermineArgumentLayoutsFromCompileOptions( const XlaComputation& computation, std::function(Shape)> choose_compact_layout_for_shape_function, std::optional>& argument_layouts, ExecutableBuildOptions* build_options, std::vector* argument_layout_pointers); // Executables can donate buffers so that buffers can be aliased from inputs // to outputs. This function returns a sorted vector of parameters that must be // donated when executable is run. tuple_inputs reflects the option that // executable was compiled with. StatusOr> ComputeParametersThatMustBeDonated( const HloModule& hlo_module, bool tuple_inputs); // Return max parallelism level. int DefaultThreadPoolSize(); // Returns true if the striding of an array corresponds to a major-to-minor // layout. bool HasMajorToMinorLayout(PrimitiveType type, absl::Span dims, absl::Span byte_strides); // Constructs a new dense array shape with the given byte strides. Supports only // trivial (compact) byte_strides that represents a transposition of a dense // buffer. StatusOr MakeShapeWithTrivialByteStrides( PrimitiveType element_type, absl::Span dimensions, absl::Span byte_strides); // If a buffer `is_donated`, then it can only be used once. This function // records the use into donation_clashes and tests for incompatible uses. // Multiple uses are valid iff they are all not donations. The provided map // stores the opaque buffer identity, a bool to denote if the previous use is a // donation, and the index of the previous use for better error messages. Status TestBufferDonationClashes( void* opaque_key, absl::flat_hash_map>& donation_clashes, bool is_donated, int arg_idx, int replica, int partition); } // namespace xla #endif // XLA_PJRT_UTILS_H_