1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
18 
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_module.h"
30 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
31 #include "tensorflow/compiler/xla/status.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/core/lib/core/status.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 
38 // Class for bookkeeping the information on the given modules, in particular on
39 // the interaction between computations.
40 //
41 // Companion instructions are one piece of information collected as we build the
42 // metadata. For example, for each While instruction, companion instructions
43 // refer to a set of While instructions in other computations that communicate
44 // with each other.
45 // In the example below with 3 modules, {While_0, While_2, While_5}, {While_1,
46 // While_4}, {While_3, While_6} are companion sets.
47 //
48 // <Module 0>               <Module 1>                 <Module 2>
49 // While_0() {              While_2() {                While_5() {
50 //   While_1() { Send(0) }    While_3() { Send(1) }      While_6() { Recv(1) }
51 // }                          While_4() { Recv(0) }
52 //                          }
53 //
54 // Each instruction can belong to at most one companion set: While_0 and While_5
55 // are in the same set even though they don't communicate with each other,
56 // because they both communicate with While_2.
57 //
58 // A send and the matching recv must both have the same level of nesting of
59 // companion instructions.
60 //
61 // Companion instructions are used to detect cycles in the graph and also for
62 // global scheduling.
63 class HloModuleGroupMetadata {
64  public:
65   // The kind of companion computation a given instruction can be within.
66   enum class ComputationKind {
67     kInvalid,
68     kWhileCondition,
69     kWhileBody,
70     kConditionalBranch,
71     kCallFunction,
72   };
73 
74   // Tracks the instruction mapped to a given computation, and the computation
75   // kind.
76   // For example, a body computation of a while instruction, will generate a
77   // TrackedInstruction with instruction being the while instruction, and
78   // kind being ComputationKind::kWhileBody.
79   class TrackedInstruction {
80    public:
81     TrackedInstruction() = default;
82     TrackedInstruction(HloInstruction* instruction, ComputationKind kind,
83                        int index = -1)
instruction_(instruction)84         : instruction_(instruction), kind_(kind), index_(index) {}
85 
86     bool operator==(const TrackedInstruction& rhs) const {
87       return instruction_->opcode() == rhs.instruction_->opcode() &&
88              kind_ == rhs.kind_ && index_ == rhs.index_;
89     }
90     bool operator!=(const TrackedInstruction& rhs) const {
91       return !operator==(rhs);
92     }
93 
instruction()94     HloInstruction* instruction() const { return instruction_; }
95 
96     string ToString() const;
97 
98    private:
99     HloInstruction* instruction_ = nullptr;
100     ComputationKind kind_ = ComputationKind::kInvalid;
101     int index_ = -1;
102   };
103 
104   // Represents a channel and the instructions that form the channel.
105   struct Channel {
106     int64 id = -1;
107     HloInstruction* send = nullptr;
108     HloInstruction* recv = nullptr;
109     HloInstruction* send_done = nullptr;
110     HloInstruction* recv_done = nullptr;
111   };
112 
HloModuleGroupMetadata(absl::Span<HloModule * const> modules)113   explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules)
114       : modules_(modules.begin(), modules.end()) {}
115 
116   ~HloModuleGroupMetadata() = default;
117 
118   // Build and return the metadata for the given modules.
119   static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build(
120       absl::Span<HloModule* const> modules);
121 
122   // Returns true if the instruction is one of the 4 channel instructions (Send,
123   // Recv, SendDone, RecvDone).
124   bool IsChannelInstruction(const HloInstruction* instruction) const;
125 
126   // Returns true if the instruction is a companion instruction. See the class
127   // comment above on companion instructions.
128   bool IsCompanionInstruction(HloInstruction* hlo) const;
129 
130   // Returns true if the instruction is either a channel instruction, a
131   // cross-module all-reduce instruction, or a companion instruction.
132   bool InstructionCommunicates(HloInstruction* hlo) const;
133 
134   // Returns the Channel instance for the given channel id.
135   const Channel& GetChannel(int64 channel_id) const;
136 
137   // Returns if the given channel id exists in metadata.
138   bool HasChannel(int64 channel_id) const;
139 
140   // Returns the all-reduce instructions with the same all_reduce_id.
141   const std::vector<HloInstruction*>& GetAllReduceGroup(
142       int64 all_reduce_id) const;
143 
144   // Returns the computation that contains the peer channel instructions for
145   // the given instruction.
146   //
147   // Precondition: IsChannelInstruction(instruction) is true.
148   HloComputation* PeerComputation(const HloInstruction* instruction) const;
149 
150   // Returns the path of the nested companion instructions, in terms of HLO
151   // instructions. The path goes from inner to outer companions.
152   // The returned path does not include the input hlo instruction, in case it
153   // is a companion instruction.
154   std::vector<TrackedInstruction> GetCompanionsPath(
155       const HloInstruction* hlo) const;
156 
157   // Checks whether two companion paths (as returned by the GetCompanionsPath()
158   // API) are compatible. The two paths are compatible if the sequence of
159   // opcodes, and the companion kinds, of the two paths matches.
160   bool CheckCompanionPathsCompatibility(
161       const std::vector<TrackedInstruction>& path0,
162       const std::vector<TrackedInstruction>& path1) const;
163 
164   // Returns the unique integer for each module. The returned id is the index of
165   // the module in the module vector.
166   int64 GetModuleId(const HloModule* module) const;
167 
168   // Retrieves the device an instruction is assigned to. Either from the
169   // sharding information, or from the ordinal of the module the instruction
170   // is in.
171   absl::optional<int64> GetInstructionDevice(
172       const HloInstruction& instruction) const;
173 
174   // Returns the number of modules for devices (excluding the host module).
175   int64 GetDeviceModulesCount() const;
176 
177   // Returns the companion set for the given instruction, including the
178   // instruction itself.
179   //
180   // Precondition: IsCompanionWhile(instruction) is true.
Companions(const HloInstruction * instruction)181   const std::vector<HloInstruction*>& Companions(
182       const HloInstruction* instruction) const {
183     CHECK(companion_set_index_.contains(instruction));
184     return companion_set(companion_set_index_.at(instruction));
185   }
186 
187   // Returns the companion set at the given index.
companion_set(int64 index)188   const std::vector<HloInstruction*>& companion_set(int64 index) const {
189     CHECK_LT(index, companion_sets_.size());
190     return *companion_sets_[index];
191   }
192 
193   // Returns the companion set index of the given instruction.
companion_set_index(HloInstruction * instruction)194   int64 companion_set_index(HloInstruction* instruction) const {
195     return companion_set_index_.at(instruction);
196   }
197 
198   // Returns the list of all companion sets in the HLO module group.
199   const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>&
companion_sets()200   companion_sets() const {
201     return companion_sets_;
202   }
203 
204   // Returns all channels in the module group.
channels()205   const std::vector<Channel>& channels() const { return channels_; }
206 
207   // Returns the maximum channel id or all_reduce_id used in the module group.
max_channel_id()208   int64 max_channel_id() const { return max_channel_id_; }
209 
points_to_analysis(HloModule * module)210   TuplePointsToAnalysis* points_to_analysis(HloModule* module) const {
211     return points_to_analyses_.at(module).get();
212   }
213 
214  private:
215   Status Build();
216 
217   // Record all channel instructions, cross-module AllReduce instructions, and
218   // While/Conditional/Call instructions.
219   Status RecordInstructions();
220 
221   // Verifies the given HloModules are well-formed and follow the specification,
222   // in particular with respect to using channel instructions.
223   //
224   // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone).
225   // * The shape of channel instructions match.
226   // * The nest level of channel instructions match.
227   // * Channel instructions are used in allowed computations, i.e., in the
228   //   entry computation of the module or condition/body of While computations.
229   Status VerifyChannelInstructions();
230 
231   // Adds metadata that the given two instructions are companions.
232   Status AddCompanion(HloInstruction* instruction1,
233                       HloInstruction* instruction2);
234 
235   // Checks whether a communicating instruction is placed in a valid position
236   // within the graph.
237   Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
238 
239   // Performs a consistency check on the companion sets built for the input
240   // modules. Checks that each instruction in a companion set is in a different
241   // module/device.
242   Status VerifyCompanionSets() const;
243 
244   // Retrieves a pointer to the stored TrackedInstruction associated with a
245   // tracked computation, or nullptr in case such computation is not tracked.
GetTrackedInstruction(const HloComputation * computation)246   const TrackedInstruction* GetTrackedInstruction(
247       const HloComputation* computation) const {
248     auto it = tracked_instructions_.find(computation);
249     return it != tracked_instructions_.end() ? &it->second : nullptr;
250   }
251 
252   // Dump all the collected module group statistics to the logs.
253   void DumpCollectedStats() const;
254 
255   // List of all companion instructions sets in the module.
256   std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
257 
258   // Map from each companion while instruction to the index into companion_set_.
259   absl::flat_hash_map<const HloInstruction*, int64> companion_set_index_;
260 
261   // Map from computation to the instruction using it (a kWhile, kConditional).
262   absl::flat_hash_map<const HloComputation*, TrackedInstruction>
263       tracked_instructions_;
264 
265   // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
266   // communicating instructions within the proper called computation(s).
267   absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
268       tracked_instructions_comms_;
269 
270   // All channels in the module.
271   std::vector<Channel> channels_;
272 
273   // Map from channel ids to the index in channels_.
274   absl::flat_hash_map<int64, int64> channel_id_map_;
275 
276   // Map from all-reduce ids to the all reduce instructions.
277   absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_;
278 
279   // The maximum channel id used in the module group.
280   int64 max_channel_id_ = -1;
281 
282   // The modules that this metadata was built from.
283   const std::vector<HloModule*> modules_;
284 
285   absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
286       points_to_analyses_;
287 };
288 
289 }  // namespace xla
290 
291 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_
292