1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <vector> 22 23 #include "absl/container/flat_hash_map.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/hlo_computation.h" 26 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 27 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h" 28 #include "tensorflow/compiler/xla/service/hlo_reachability.h" 29 #include "tensorflow/compiler/xla/status.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/core/lib/core/status.h" 32 33 namespace xla { 34 35 // Collection of utilities for handling HloModuleGroups. 36 class HloModuleGroupUtil { 37 public: HloModuleGroupUtil(const HloModuleGroupMetadata & metadata)38 explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata) 39 : metadata_(metadata) {} 40 41 // Returns all unique predecessors of the instruction. This includes: 42 // * predecessors in the same computation: operands and control predecessors 43 // * Recv is a predecessor of Send 44 // * Send is a predecessor of RecvDone 45 // * predecessors of companions (if the instruction is a companion while) 46 // * predecessors' companions (for any predecessor that is a companion while) 47 std::vector<HloInstruction*> GlobalPredecessors(HloInstruction* instruction); 48 49 // Returns all unique successors of the instruction. This includes: 50 // * successors in the same computation: users and control successors 51 // * Send is a successor of Recv 52 // * RecvDone is a successor of Send 53 // * successors of companions (if the instruction is a companion while) 54 // * successors' companions (for any successor that is a companion while) 55 std::vector<HloInstruction*> GlobalSuccessors(HloInstruction* instruction); 56 57 // Returns the root instructions of the computations. 58 std::vector<HloInstruction*> RootInstructions( 59 absl::Span<HloComputation* const> computations); 60 61 // Visit state of each instruction during DFS traversal. 62 enum VisitState { 63 kNotVisited = 0, 64 kVisiting, 65 kVisited, 66 }; 67 68 // Function called on each instruction group during the DFS traversal. See the 69 // comment for VisitTopologicalOrder()). 70 using VisitFunction = std::function<Status( 71 HloInstruction* hlo, 72 const std::vector<HloInstruction*>& instruction_group)>; 73 74 // Given the hlo instruction as the root, recursively visits all its 75 // predecessor instructions in DFS order to visit nodes in topological order. 76 // 77 // Note that the DFS traversal does not only visit nodes in the same 78 // computation (parent of the root instruction), but also visits nodes in 79 // different computations connected via communication instructions. During the 80 // traversal, companion While instructions (see the class comment in 81 // HloModuleGroupMetadata) are treated as a single instruction (called 82 // instruction group, which contains only a single instruction if the visiting 83 // node is not a companion while) -- visiting one of the instructions in the 84 // group effectively visits all other instructions in the group, and then all 85 // predecessor instructions of the group are visited. 86 // 87 // * visit_state: map from each instruction to its visit state. 88 // * visit_function: function called when each instruction group. 89 // * root: the root instruction of the traversal. 90 using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>; 91 Status VisitTopologicalOrder(VisitStates* visit_state, 92 const VisitFunction& visit_function, 93 HloInstruction* root); 94 95 // Verifies that the computations are well-formed (e.g., no cycles). 96 Status VerifyComputations(absl::Span<HloComputation* const> computations); 97 98 // Below Reachability utils resemble those in HloComputation, except that 99 // they can handle instructions across multiple computations. 100 // 101 // Creates the reachability map for the instructions in the computations. 102 StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability( 103 absl::Span<HloComputation* const> computations); 104 105 // Updates the reachability of the given instruction, taking the global 106 // predeccessorss and successors into account. 107 void UpdateReachabilityThroughInstruction( 108 HloInstruction* instruction, HloReachabilityMap* reachability_map); 109 110 private: 111 string CycleToString(HloInstruction* instruction); 112 113 const HloModuleGroupMetadata& metadata_; 114 }; 115 116 } // namespace xla 117 118 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_ 119