3RNN/Lib/site-packages/tensorflow/include/xla/service/compiler.h

469 lines
18 KiB
C
Raw Permalink Normal View History

2024-05-26 19:49:15 +02:00
/* Copyright 2017 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The compiler API is used by the XLA service to generate executables that
// run on a given platform. This is a registry and abstract interface, for
// pluggability by the various platforms.
#ifndef XLA_SERVICE_COMPILER_H_
#define XLA_SERVICE_COMPILER_H_
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_module_group.h"
#include "xla/service/buffer_assignment.h"
#include "xla/service/buffer_value.h"
#include "xla/service/computation_placer.h"
#include "xla/service/executable.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/metrics_hook_interface.h"
#include "xla/statusor.h"
#include "xla/stream_executor/stream_executor.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/threadpool.h"
namespace xla {
// The following types are used for ahead of time compilation.
// Contains the object file data created as a result of ahead-of-time
// computation.
using ObjectFileData = std::vector<char>;
class Compiler;
class AotCompilationOptions;
// Abstract superclass describing the result of an ahead-of-time compilation.
class AotCompilationResult {
public:
AotCompilationResult(const AotCompilationResult&) = delete;
AotCompilationResult& operator=(AotCompilationResult const&) = delete;
virtual ~AotCompilationResult() = default;
virtual StatusOr<std::string> SerializeAsString() const {
return Unimplemented("SerializeAsString unimplemented.");
}
virtual StatusOr<std::unique_ptr<Executable>> LoadExecutable(
Compiler* compiler, const se::StreamExecutor* executor) const {
return Unimplemented("LoadExecutable unimplemented.");
}
// Returns the optimized HLO module if one was computed and the implementation
// supports it.
virtual const HloModule* optimized_module() const = 0;
virtual std::unique_ptr<HloModule> consume_optimized_module() = 0;
protected:
AotCompilationResult() = default;
};
// Abstract superclass describing metadata produced during ahead-of-time
// compilation.
class AotCompilationMetadata {
public:
AotCompilationMetadata(const AotCompilationMetadata&) = delete;
AotCompilationMetadata& operator=(AotCompilationMetadata const&) = delete;
virtual std::string ToString() const { return ""; }
virtual ~AotCompilationMetadata() = default;
protected:
AotCompilationMetadata() = default;
};
// Abstract compiler interface that is subclassed for compilation on a
// particular platform.
//
// The compiler ties together high level optimization (HLO) and low level
// optimization (LLO) / codegen (CG) to generate efficient executables for the
// target platform.
//
// The platform-based compiler singletons are registered via module initializers
// in their corresponding XLA compiler libraries, and are registered via the
// RegisterCompilerFactory API below.
//
// Thread-safety: subclasses of Compiler must be thread-safe, as multiple
// XLA clients may be requesting compilation concurrently for a given
// platform.
class Compiler {
public:
// Description of a target device for compilation.
struct TargetConfig {
explicit TargetConfig(const se::GpuTargetConfigProto& proto);
explicit TargetConfig(se::StreamExecutor* s);
se::GpuTargetConfigProto ToProto() const;
bool operator==(const TargetConfig& other) const {
// TODO(cheshire): More efficient comparator, this is currently just for
// tests.
return ToProto().SerializeAsString() ==
other.ToProto().SerializeAsString();
}
se::DeviceDescription device_description;
std::string platform_name;
se::dnn::VersionInfo dnn_version_info;
std::string device_description_str;
};
struct CompileOptions {
// If device_allocator is not null, the compiler may use it to allocate temp
// space on the device for use during compilation. For example, the
// compiler may allocate buffers on the device and then run variants of a
// given algorithm over those buffers, to see which variant is fastest. Any
// space allocated will be deallocated before the compilation returns.
se::DeviceMemoryAllocator* device_allocator = nullptr;
// An optional thread pool for parallel compilation.
tsl::thread::ThreadPool* thread_pool = nullptr;
std::function<StatusOr<std::pair<std::vector<Shape>, Shape>>(
const HloModule& module)>
layout_canonicalization_callback = {};
bool is_autotuning_compilation = false;
// AOT device description. If provided, used instead of querying the device
// on which compilation is performed.
std::optional<TargetConfig> target_config;
};
virtual ~Compiler() = default;
// Returns the ID of the platform that this compiler targets.
virtual se::Platform::Id PlatformId() const = 0;
// Runs Hlo passes to optimize the given Hlo module, returns the optimized
// module.
virtual StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
const CompileOptions& options) = 0;
StatusOr<std::unique_ptr<HloModule>> RunHloPasses(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunHloPasses(std::move(module), executor,
CompileOptions{device_allocator});
}
// Performs scheduling and buffer assignment and returns the buffer
// assignments.
// The returned 'BufferAssignment' retains a pointer to the 'HloModule', so
// the module must live at least as long as the buffer assignments.
virtual StatusOr<std::unique_ptr<BufferAssignment>> AssignBuffers(
HloModule* module, const se::StreamExecutor* executor) {
return Unimplemented("This compiler does not support this method");
}
// Compiles the HLO module for execution on a device given by the executor,
// and returns an executable object or an error status. No HLO passes are
// applied to module. Generally a module should be passed through RunHloPasses
// prior to calling this method because some HLO passes are required for
// correctness. Takes ownership of the HLO module.
//
// The compiler may optionally specialize to the individual device
// (not just type of device) indicated by the executor.
virtual StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
const CompileOptions& options) = 0;
StatusOr<std::unique_ptr<Executable>> RunBackend(
std::unique_ptr<HloModule> module, se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunBackend(std::move(module), executor,
CompileOptions{device_allocator});
}
// The following two interfaces are same as the above two, except they
// facilitate the loading of buffer assignment from proto if available.
// Note: The default implementation of the API here does not utilize the given
// buffer assignment. Different backends are a expected to override the
// following method to achieve this functionality.
virtual StatusOr<std::unique_ptr<Executable>> RunBackendWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* /*buffer_assignment_proto*/,
se::StreamExecutor* executor, const CompileOptions& options) {
LOG(WARNING) << "Ignoring the buffer assignment proto provided.";
return RunBackend(std::move(module), executor, options);
}
StatusOr<std::unique_ptr<Executable>> RunBackendWithBufferAssignment(
std::unique_ptr<HloModule> module,
const BufferAssignmentProto* buffer_assignment_proto,
se::StreamExecutor* executor,
se::DeviceMemoryAllocator* device_allocator) {
return RunBackendWithBufferAssignment(std::move(module),
buffer_assignment_proto, executor,
CompileOptions{device_allocator});
}
// Returns a (deserialized) AotCompilationResult from a serialized
// AotCompilationResult.
virtual StatusOr<std::unique_ptr<AotCompilationResult>>
LoadAotCompilationResult(const std::string& serialized_aot_result) {
return Unimplemented("LoadAotCompilationResult unimplemented.");
}
// Compiles a set of HLO modules that can run in parallel, potentially
// communicating data between the modules, and returns a corresponding
// sequence of executable objects.
//
// TODO(b/68666782): Remove this method after adding support for multiple
// modules to RunHloPasses and RunBackends.
virtual StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
const CompileOptions& options) = 0;
StatusOr<std::vector<std::unique_ptr<Executable>>> Compile(
std::unique_ptr<HloModuleGroup> module_group,
std::vector<std::vector<se::StreamExecutor*>> stream_exec,
se::DeviceMemoryAllocator* device_allocator) {
return Compile(std::move(module_group), stream_exec,
CompileOptions{device_allocator});
}
// Returns the backend configurations that the backend will consider for the
// given HLO. Returns no configurations if the backend does not support
// configurations for the given HLO.
//
// The stream executor is passed in to provide information about the hardware
// that the backend configurations would be targeting.
virtual std::vector<std::unique_ptr<tsl::protobuf::Message>>
ComputeBackendConfigs(const HloInstruction& hlo,
se::StreamExecutor* executor) const;
// Returns the backend configuration that the backend chooses by default for
// the given HLO. Returns no configuration if the backend does not support
// configurations for the given HLO.
//
// The stream executor is passed in to provide information about the hardware
// that the backend configurations would be targeting.
virtual std::unique_ptr<tsl::protobuf::Message> ComputeDefaultBackendConfig(
const HloInstruction& hlo, se::StreamExecutor* executor) const;
// Compiles the HLO module group for ahead-of-time execution. This is
// intended for use in static compilation.
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options) = 0;
// Similar to CompileAheadOfTime above but AotCompilationMetadata
// has an argument that can be populated during compilation.
virtual StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileAheadOfTime(std::unique_ptr<HloModuleGroup> module_group,
const AotCompilationOptions& options,
std::unique_ptr<AotCompilationMetadata>* metadata);
/////
// The Compiler class also serves as a point to register compiler objects
// for the various platforms.
using CompilerFactory = std::function<std::unique_ptr<Compiler>()>;
// Registers the compiler singleton for the platform. This is assumed to
// be a singleton, so no ownership is transferred.
//
// Precondition: a platform kind must not be registered more than once.
static void RegisterCompilerFactory(se::Platform::Id platform_id,
CompilerFactory compiler_factory);
// Returns the compiler singleton pointer if it is available for the given
// platform, or an error status if it is not.
static StatusOr<Compiler*> GetForPlatform(const se::Platform* platform);
// Returns a function that computes the size in bytes of the logical
// buffer that contains a shape.
virtual HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const = 0;
// Returns a function that computes the size in bytes of a given
// logical buffer.
std::function<int64_t(const BufferValue&)> BufferSizeBytesFunction() {
HloCostAnalysis::ShapeSizeFunction shape_size = ShapeSizeBytesFunction();
return [shape_size](const BufferValue& buffer) {
return shape_size(buffer.shape());
};
}
virtual Shape DefaultDeviceShapeRepresentation(const Shape& shape) const {
return shape;
}
// Returns an AotCompilationResult of the executable for serialization.
virtual StatusOr<std::unique_ptr<AotCompilationResult>> Export(
Executable* executable) const {
return Unimplemented("Export unimplemented");
}
// Returns a MetricsHookInterface object used to instrument Compiler's
// compilation stages.
virtual std::unique_ptr<MetricsHookInterface> CreateMetricsHook(
absl::string_view filename_prefix) const;
private:
// Mutex that guards the platform-compiler map.
static absl::Mutex platform_compiler_mutex_;
// Map from platform kind to compiler factory.
static absl::flat_hash_map<se::Platform::Id, CompilerFactory>*
GetPlatformCompilerFactories();
// Map from platform kind to compiler instance, if we made one already (based
// on the factories above).
static absl::flat_hash_map<se::Platform::Id, std::unique_ptr<Compiler>>*
GetPlatformCompilers();
};
// Abstract superclass describing options to an ahead-of-time compilation.
class AotCompilationOptions {
public:
AotCompilationOptions(const AotCompilationOptions&) = delete;
AotCompilationOptions& operator=(AotCompilationOptions const&) = delete;
explicit AotCompilationOptions(se::Platform::Id platform_id)
: platform_id_(platform_id), debug_options_(GetDebugOptionsFromFlags()) {}
virtual ~AotCompilationOptions() = default;
// Returns the ID of the platform to which these options apply.
virtual se::Platform::Id PlatformId() const { return platform_id_; }
virtual int64_t replica_count() const { return 0; }
virtual int64_t num_cores() const { return 0; }
virtual bool use_spmd_partitioning() const { return false; }
virtual bool use_auto_spmd_partitioning() const { return false; }
virtual std::vector<int64_t> auto_spmd_partitioning_mesh_shape() const {
return {};
}
virtual std::vector<int64_t> auto_spmd_partitioning_mesh_ids() const {
return {};
}
virtual bool deduplicate_hlo() const { return false; }
virtual PrecisionConfig::Precision matrix_unit_operand_precision() const {
return PrecisionConfig::DEFAULT;
}
// Optional allocator that may be used for allocating temp space on the device
// during compilation.
se::DeviceMemoryAllocator* device_allocator() const {
return device_allocator_;
}
void set_device_allocator(se::DeviceMemoryAllocator* device_allocator) {
device_allocator_ = device_allocator;
}
const DebugOptions& debug_options() const { return debug_options_; }
DebugOptions* mutable_debug_options() { return &debug_options_; }
bool has_static_device_assignment() const {
return static_device_assignment_.has_value();
}
const DeviceAssignment& static_device_assignment() const {
CHECK(static_device_assignment_.has_value());
return *static_device_assignment_;
}
void set_static_device_assignment(const DeviceAssignment& device_assignment) {
static_device_assignment_ = device_assignment;
}
FusionConfigCollection fusion_config_collection() const {
return fusion_config_collection_;
}
void set_fusion_config_collection(
FusionConfigCollection fusion_config_collection) {
fusion_config_collection_ = fusion_config_collection;
}
const std::vector<std::vector<bool>>& fusion_config() const {
return fusion_config_;
}
void set_fusion_config(const std::vector<std::vector<bool>>& fusion_config) {
fusion_config_ = fusion_config;
}
se::StreamExecutor* executor() const { return executor_; }
void set_executor(se::StreamExecutor* executor) { executor_ = executor; }
// Optional profile_version and cache key may be used to trigger recompilation
// when a compilation cache is used.
int64_t profile_version() const { return profile_version_; }
void set_profile_version(int64_t profile_version) {
profile_version_ = profile_version;
}
absl::string_view cache_key() const { return cache_key_; }
void set_cache_key(absl::string_view cache_key) {
cache_key_ = std::string(cache_key);
}
bool run_backend_only() const { return run_backend_only_; }
void set_run_backend_only(bool run_backend_only) {
run_backend_only_ = run_backend_only;
}
bool sanitize_dataflow() const { return sanitize_dataflow_; }
void set_sanitize_dataflow(bool sanitize_dataflow) {
sanitize_dataflow_ = sanitize_dataflow;
}
const std::vector<std::string>& sanitize_abilists_dataflow() const {
return sanitize_abilists_dataflow_;
}
void set_sanitize_abilists_dataflow(
const std::vector<std::string>& abilists) {
sanitize_abilists_dataflow_ = abilists;
}
const std::optional<Compiler::TargetConfig>& target_config() const {
return target_config_;
}
void set_target_config(const Compiler::TargetConfig& target_config) {
target_config_ = std::move(target_config);
}
protected:
AotCompilationOptions();
private:
se::Platform::Id platform_id_;
se::DeviceMemoryAllocator* device_allocator_ = nullptr;
DebugOptions debug_options_;
std::optional<DeviceAssignment> static_device_assignment_;
std::vector<std::vector<bool>> fusion_config_;
FusionConfigCollection fusion_config_collection_ =
FusionConfigCollection::kOff;
se::StreamExecutor* executor_ = nullptr;
int64_t profile_version_ = 0;
std::string cache_key_;
bool run_backend_only_ = false;
bool sanitize_dataflow_ = false;
std::vector<std::string> sanitize_abilists_dataflow_;
// Contains target-specific information required by AOT compilation.
std::optional<Compiler::TargetConfig> target_config_;
};
} // namespace xla
#endif // XLA_SERVICE_COMPILER_H_