3RNN/Lib/site-packages/tensorflow/include/xla/runtime/executable.h
2024-05-26 19:49:15 +02:00

448 lines
19 KiB
C++

/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_RUNTIME_EXECUTABLE_H_
#define XLA_RUNTIME_EXECUTABLE_H_
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
#include "absl/status/statusor.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/MemoryBuffer.h"
#include "xla/runtime/arguments.h"
#include "xla/runtime/async_runtime.h"
#include "xla/runtime/custom_call.h"
#include "xla/runtime/custom_call_registry.h"
#include "xla/runtime/diagnostics.h"
#include "xla/runtime/execution_engine.h"
#include "xla/runtime/logical_result.h"
#include "xla/runtime/memory_mapper.h"
#include "xla/runtime/results.h"
#include "xla/runtime/type_id.h"
#include "xla/runtime/types.h"
namespace xla {
namespace runtime {
struct ExecutionContext;
struct DestroyExecutionContext {
void operator()(ExecutionContext* ctx);
};
// If executable has async results, ExecutionReference keeps that
// execution context alive. For sync executables `Execute` always returns
// ExecutionReference with nullptr.
class ExecutionReference
: public std::unique_ptr<ExecutionContext, DestroyExecutionContext> {
// Bring std::unique_ptr constructors in scope.
using std::unique_ptr<ExecutionContext, DestroyExecutionContext>::unique_ptr;
};
class FunctionRef;
class JitCompiler;
// Returns a symbols binding for running XLA executable with a custom symbols
// provided by the user.
ExecutionEngine::SymbolsBinding RuntimeSymbolsBinding(
ExecutionEngine::SymbolsBinding custom_binding);
// Converts a direct custom call and custom type id name registration functions
// (types required by the library) to the execution engine symbols binding.
// Returned symbols binding always includes type id symbols for all
// canonical types supported by the XLA runtime custom calls.
ExecutionEngine::SymbolsBinding ToSymbolsBinding(
std::function<void(DirectCustomCallRegistry&)> custom_calls = {},
std::function<void(TypeIDNameRegistry&)> types = {});
class Executable {
public:
// Forward declare types defined below.
struct ArgumentsMemoryLayout;
struct ResultsMemoryLayout;
struct CallFrame;
struct ExecuteOpts;
// Initializes call frame by adding all arguments according to the exported
// function ABI. Also allocates storage for the returned values according to
// the results memory layout.
//
// If `verify_arguments` is true (in debug mode it's always on, independent of
// the argument value) this function also verifies that arguments passed at
// run time matches the exported function signature (e.g. all statically known
// dimensions of the memrefs matches the arguments). Returns an error if finds
// a mismatch.
//
// This function leaves the execution context argument (the first argument of
// an exported function) uninitialized. It will be initialized in the
// `Execute` function right before the actual execution.
absl::Status InitializeCallFrame(unsigned ordinal, ArgumentsRef arguments,
CallFrame* call_frame,
bool verify_arguments = true) const;
absl::Status InitializeCallFrame(ArgumentsRef arguments,
CallFrame* call_frame,
bool verify_arguments = true) const {
return InitializeCallFrame(0, arguments, call_frame, verify_arguments);
}
// Converts returned values owned by the call frame using provided result
// converter. If exported function execution finished with an error (error
// flag is `true` in the call frame) returns error for all results (see
// `ResultConverter::ReturnError` documentation).
absl::Status ReturnResults(unsigned ordinal, const ResultConverter& results,
CallFrame* call_frame) const;
absl::Status ReturnResults(const ResultConverter& results,
CallFrame* call_frame) const {
return ReturnResults(0, results, call_frame);
}
// Executes exported function exported with given arguments.
//
// If `verify_arguments` is true (in debug mode it's always on, independent of
// the argument value) this function also verifies that arguments passed at
// run time matches the exported function signature. If some of the
// arguments do not match the expected type, this function allocates error
// async values for all results and returns an error.
//
// Returns exported function results via the user-provided results converter.
// If execution completed in the error state, returns error for all results.
absl::StatusOr<ExecutionReference> Execute(
unsigned ordinal, ArgumentsRef arguments, const ResultConverter& results,
const ExecuteOpts& opts, bool verify_arguments = true) const;
absl::StatusOr<ExecutionReference> Execute(
ArgumentsRef arguments, const ResultConverter& results,
const ExecuteOpts& opts, bool verify_arguments = true) const {
return Execute(0, arguments, results, opts, verify_arguments);
}
// Executes exported function using user provided call frame.
//
// It is the caller responsibility to handle the compiled function results
// stored in the call frame.
ExecutionReference Execute(unsigned ordinal, CallFrame& call_frame,
const ExecuteOpts& opts) const;
void Execute(CallFrame& call_frame, const ExecuteOpts& opts) const {
Execute(0, call_frame, opts);
}
std::string_view name() const { return name_; }
std::string&& take_ir_module_string() { return std::move(ir_module_string_); }
std::optional<size_t> specialization() const { return specialization_; }
// Returns the number of exported functions. Functions are indexed by their
// ordinal number in the [0, num_functions) range.
size_t num_functions() const { return functions_.size(); }
// Returns a function reference to an exported function with given ordinal.
FunctionRef function_ref(unsigned ordinal) const;
// Returns true if exported function with given ordinal has async results.
bool IsAsync(unsigned ordinal) const;
bool IsAsync() const { return IsAsync(0); }
// Returns the name of the exported function with the given ordinal.
std::string_view function_name(unsigned ordinal) const;
std::string_view function_name() const { return function_name(0); }
// Returns the number of results of the exported function with given ordinal.
unsigned num_results(unsigned ordinal) const;
unsigned num_results() const { return num_results(0); }
// Signature of the exported function with the given ordinal before lowering
// to the runtime dialects. See JitExecutable::Function's `signature` for
// more details.
const FunctionType& signature(unsigned ordinal) const;
const FunctionType& signature() const { return signature(0); }
// Signature of the exported function with the given ordinal after lowering it
// from high level dialects to the dialects supported by the XLA runtime. See
// JitExecutable::Function's `signature` for more details.
const FunctionType& runtime_signature(unsigned ordinal) const;
const FunctionType& runtime_signature() const { return runtime_signature(0); }
std::chrono::milliseconds time_to_compile() const;
// Get the object file behind this executable (on linux for example, it will
// be https://en.wikipedia.org/wiki/Executable_and_Linkable_Format
// executable). Can be null.
std::unique_ptr<llvm::MemoryBuffer> obj_file() const;
// CallFrame provides a pointer-stable storage for packed function arguments
// and storage for returned values.
struct CallFrame {
// Pointers to executable arguments.
llvm::SmallVector<void*, 32> args;
// We use single block of memory to store executable results. We need to be
// able to store pointers to async values and tokens, and strided memrefs
// which at runtime are represented as StridedMemrefType<T, rank>.
//
// Currently we only need to provide result storage for pointers and memref
// sizes and strides (int64_t type). If we'll need to support more complex
// return types we'll have to be more careful about alignment requirements.
static_assert(sizeof(uintptr_t) == sizeof(int64_t),
"uintptr_t size must be the same as int64_t");
// Memory where the executable will write its results.
llvm::SmallVector<uint8_t, 128> results;
// Tracks whether any of the outputs were set.
bool has_set_outputs = false;
// Indicates whether the execution finished with an error.
bool is_error = false;
// The error message which is available only if `is_error` is true. The
// assumption is that the error message string is owned by the compiled
// binary and the call frame can safely keep a non-owning pointer.
std::string_view error;
};
// Requirements for passing arguments to the compiled function.
struct ArgumentsMemoryLayout {
size_t num_args_ptrs = 0; // total number of required pointers
llvm::SmallVector<size_t> num_ptrs; // num_ptrs for each argument
llvm::SmallVector<size_t> offsets; // offsets into the args array
};
// Requirements for the contiguous block of memory to store compiled function
// results. When we invoke a compiled fuction we allocate a block of memory,
// and pass pointers to pre-computed offsets as output arguments to the
// function.
struct ResultsMemoryLayout {
bool has_async_results = false; // true iff returns async results
size_t size = 0; // number of bytes required
llvm::SmallVector<size_t> offsets; // offsets in the block of memory
};
struct ExecuteOpts {
// Async task runner for executing async runtime tasks. Typically it
// schedules async tasks into the underlying thread pool. It's the caller's
// responsibility to guarantee that it will outlive the execution of all
// async tasks started by the executable.
AsyncTaskRunner* async_task_runner = nullptr;
// A container for passing arbitrary user-provided data to the custom call
// handlers. Must outlive all async tasks launched by this executable.
const CustomCall::UserData* custom_call_data = nullptr;
// Dynamically registered custom calls library. These custom calls resolved
// at run time by name. In contrast to custom calls defined by the
// `DirectCustomCallRegistry` which are linked directly with the executable
// at compile time.
const DynamicCustomCallRegistry* custom_call_registry = nullptr;
// Diagnostic engine is responsible for passing runtime diagnostics back
// to the caller through the diagnostic handler.
const DiagnosticEngine* diagnostic_engine = nullptr;
};
// Function specification for loading from the object file.
struct LoadFunction {
std::string name;
FunctionType signature;
FunctionType runtime_signature;
};
// Loads executable from an object file. It is the caller responsibility to
// guarantee that signatures do match the compiled function in the object
// file, otherwise it will surely lead to crash.
static absl::StatusOr<Executable> LoadFromObjFile(
std::string_view name, std::unique_ptr<llvm::MemoryBuffer> obj_file,
std::vector<LoadFunction> load_functions,
ExecutionEngine::SymbolsBinding symbols_binding = {},
std::string_view memory_region_name = "");
// Verifies that all arguments types in the exported function signature are
// supported at run time. Returns a pre-computed layout for the function
// arguments. If some arguments are not supported returns an error.
static absl::StatusOr<ArgumentsMemoryLayout> GetArgumentsMemoryLayout(
const FunctionType& signature);
// Verifies that all results types in the exported function signature are
// supported at run time . Returns a pre-computed layout for the function
// results. If some results are not supported returns an error.
static absl::StatusOr<ResultsMemoryLayout> GetResultsMemoryLayout(
const FunctionType& signature);
// TODO(ezhulenev): The following three functions should be decoupled from
// the executable header file (maybe move them to runtime.h?) so that custom
// call implementations do not have to depend on the `executable` target.
// Returns the user data passed via the ExecuteOpts to the executable.
static const CustomCall::UserData* GetUserData(ExecutionContext* ctx);
// Returns the diagnostic engine passed via the ExecuteOpts to the executable.
static const DiagnosticEngine* GetDiagnosticEngine(ExecutionContext* ctx);
// Calls the custom call handler with the given runtime context, arguments,
// attributes and results.
static LogicalResult Call(ExecutionContext* ctx, CustomCall& call,
void** args, void** attrs, void** rets);
bool RequiresBlas(int ordinal) const {
return functions_[ordinal].requires_blas;
}
private:
friend class JitCompiler; // see `mlir/runtime/transforms/jit_compiler.h`
// Executable exports multiple functions available for users to call into. At
// run time they are referenced by their ordinal, so that we don't depend on
// expensive by-name lookup on the hot path. We keep function name only for
// debugging. Function ordinal is defined by its index in the `functions_`
// vector.
struct Function {
Function(std::string_view name, ExecutionEngine::ExportedFunctionPtr fptr,
FunctionType signature, FunctionType runtime_signature,
ArgumentsMemoryLayout arguments_memory_layout,
ResultsMemoryLayout results_memory_layout, bool requires_blas)
: name(name),
fptr(std::move(fptr)),
signature(std::move(signature)),
runtime_signature(std::move(runtime_signature)),
arguments_memory_layout(std::move(arguments_memory_layout)),
results_memory_layout(std::move(results_memory_layout)),
requires_blas(requires_blas) {}
Function(const Function&) = delete;
Function(Function&&) = default;
// Exported function name.
std::string name;
// Pointer to an exported function owned by the execution engine.
ExecutionEngine::ExportedFunctionPtr fptr;
// Signature of the exported function function before lowering to the
// runtime dialects (see JitExecutable::Function's `signature`).
FunctionType signature;
// Signature of the exported function after lowering it from high level
// dialects to the dialects supported by the XLA runtime.
//
// - Operands and results types converted to the types with well-defined ABI
// (e.g. tensors converted to memrefs).
//
// - First argument is always an execution context added to the function by
// the lowering pipeline.
//
// From this signatur, Executable infers how to pack runtime arguments
// according to the expected memory layout, and how to convert results
// returned from the JIT-compiled function into high level types (e.g. how
// to convert StridedMemrefType into Tensorflow Tensor).
//
// To infer the type of the returned value, Executable looks at the type
// defined by the `runtime_signature` to get the memory layout of the
// returned value, and at the type defined by the `signature` to get the
// type expected by the runtime.
FunctionType runtime_signature;
// Memory layout required for passing function arguments.
ArgumentsMemoryLayout arguments_memory_layout;
// Memory layout for returning function results.
ResultsMemoryLayout results_memory_layout;
// If this flag is true, then this function is outlined for cuda graph, and
// cuBlas should be initiated when capturing the cuda graph.
bool requires_blas;
};
Executable(std::string_view name,
std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper,
std::unique_ptr<ExecutionEngine> engine,
std::vector<Function> functions,
std::optional<size_t> specialization,
std::chrono::milliseconds time_to_compile,
std::string&& ir_module_string = "")
: name_(name),
memory_mapper_(std::move(memory_mapper)),
engine_(std::move(engine)),
functions_(std::move(functions)),
specialization_(specialization),
time_to_compile_(time_to_compile),
ir_module_string_(ir_module_string) {
// All exported functions must have a non-null function pointer.
assert(llvm::all_of(functions_, [](const Function& f) { return f.fptr; }));
}
std::string name_; // name of the compiled executable
// Called by `engine_`'s destructor; must appear before it.
std::unique_ptr<XlaRuntimeMemoryMapper> memory_mapper_; // optional
// XLA runtime execution engine owns the LLVM ORC jit compilation stack.
std::unique_ptr<ExecutionEngine> engine_;
// Functions exported by this executable, indexed by function ordinal.
std::vector<Function> functions_;
// Specialization id if this executable is a specialization, or an empty
// optional if this executable is a default one.
std::optional<size_t> specialization_;
// The time it took to compile this binary.
std::chrono::milliseconds time_to_compile_;
// The (optional) string containing the LLVM module, if requested by
// compilation or set explicitly.
std::string ir_module_string_;
};
// Function reference provides a function-like API for a function exported from
// the executabled with the given ordinal.
class FunctionRef {
public:
FunctionRef(const Executable* executable, unsigned ordinal);
absl::StatusOr<ExecutionReference> operator()(
ArgumentsRef arguments, const ResultConverter& results,
const Executable::ExecuteOpts& opts, bool verify_arguments = true) const;
bool RequiresBlas() const { return executable_->RequiresBlas(ordinal_); }
unsigned ordinal() const { return ordinal_; }
private:
const Executable* executable_;
unsigned ordinal_;
};
// Escape slashes, substituting them with double underscores to get a memory
// region name for the XlaRuntimeMemoryMapper.
//
// The profiler's UI might interpret slashes as callchain separators,
// whereas we want the region name to be shown in full.
inline std::string EscapeMemRegionName(std::string_view memory_region_name) {
return llvm::join(llvm::split(memory_region_name, '/'), "__");
}
} // namespace runtime
} // namespace xla
#endif // XLA_RUNTIME_EXECUTABLE_H_