/* Copyright 2017 The OpenXLA Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Call graph for an HLO module. #ifndef XLA_SERVICE_CALL_GRAPH_H_ #define XLA_SERVICE_CALL_GRAPH_H_ #include #include #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/functional/function_ref.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_module.h" namespace xla { // The context in which a computation is called by another computation. enum class CallContext { // In an embedded call context, the body of the function cannot allocate // buffers. kEmbedded, // A control flow call context can allocate buffers. kControlFlow, // A computation is called from both an embedded and control flow context. kBoth, // During call graph construction kNone is used to indicate that the context // has not been determined. This is the top value for the context // lattice. After construction, no call sites or call graph nodes should have // this value. kNone }; std::string CallContextToString(CallContext context); std::ostream& operator<<(std::ostream& out, const CallContext& context); CallContext GetInstructionCallContext(HloOpcode opcode); // Represents an HLO instruction which calls one or more computations. class CallSite { public: CallSite(HloInstruction* instruction, absl::Span called_computations, CallContext context) : instruction_(CHECK_NOTNULL(instruction)), called_computations_(called_computations.begin(), called_computations.end()), context_(context) {} // Returns the instruction associated with this call site. HloInstruction* instruction() const { return instruction_; } // Returns the computations called at this call site. absl::Span called_computations() const { return called_computations_; } // Returns the context in which computations are called at this call site. CallContext context() const { return context_; } std::string ToString() const; private: // The calling instruction. HloInstruction* instruction_; // The computations called by this callsite. const absl::InlinedVector called_computations_; // The context in which the computations are called. const CallContext context_; }; // A node in the call graph representing an HLO computation. class CallGraphNode { public: explicit CallGraphNode(HloComputation* computation); // Returns the computation represented by this call graph node. HloComputation* computation() const { return computation_; } // Returns the call sites in this computation. These are the instructions in // this computation which call other computations. absl::Span callsites() const { return callsites_; } // Returns the callsite associated with the given instruction. If this // instruction calls no computations nullptr is returned. // Prerequisite: instruction is in the computation associated with this call // graph node. const CallSite* GetCallSite(const HloInstruction* instruction) const; // Returns the computations called by this computation. absl::Span callees() const { return callees_; } // Returns the call sites in other computations which call this computation. absl::Span caller_callsites() const { return caller_callsites_; } // Returns the computations which call this computation. absl::Span callers() const { return callers_; } // Returns the context in which this computation is called. CallContext context() const { return context_; } // Returns the depth of this node in the call graph. The depth is defined as // the length of the longest call chain from a computation with no callers // (usually the entry computation node) to this node. int depth() const { return depth_; } absl::string_view ToString() const; CallGraphNode(const CallGraphNode&) = delete; CallGraphNode& operator=(const CallGraphNode&) = delete; CallGraphNode(CallGraphNode&&) = default; CallGraphNode& operator=(CallGraphNode&&) = default; private: // Only CallGraph can modify CallGraphNode. friend class CallGraph; // Sets the context in which this computation is called. void set_context(CallContext value) { context_ = value; } // Sets the depth of this node in the graph. void set_depth(int value) { depth_ = value; } // Adds a callsite which calls this computation. Updates callers to include // the calling computation. void AddCallerCallSite(const CallSite& caller_callsite); // If instruction calls any computations adds a call site for this instruction // to the call graph node. If the instruction calls no computations then no // call site is added. void AddCallSiteForInstruction( HloInstruction* instruction, const absl::flat_hash_set& execution_threads = {}); // Computation represented by this call graph node. HloComputation* computation_; // The computations called by this computation. The vector is used for a // stable ordering and the set enables fast membership testing. absl::InlinedVector callees_; absl::flat_hash_set callee_set_; // The computations which call this computation. The vector is used for a // stable ordering and the set enables fast membership testing. absl::InlinedVector callers_; absl::flat_hash_set caller_set_; // The call sites in this computation absl::InlinedVector callsites_; // The map from instruction to index in callsites_ for looking up the callsite // (if any) associated with a particular instruction in this computation. absl::flat_hash_map callsite_instructions_; // The call sites in other computations which call this computation. absl::InlinedVector caller_callsites_; // The context in which this computation is called. CallContext context_ = CallContext::kNone; // The depth of this node in the call graph. int depth_ = 0; }; // The call graph for an HLO module. The graph includes a node for each // computation in the module. class CallGraph { public: using VisitorFunction = absl::FunctionRef; // Builds and returns a call graph for the given HLO module. If a non-empty // execution_threads is provided, only computations that are in // execution_threads will be part of the returned call graph. static std::unique_ptr Build( const HloModule* module, const absl::flat_hash_set& execution_threads = {}); // Returns the node associated with the given computation. const CallGraphNode& GetNode(const HloComputation* computation) const; CallGraphNode& GetNode(const HloComputation* computation); // Returns the vector of all nodes in the call graph. const std::vector& nodes() const { return nodes_; } // Calls the given function on each node in the call graph. Nodes are visited // in post order (callees before callers). If visit_unreachable_nodes is true // then all nodes in the call graph are visited. Otherwise only those nodes // reachable from the entry computation are visited. Status VisitNodes(VisitorFunction visitor_func, bool visit_unreachable_nodes = true) const; // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' // dominates computation 'b' iff all callgraph paths in the caller-to-callee // direction from a root computation to 'b' pass through computation // 'a'. Trivially, a computation dominates itself. bool Dominates(const HloComputation* a, const HloComputation* b) const; // Returns true if 'a' can reach 'b' in the call graph. 'a' can reach 'b' if // 'a' is 'b' or 'a' can reach one of the callers of 'b'. bool CanReach(const HloComputation* a, const HloComputation* b) const; // Returns whether 'instruction' is contained in 'computation' either directly // ('instruction->parent' is 'computation') or indirectly ('computation' // dominates 'instruction->parent' in the call graph). bool InstructionIsNestedIn(const HloInstruction* instruction, const HloComputation* computation) const { return Dominates(computation, instruction->parent()); } // Returns the nearest call graph ancestors of instructions 'a' and 'b' for // which the ancestors are in the same computation. An instruction is an call // graph ancestor of 'a' if the instruction calls the computation containing // 'a' either directly or transitively. Degeneratively an instruction is an // ancestor of itself. nullptr is returned if there is no common ancestor or // if the caller chain of 'a' or 'b' diverges (has multiple callers) before // the nearest common ancestor. // // Example: // // Entry computation: // %x = Call(A, {Constant(42.0)}) // %y = Call(B, {%x}) // // Computation A: // %a = Negate(Param()) // // Computation B: // %b = Exp(Param()); // // If called with %a and %b, this function would return (%x, %y). %x is an // ancestor of %a, and %y is an ancestor of %b, and %x and %y are in the same // computation. std::pair NearestAncestorsInSameComputation( HloInstruction* a, HloInstruction* b) const; // Given a set of instructions within a computation, returns nearest common // ancestors as Hlo instructions (There could be multiple nearest common // ancestors in a DAG). If the given instructions are not in the same // computation, this function would report FAILURE. // // Unlike the `NearestAncestorsInSameComputation` defined above, it: // // (1) Only compute the nearest common ancestors within a computation, instead // of across computations (that's the function // `ComputationsNearestCommonAncestors` that defined below). // // (2) Takes in **a set of** Hlo instructions, instead of two Hlo // instructions, and find their nearest common ancestors. // // Example: // // Computation A: // %p0 = Param(0) // %p1 = Param(1) // %p2 = Param(2) // %add0 = Add(%p0, %p1) // %mul0 = Mul(%p1, %p2) // %sub0 = Sub(%add0, %mul0) // // If called with {%p0, %p1}, this function would return {%add0}. // // Please check the detailed example in // `CallGraphTest.NearestCommonAncestorInstructions`. absl::flat_hash_set NearestCommonAncestorInstructions( std::vector instructions); // Given a set of computations within a module, returns nearest common // ancestors as Hlo computations (There could be multiple nearest common // ancestors in a DAG). // // Entry_computation: // %x = Call(A, {Constant(42.0)}) // %y = Call(B, {%x}) // // Computation_A: // %a = Negate(Param()) // // Computation_B: // %b = Exp(Param()); // // If called with {Computation_A, Computation_B}, this function would return // {Entry_computation}. // // Please check the detailed example in // `CallGraphTest.NearestCommonAncestorComputations`. absl::flat_hash_set NearestCommonAncestorComputations( std::vector computations); // A template helper function that computes the nearest common ancestors among // instructions/computations. `T` can be either `HloInstruction` or // `HloComputation`. Computing nearest common ancestors are basically the same // for HloInstruction and HloComputation. The only difference is that they // require different ways to access the ancestors of one node. Specifically, // the ancestors are users_instruction for instructions, and are // caller_computations for computations. // // The overall idea is to conduct BFS from the `starting_nodes`, and keep // track of the visited ancestors of each node. For each BFS step, we check if // there is a common node in all the visited ancestors, and if yes, that // common node is the nearest ancestor we are looking for. Note that, since we // are traversing DAG, there could be multiple nearest common ancestors. And // there must be at least one common ancestor (i.e., entry computations among // computations or root instruction among instructions). template absl::flat_hash_set NearestCommonAncestorsHelper( std::vector& starting_nodes); // Returns whether the call graph is flattened. A call graph is flattened if // every computation called in a sequential context (eg, kWhile or kCall) has // zero or one callsite, and no computation is called from both a parallel and // sequential context. The call graph of a module can be flattened with // FlattenCallGraph. bool IsFlattened() const; // Returns a vector of instructions calling the passed computation. // (Often a vector of size 1.) std::vector GetComputationCallers( const HloComputation* c) const; std::string ToString() const; private: explicit CallGraph( const HloModule* module, const absl::flat_hash_set& execution_threads = {}); // Not copyable. CallGraph(const CallGraph&) = delete; CallGraph& operator=(const CallGraph&) = delete; // Sets the call contexts for every node in the graph. void SetCallContexts(); // Sets the call node depths for every node in the graph. void SetNodeDepths(); // Helper method for VisitNodes(). Traverses the call graph from 'node' in DFS // post order (callee before caller) calling visitor_func on each node. Adds // nodes to 'visited' as each node is visited. Skips nodes already in // 'visited'. Status VisitNodesInternal( VisitorFunction visitor_func, const CallGraphNode& node, absl::flat_hash_set* visited) const; // Recursive helper for computing whether 'a' dominates 'b' in the call // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), // and 'visited' is the set of computations which have been visited. bool DominatesHelper( const HloComputation* a, const HloComputation* b, absl::flat_hash_set* visited) const; // The HLO module represented by this call graph. const HloModule* module_ = nullptr; // Vector of all nodes in the call graph. std::vector nodes_; // Map from HLO computation to the index of the corresponding call graph node // in nodes_. absl::flat_hash_map node_indices_; // The execution threads that the call graph is built for. absl::flat_hash_set execution_threads_; }; } // namespace xla #endif // XLA_SERVICE_CALL_GRAPH_H_