1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
28 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
29 #include "tensorflow/compiler/xla/status.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/core/lib/core/status.h"
32 
33 namespace xla {
34 
35 // Collection of utilities for handling HloModuleGroups.
36 class HloModuleGroupUtil {
37  public:
HloModuleGroupUtil(const HloModuleGroupMetadata & metadata)38   explicit HloModuleGroupUtil(const HloModuleGroupMetadata& metadata)
39       : metadata_(metadata) {}
40 
41   // Returns all unique predecessors of the instruction. This includes:
42   // * predecessors in the same computation: operands and control predecessors
43   // * Recv is a predecessor of Send
44   // * Send is a predecessor of RecvDone
45   // * predecessors of companions (if the instruction is a companion while)
46   // * predecessors' companions (for any predecessor that is a companion while)
47   std::vector<HloInstruction*> GlobalPredecessors(HloInstruction* instruction);
48 
49   // Returns all unique successors of the instruction. This includes:
50   // * successors in the same computation: users and control successors
51   // * Send is a successor of Recv
52   // * RecvDone is a successor of Send
53   // * successors of companions (if the instruction is a companion while)
54   // * successors' companions (for any successor that is a companion while)
55   std::vector<HloInstruction*> GlobalSuccessors(HloInstruction* instruction);
56 
57   // Returns the root instructions of the computations.
58   std::vector<HloInstruction*> RootInstructions(
59       absl::Span<HloComputation* const> computations);
60 
61   // Visit state of each instruction during DFS traversal.
62   enum VisitState {
63     kNotVisited = 0,
64     kVisiting,
65     kVisited,
66   };
67 
68   // Function called on each instruction group during the DFS traversal. See the
69   // comment for VisitTopologicalOrder()).
70   using VisitFunction = std::function<Status(
71       HloInstruction* hlo,
72       const std::vector<HloInstruction*>& instruction_group)>;
73 
74   // Given the hlo instruction as the root, recursively visits all its
75   // predecessor instructions in DFS order to visit nodes in topological order.
76   //
77   // Note that the DFS traversal does not only visit nodes in the same
78   // computation (parent of the root instruction), but also visits nodes in
79   // different computations connected via communication instructions. During the
80   // traversal, companion While instructions (see the class comment in
81   // HloModuleGroupMetadata) are treated as a single instruction (called
82   // instruction group, which contains only a single instruction if the visiting
83   // node is not a companion while) -- visiting one of the instructions in the
84   // group effectively visits all other instructions in the group, and then all
85   // predecessor instructions of the group are visited.
86   //
87   // * visit_state: map from each instruction to its visit state.
88   // * visit_function: function called when each instruction group.
89   // * root: the root instruction of the traversal.
90   using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>;
91   Status VisitTopologicalOrder(VisitStates* visit_state,
92                                const VisitFunction& visit_function,
93                                HloInstruction* root);
94 
95   // Verifies that the computations are well-formed (e.g., no cycles).
96   Status VerifyComputations(absl::Span<HloComputation* const> computations);
97 
98   // Below Reachability utils resemble those in HloComputation, except that
99   // they can handle instructions across multiple computations.
100   //
101   // Creates the reachability map for the instructions in the computations.
102   StatusOr<std::unique_ptr<HloReachabilityMap>> ComputeReachability(
103       absl::Span<HloComputation* const> computations);
104 
105   // Updates the reachability of the given instruction, taking the global
106   // predeccessorss and successors into account.
107   void UpdateReachabilityThroughInstruction(
108       HloInstruction* instruction, HloReachabilityMap* reachability_map);
109 
110  private:
111   string CycleToString(HloInstruction* instruction);
112 
113   const HloModuleGroupMetadata& metadata_;
114 };
115 
116 }  // namespace xla
117 
118 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_UTIL_H_
119