3RNN/Lib/site-packages/tensorflow/include/xla/service/instruction_fusion.h
2024-05-26 19:49:15 +02:00

340 lines
14 KiB
C++

/* Copyright 2017 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_INSTRUCTION_FUSION_H_
#define XLA_SERVICE_INSTRUCTION_FUSION_H_
#include <functional>
#include <optional>
#include <string>
#include <utility>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
// The source_location.h is not available in open source.
#if defined(PLATFORM_GOOGLE)
#include "absl/types/source_location.h"
#endif // PLATFORM_GOOGLE
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_reachability.h"
#include "xla/service/fusion_queue.h"
#include "xla/service/hlo_pass_interface.h"
namespace xla {
// Propagating explanation of fusion decisions: if something could not be fused,
// explain the reason.
class FusionDecision {
public:
// Can not be fused: explain why. Implicit conversion due to optional-like
// semantics: waiver granted in cl/419938611.
FusionDecision(absl::string_view explanation) // NOLINT
: explanation_(explanation) {}
// Same constructor as string_view, to allow implicit string conversion (can't
// implicitly convert both char* to string_view and string_view to
// FusionDecision).
FusionDecision(const char* explanation) // NOLINT
: explanation_(explanation) {}
// If condition is `true` means that we CAN fuse. In that case, explanation is
// discarded.
FusionDecision(bool condition, absl::string_view explanation) {
if (!condition) {
explanation_ = std::string(explanation);
}
}
#if defined(PLATFORM_GOOGLE)
// We can fuse iff. the decision is `true`. The source location indicates
// where an instance was created, making debugging easier without a need to
// provide explicit explanation.
FusionDecision( // NOLINT
bool decision,
absl::SourceLocation source_location = absl::SourceLocation::current());
#endif // PLATFORM_GOOGLE
// Can be fused.
FusionDecision() = default;
// Returns whether it can be fused.
explicit operator bool() const { return CanFuse(); }
// Whether the fusion decision is positive.
bool CanFuse() const { return !explanation_.has_value(); }
// Connects two decisions with a disjunction. This is different than just
// picking one, as we also have to propagate both explanations if only one of
// them is false to show why fusion wasn't performed.
FusionDecision Or(const FusionDecision& decision) const {
if (CanFuse() || decision.CanFuse()) {
return {};
}
return {absl::StrCat(explanation_.value_or(""), " ; ", decision.Explain())};
}
// Connects two fusion decision with a conjunction. Unlike disjunction,
// propagates only one explanation (as it is enough to show that fusion could
// not be done).
FusionDecision And(const FusionDecision& decision) const {
if (CanFuse()) {
return decision;
}
if (decision.CanFuse()) {
return *this;
}
// Both conditions were violated: returning either is valid.
return *this;
}
// Appends to explanation, or turns the decision negative.
FusionDecision operator<<(absl::string_view explanation) const {
return {absl::StrCat(explanation_.value_or(""), explanation)};
}
// Appends to explanation, or turns the decision negative.
FusionDecision operator<<(int64_t explanation) const {
return {absl::StrCat(explanation_.value_or(""), explanation)};
}
// Explains why the fusion could not be performed.
std::string Explain() const { return *explanation_; }
private:
// Empty IFF fusion is possible (explanation provided for negative cases).
std::optional<std::string> explanation_;
};
#define RETURN_IF_NOT_FUSIBLE(...) \
do { \
::xla::FusionDecision _decision = (__VA_ARGS__); \
if (TF_PREDICT_FALSE(!_decision.CanFuse())) { \
return _decision; \
} \
} while (0)
// HLO pass which performs instruction fusion. Instructions are fused
// "vertically", meaning producing instructions are fused into their consumers
// with the intent that the loops which compute their values will be fused in
// code generation. Derived classes define ShouldFuse method to select which
// instructions to fuse.
class InstructionFusion : public HloModulePass {
public:
explicit InstructionFusion(
std::function<bool(const HloInstruction& instruction)> is_expensive,
bool may_duplicate = true,
FusionConfigCollection config_collection_mode =
FusionConfigCollection::kOff)
: is_expensive_(is_expensive),
may_duplicate_(may_duplicate),
config_collection_mode_(config_collection_mode) {}
~InstructionFusion() override = default;
absl::string_view name() const override { return "fusion"; }
// Run instruction fusion on the given computation. Returns whether the
// computation was changed (instructions were fused).
using HloPassInterface::Run;
StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
// Returns true if the computation of the given instruction is significantly
// more expensive than just writing all the values of the instructions' result
// array. Expensive operations will not be duplicated.
static bool IsExpensive(const HloInstruction& instruction);
// Returns true if it's legal to fuse the producer instruction into consumer
// with regard to in-place semantics of the consumer. For example, it is
// illegal to fuse a slice into a dynamic-update-slice if the slice output is
// used as the update and if slice and dynamic-update-slice indices cannot be
// proven to be the same.
static FusionDecision ShouldFuseInPlaceOp(const HloInstruction* producer,
const HloInstruction* consumer);
protected:
// Returns a list of computations on which Fusion is performed.
virtual std::vector<HloComputation*> GetFusionComputations(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads);
// Returns a FusionQueue that implements custom order of instructions being
// fused. The default implementation processes consumers in reverse post
// order.
virtual std::unique_ptr<FusionQueue> GetFusionQueue(
HloComputation* computation);
// Returns whether the given producer instruction should be fused into the
// given consumer instruction. producer is necessarily an operand of consumer.
// Derived classes should define this method to specify which instructions
// should be fused. `operand_index` is which operand of the consumer the
// producer is.
//
// Instructions are traversed in reverse post order (computation root to
// leaves). This method is called for each operand of the instruction (where
// the operand is 'producer' and the instruction is 'consumer')
//
// Subtypes can override this with target-specific heuristics.
virtual FusionDecision ShouldFuse(HloInstruction* consumer,
int64_t operand_index);
// Returns whether multi-output fusion can be applied to fuse `producer` into
// `consumer`. In contrast to "regular" fusion, the `producer` is not
// duplicated by multi-output fusion.
virtual FusionDecision ShouldFuseIntoMultiOutput(HloInstruction* consumer,
int64_t operand_index) {
return "multi-output fusion not supported by this pass";
}
// Chooses a fusion kind for `producer` and `consumer`.
// Default method chooses `kLoop`.
virtual HloInstruction::FusionKind ChooseKind(const HloInstruction* producer,
const HloInstruction* consumer);
// Fuses 'producer' into 'fusion_instruction'. 'fusion_instruction' needs to
// be a fusion instruction. Returns the newly created clone of 'producer'
// which is part of the fusion computation.
virtual HloInstruction* FuseInstruction(HloInstruction* fusion_instruction,
HloInstruction* producer);
// Fuses producer into consumer. Returns the fusion instruction.
virtual HloInstruction* Fuse(HloInstruction* producer,
HloInstruction* consumer,
HloComputation* computation);
// Creates a new fusion instruction containing `producer` and `consumer`. A
// tuple is added as the fusion instruction's root, which consumes from both,
// `producer` and `consumer`. This style of fusion is referred to as
// multi-output fusion.
virtual HloInstruction* FuseIntoMultiOutput(HloInstruction* producer,
HloInstruction* consumer,
HloComputation* computation);
// An "effectively unary" operation is one that has at most one "large"
// input with the others being negligible in terms of memory usage.
// We use "has a smaller true rank than the output" as a heuristic
// for "negligible" memory usage.
bool EffectivelyAtMostUnary(HloInstruction* hlo);
// Returns true if fusing producer into consumer would cause producer to be
// duplicated. This is the case if producer has uses other than consumer.
bool FusionWouldDuplicate(const HloInstruction& producer,
const HloInstruction& consumer) {
return !(producer.users().size() == 1 && consumer.IsUserOf(&producer));
}
bool is_expensive(const HloInstruction& instruction) {
return is_expensive_(instruction);
}
// Overwrites the originally initialized is_expensive function.
void set_is_expensive(
std::function<bool(const HloInstruction& instruction)> is_expensive) {
is_expensive_ = is_expensive;
}
// Whether multi-output fusion would introduce a cycle into the HLO graph.
bool MultiOutputFusionCreatesCycle(HloInstruction* producer,
HloInstruction* consumer,
const HloReachabilityMap& reachability);
FusionConfigCollection config_collection_mode() {
return config_collection_mode_;
}
// Returns whether 'consumer' may reuse elements of its `operand_index`th
// operand.
bool ReusesOperandElements(const HloInstruction* consumer,
int64_t operand_index);
// The set of producers whose consumers we cannot fuse into.
using HloInstructionSet = absl::flat_hash_set<HloInstruction*>;
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
virtual HloInstructionSet ComputeGloballyUnfusible(
absl::Span<HloInstruction* const> post_order,
const HloReachabilityMap& reachability);
private:
// Returns the reused operands of `instruction` from reused_fusion_operands_,
// computing them if they have not previously been computed for that
// instruction.
// The returned value has pointer stability, assuming entries are not deleted
// from reused_fusion_operands_.
absl::flat_hash_set<const HloInstruction*>& ReusedOperandsOf(
const HloInstruction* instruction);
// Updates reused_fusion_operands_ for a fusion when we are about to fuse
// `producer` into `fusion_instruction`.
void UpdateReusedOperandsForFusion(HloInstruction* producer,
HloInstruction* fusion_instruction);
HloInstruction* AddFusionInstruction(HloInstruction* producer,
HloInstruction* consumer,
HloComputation* computation);
// Whether or not we can fuse producer into consumer on all paths
// from the producer to the consumer where nodes are HLOs and edges are uses.
//
// A map from <producer, consumer> to a bool is required as the result cache
// to store and query the results of calls to this function, in order to avoid
// repeated computations.
bool CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
const HloInstructionSet& do_not_fuse,
const HloReachabilityMap& reachability,
absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
result_cache);
// Used to determine if an HLO is expensive. Expensive operations will not be
// duplicated.
std::function<bool(const HloInstruction& instruction)> is_expensive_;
// Dumps the state of computation before fusion.
void DumpPreFusionState(HloComputation* computation, HloInstruction* consumer,
HloInstruction* producer, bool is_mof = false);
// Dumps the state of computation and the reason why the fusion was not
// performed.
void DumpNotFusingState(HloComputation* computation, HloInstruction* consumer,
HloInstruction* producer, FusionDecision decision);
// Dumps the state of computation after fusion happened.
void DumpStateAfterFusion(HloComputation* computation,
HloInstruction* fusion_instruction,
const std::string& producer_name);
// Returns whether we may duplicate an instruction if we want to fuse it.
bool may_duplicate_;
// Configuration mode.
FusionConfigCollection config_collection_mode_;
// Caches which operands are reused inside fusion computations.
absl::flat_hash_map<
const HloInstruction*,
std::unique_ptr<absl::flat_hash_set<const HloInstruction*>>>
reused_fusion_operands_;
InstructionFusion(const InstructionFusion&) = delete;
InstructionFusion& operator=(const InstructionFusion&) = delete;
};
} // namespace xla
#endif // XLA_SERVICE_INSTRUCTION_FUSION_H_