3RNN/Lib/site-packages/tensorflow/include/xla/runtime/diagnostics.h
2024-05-26 19:49:15 +02:00

175 lines
5.8 KiB
C++

/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_RUNTIME_DIAGNOSTICS_H_
#define XLA_RUNTIME_DIAGNOSTICS_H_
#include <functional>
#include <optional>
#include <string>
#include <utility>
#include "absl/status/status.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "xla/runtime/logical_result.h"
#include "tsl/platform/logging.h"
namespace xla {
namespace runtime {
// Forward declare.
class DiagnosticEngine;
// XLA runtime diagnostic engine enables XLA runtime custom calls and compiled
// programs to pass diagnostic information (e.g. detailed run time error
// information attached to the absl::Status) to the caller via the side channel,
// because the API (and ABI) of the compiled executable is very simple, and
// doesn't allow to pass complex C++ tyopes.
//
// XLA runtime diagnostics borrows a lot of ideas from the MLIR compile time
// diagnostics (which is largely based on the Swift compiler diagnostics),
// however in contrast to MLIR compilation pipelines we need to emit diagnostics
// for the run time events (vs compile time) and correlate them back to the
// location in the input module.
//
// See MLIR Diagnostics documentation: https://mlir.llvm.org/docs/Diagnostics.
//===----------------------------------------------------------------------===//
// Diagnostic
//===----------------------------------------------------------------------===//
// TODO(ezhulenev): Add location tracking to diagnostic, so that we can
// correlate emitted diagnostic to the location in the input module, and from
// there rely on the MLIR location to correlate events back to the user program
// (e.g. original JAX program written in Python).
class Diagnostic {
public:
explicit Diagnostic(absl::Status status) : status_(std::move(status)) {}
Diagnostic(Diagnostic &&) = default;
Diagnostic &operator=(Diagnostic &&) = default;
absl::Status status() const { return status_; }
private:
Diagnostic(const Diagnostic &rhs) = delete;
Diagnostic &operator=(const Diagnostic &rhs) = delete;
absl::Status status_;
};
//===----------------------------------------------------------------------===//
// InFlightDiagnostic
//===----------------------------------------------------------------------===//
// RAII wrapper around constructed, but but not yet emitted diagnostic. In
// flight diagnostic gives an opportunity to build a diagnostic before reporting
// it to the engine, similar to the builder pattern.
class InFlightDiagnostic {
public:
InFlightDiagnostic(InFlightDiagnostic &&other)
: engine_(other.engine_), diagnostic_(std::move(other.diagnostic_)) {
other.diagnostic_.reset();
other.Abandon();
}
~InFlightDiagnostic() {
if (IsInFlight()) Report();
}
void Report();
void Abandon();
// Allow a diagnostic to be converted to 'failure'.
//
// Example:
//
// LogicalResult call(DiagnosticEngine diag, ...) {
// if (<check failed>) return diag.EmitError(Internal("oops"));
// ...
// }
//
operator LogicalResult() const { // NOLINT
return failure();
}
template <typename T>
operator FailureOr<T>() const { // NOLINT
return failure();
}
private:
friend class DiagnosticEngine;
InFlightDiagnostic(const DiagnosticEngine *engine, Diagnostic diagnostic)
: engine_(engine), diagnostic_(std::move(diagnostic)) {}
InFlightDiagnostic &operator=(const InFlightDiagnostic &) = delete;
InFlightDiagnostic &operator=(InFlightDiagnostic &&) = delete;
bool IsActive() const { return diagnostic_.has_value(); }
bool IsInFlight() const { return engine_ != nullptr; }
// Diagnostic engine that will report this diagnostic once its ready.
const DiagnosticEngine *engine_ = nullptr;
std::optional<Diagnostic> diagnostic_;
};
//===----------------------------------------------------------------------===//
// DiagnosticEngine
//===----------------------------------------------------------------------===//
// Diagnostic engine is responsible for passing diagnostics to the user.
//
// XLA runtime users must set up diagnostic engine to report errors back to the
// caller, e.g. the handler can collect all of the emitted diagnostics into the
// string message, and pass it to the caller as the async error.
//
// Unhandled error diagnostics will be logged with the warning level.
class DiagnosticEngine {
public:
// Diagnostic handler must return success if it consumed the diagnostic, and
// failure if the engine should pass it to the next registered handler.
using HandlerTy = std::function<LogicalResult(Diagnostic &)>;
// Returns the default instance of the diagnostic engine.
static const DiagnosticEngine *DefaultDiagnosticEngine();
InFlightDiagnostic EmitError(absl::Status status) const {
return InFlightDiagnostic(this, Diagnostic(std::move(status)));
}
void AddHandler(HandlerTy handler) {
handlers_.push_back(std::move(handler));
}
void Emit(Diagnostic diagnostic) const {
for (auto &handler : llvm::reverse(handlers_)) {
if (succeeded(handler(diagnostic))) return;
}
// Log unhandled errors to the warning log.
LOG(WARNING) << "XLA runtime error: " << diagnostic.status();
}
private:
llvm::SmallVector<HandlerTy> handlers_;
};
} // namespace runtime
} // namespace xla
#endif // XLA_RUNTIME_DIAGNOSTICS_H_