3RNN/Lib/site-packages/tensorflow/include/xla/service/compilation_environments.h
2024-05-26 19:49:15 +02:00

162 lines
5.9 KiB
C++

/* Copyright 2022 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_COMPILATION_ENVIRONMENTS_H_
#define XLA_SERVICE_COMPILATION_ENVIRONMENTS_H_
#include <cstdint>
#include <functional>
#include <memory>
#include <string_view>
#include <typeindex>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "xla/statusor.h"
#include "xla/xla.pb.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/protobuf.h"
namespace xla {
// A class for holding CompilationEnvironments, i.e., protos holding the values
// of command line flags and environment variables that affect compilation.
//
// CompilationEnvironments uses lazy initialization, (see GetEnv() for more
// details). Lazy initialization is used so we can avoid:
// A) Requiring every code path to explitily construct all needed compilation
// environments, particularly when the default constructed environment is
// all we need AND
// B) Requiring CompilationEnvironments to implicitly construct all needed
// environments, thereby requiring it to statically know the types of all
// such environments
//
// CompilationEnvironments is not thread-safe.
class CompilationEnvironments {
public:
using ProcessNewEnvFn =
std::function<StatusOr<std::unique_ptr<tsl::protobuf::Message>>(
std::unique_ptr<tsl::protobuf::Message>)>;
CompilationEnvironments() = default;
CompilationEnvironments(const CompilationEnvironments& rhs) { *this = rhs; }
CompilationEnvironments& operator=(const CompilationEnvironments& rhs);
~CompilationEnvironments() = default;
// Deserializes the given CompilationEnvironments proto.
static StatusOr<std::unique_ptr<CompilationEnvironments>> CreateFromProto(
const CompilationEnvironmentsProto& proto);
// Whenever an environment is added to CompilationEnvironments, even when
// GetEnv() adds a lazily initialized one, it is passed to the function
// registered by this method, corresponding to the environment's proto
// descriptor. The result is the environment that is used by
// CompilationEnvironments. This allows environment authors to
// do things like populate missing fields in an added environment.
//
// Users of CompilationEnvironments must register their `ProcessNewEnvFn`
// function via this method for each type of CompilationEnvironment they wish
// to use in code.
//
// The input env to a ProcessNewEnvFn may be null.
//
// REQUIRES:
// - The output is *not* allowed to be null, even for null input.
static void RegisterProcessNewEnvFn(
const tsl::protobuf::Descriptor* descriptor,
ProcessNewEnvFn process_new_env);
// Adds env to the list of CompilationEnvironments. If an environment with
// the same proto descriptor has already been added, env will replace it.
//
// All added environments are processed via registered ProcessNewEnvFns. If
// such a function was not regitered for env's proto descriptor or env's
// proto type is unknown, an error will be returned.
Status AddEnv(std::unique_ptr<tsl::protobuf::Message> env);
// Returns the CompilationEnvironment corresponding to T. If such an
// environment has not been added, ProcessNewEnvFn(nullptr) will be added and
// returned.
//
// GetMutableEnv()/GetEnv() are not const because they can perform lazy
// initialization, thereby modifying the CompilationEnvironments's data
// members.
template <typename T>
T& GetMutableEnv();
template <typename T>
const T& GetEnv();
template <typename T>
bool HasEnv();
// Removes all added environments.
void Clear() { environments_.clear(); }
// Serializes this CompilationEnvironments into a protobuf message.
CompilationEnvironmentsProto ToProto() const;
private:
// Returns the ProcessNewEnvFn for the given env type. Returns nullptr if no
// ProcessNewEnvFn has been registered for the env type.
static ProcessNewEnvFn GetProcessNewEnvFn(
const tsl::protobuf::Descriptor& descriptor);
// Called by GetEnv(), when it lazily creates a new environment, to globally
// track stats about how many such environments are created by
// CompilationEnvironments.
static void DefaultEnvCreatedByCompilationEnvironments(
std::string_view env_type);
// Called by AddEnv(), to globally track stats about how many environments
// are added to CompilationEnvironments.
static void EnvAdded(std::string_view env_type);
Status AddEnvImpl(const tsl::protobuf::Descriptor& descriptor,
std::unique_ptr<tsl::protobuf::Message> env);
absl::flat_hash_map<const tsl::protobuf::Descriptor*,
std::unique_ptr<tsl::protobuf::Message>>
environments_;
};
// ----- Template implementation below -----
template <typename T>
T& CompilationEnvironments::GetMutableEnv() {
auto descriptor = T::descriptor();
auto it = environments_.find(descriptor);
if (it == environments_.end()) {
TF_CHECK_OK(AddEnvImpl(*descriptor, nullptr));
DefaultEnvCreatedByCompilationEnvironments(descriptor->full_name());
it = environments_.find(descriptor);
}
return tensorflow::down_cast<T&>(*it->second);
}
template <typename T>
const T& CompilationEnvironments::GetEnv() {
return GetMutableEnv<T>();
}
template <typename T>
bool CompilationEnvironments::HasEnv() {
auto descriptor = T::descriptor();
return environments_.find(descriptor) != environments_.end();
}
} // namespace xla
#endif // XLA_SERVICE_COMPILATION_ENVIRONMENTS_H_