1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ 18 19 #include <memory> 20 #include <set> 21 #include <string> 22 #include <unordered_set> 23 #include <vector> 24 25 #include "absl/container/flat_hash_map.h" 26 #include "absl/types/optional.h" 27 #include "tensorflow/compiler/xla/service/hlo_computation.h" 28 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 29 #include "tensorflow/compiler/xla/service/hlo_module.h" 30 #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" 31 #include "tensorflow/compiler/xla/status.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/core/lib/core/status.h" 34 #include "tensorflow/core/platform/types.h" 35 36 namespace xla { 37 38 // Class for bookkeeping the information on the given modules, in particular on 39 // the interaction between computations. 40 // 41 // Companion instructions are one piece of information collected as we build the 42 // metadata. For example, for each While instruction, companion instructions 43 // refer to a set of While instructions in other computations that communicate 44 // with each other. 45 // In the example below with 3 modules, {While_0, While_2, While_5}, {While_1, 46 // While_4}, {While_3, While_6} are companion sets. 47 // 48 // <Module 0> <Module 1> <Module 2> 49 // While_0() { While_2() { While_5() { 50 // While_1() { Send(0) } While_3() { Send(1) } While_6() { Recv(1) } 51 // } While_4() { Recv(0) } 52 // } 53 // 54 // Each instruction can belong to at most one companion set: While_0 and While_5 55 // are in the same set even though they don't communicate with each other, 56 // because they both communicate with While_2. 57 // 58 // A send and the matching recv must both have the same level of nesting of 59 // companion instructions. 60 // 61 // Companion instructions are used to detect cycles in the graph and also for 62 // global scheduling. 63 class HloModuleGroupMetadata { 64 public: 65 // The kind of companion computation a given instruction can be within. 66 enum class ComputationKind { 67 kInvalid, 68 kWhileCondition, 69 kWhileBody, 70 kConditionalBranch, 71 kCallFunction, 72 }; 73 74 // Tracks the instruction mapped to a given computation, and the computation 75 // kind. 76 // For example, a body computation of a while instruction, will generate a 77 // TrackedInstruction with instruction being the while instruction, and 78 // kind being ComputationKind::kWhileBody. 79 class TrackedInstruction { 80 public: 81 TrackedInstruction() = default; 82 TrackedInstruction(HloInstruction* instruction, ComputationKind kind, 83 int index = -1) instruction_(instruction)84 : instruction_(instruction), kind_(kind), index_(index) {} 85 86 bool operator==(const TrackedInstruction& rhs) const { 87 return instruction_->opcode() == rhs.instruction_->opcode() && 88 kind_ == rhs.kind_ && index_ == rhs.index_; 89 } 90 bool operator!=(const TrackedInstruction& rhs) const { 91 return !operator==(rhs); 92 } 93 instruction()94 HloInstruction* instruction() const { return instruction_; } 95 96 string ToString() const; 97 98 private: 99 HloInstruction* instruction_ = nullptr; 100 ComputationKind kind_ = ComputationKind::kInvalid; 101 int index_ = -1; 102 }; 103 104 // Represents a channel and the instructions that form the channel. 105 struct Channel { 106 int64 id = -1; 107 HloInstruction* send = nullptr; 108 HloInstruction* recv = nullptr; 109 HloInstruction* send_done = nullptr; 110 HloInstruction* recv_done = nullptr; 111 }; 112 HloModuleGroupMetadata(absl::Span<HloModule * const> modules)113 explicit HloModuleGroupMetadata(absl::Span<HloModule* const> modules) 114 : modules_(modules.begin(), modules.end()) {} 115 116 ~HloModuleGroupMetadata() = default; 117 118 // Build and return the metadata for the given modules. 119 static StatusOr<std::unique_ptr<HloModuleGroupMetadata>> Build( 120 absl::Span<HloModule* const> modules); 121 122 // Returns true if the instruction is one of the 4 channel instructions (Send, 123 // Recv, SendDone, RecvDone). 124 bool IsChannelInstruction(const HloInstruction* instruction) const; 125 126 // Returns true if the instruction is a companion instruction. See the class 127 // comment above on companion instructions. 128 bool IsCompanionInstruction(HloInstruction* hlo) const; 129 130 // Returns true if the instruction is either a channel instruction, a 131 // cross-module all-reduce instruction, or a companion instruction. 132 bool InstructionCommunicates(HloInstruction* hlo) const; 133 134 // Returns the Channel instance for the given channel id. 135 const Channel& GetChannel(int64 channel_id) const; 136 137 // Returns if the given channel id exists in metadata. 138 bool HasChannel(int64 channel_id) const; 139 140 // Returns the all-reduce instructions with the same all_reduce_id. 141 const std::vector<HloInstruction*>& GetAllReduceGroup( 142 int64 all_reduce_id) const; 143 144 // Returns the computation that contains the peer channel instructions for 145 // the given instruction. 146 // 147 // Precondition: IsChannelInstruction(instruction) is true. 148 HloComputation* PeerComputation(const HloInstruction* instruction) const; 149 150 // Returns the path of the nested companion instructions, in terms of HLO 151 // instructions. The path goes from inner to outer companions. 152 // The returned path does not include the input hlo instruction, in case it 153 // is a companion instruction. 154 std::vector<TrackedInstruction> GetCompanionsPath( 155 const HloInstruction* hlo) const; 156 157 // Checks whether two companion paths (as returned by the GetCompanionsPath() 158 // API) are compatible. The two paths are compatible if the sequence of 159 // opcodes, and the companion kinds, of the two paths matches. 160 bool CheckCompanionPathsCompatibility( 161 const std::vector<TrackedInstruction>& path0, 162 const std::vector<TrackedInstruction>& path1) const; 163 164 // Returns the unique integer for each module. The returned id is the index of 165 // the module in the module vector. 166 int64 GetModuleId(const HloModule* module) const; 167 168 // Retrieves the device an instruction is assigned to. Either from the 169 // sharding information, or from the ordinal of the module the instruction 170 // is in. 171 absl::optional<int64> GetInstructionDevice( 172 const HloInstruction& instruction) const; 173 174 // Returns the number of modules for devices (excluding the host module). 175 int64 GetDeviceModulesCount() const; 176 177 // Returns the companion set for the given instruction, including the 178 // instruction itself. 179 // 180 // Precondition: IsCompanionWhile(instruction) is true. Companions(const HloInstruction * instruction)181 const std::vector<HloInstruction*>& Companions( 182 const HloInstruction* instruction) const { 183 CHECK(companion_set_index_.contains(instruction)); 184 return companion_set(companion_set_index_.at(instruction)); 185 } 186 187 // Returns the companion set at the given index. companion_set(int64 index)188 const std::vector<HloInstruction*>& companion_set(int64 index) const { 189 CHECK_LT(index, companion_sets_.size()); 190 return *companion_sets_[index]; 191 } 192 193 // Returns the companion set index of the given instruction. companion_set_index(HloInstruction * instruction)194 int64 companion_set_index(HloInstruction* instruction) const { 195 return companion_set_index_.at(instruction); 196 } 197 198 // Returns the list of all companion sets in the HLO module group. 199 const std::vector<std::unique_ptr<std::vector<HloInstruction*>>>& companion_sets()200 companion_sets() const { 201 return companion_sets_; 202 } 203 204 // Returns all channels in the module group. channels()205 const std::vector<Channel>& channels() const { return channels_; } 206 207 // Returns the maximum channel id or all_reduce_id used in the module group. max_channel_id()208 int64 max_channel_id() const { return max_channel_id_; } 209 points_to_analysis(HloModule * module)210 TuplePointsToAnalysis* points_to_analysis(HloModule* module) const { 211 return points_to_analyses_.at(module).get(); 212 } 213 214 private: 215 Status Build(); 216 217 // Record all channel instructions, cross-module AllReduce instructions, and 218 // While/Conditional/Call instructions. 219 Status RecordInstructions(); 220 221 // Verifies the given HloModules are well-formed and follow the specification, 222 // in particular with respect to using channel instructions. 223 // 224 // * Each channel has all 4 instructions (Send, Recv, SendDone, RecvDone). 225 // * The shape of channel instructions match. 226 // * The nest level of channel instructions match. 227 // * Channel instructions are used in allowed computations, i.e., in the 228 // entry computation of the module or condition/body of While computations. 229 Status VerifyChannelInstructions(); 230 231 // Adds metadata that the given two instructions are companions. 232 Status AddCompanion(HloInstruction* instruction1, 233 HloInstruction* instruction2); 234 235 // Checks whether a communicating instruction is placed in a valid position 236 // within the graph. 237 Status CheckCommunicatingInstruction(HloInstruction* instruction) const; 238 239 // Performs a consistency check on the companion sets built for the input 240 // modules. Checks that each instruction in a companion set is in a different 241 // module/device. 242 Status VerifyCompanionSets() const; 243 244 // Retrieves a pointer to the stored TrackedInstruction associated with a 245 // tracked computation, or nullptr in case such computation is not tracked. GetTrackedInstruction(const HloComputation * computation)246 const TrackedInstruction* GetTrackedInstruction( 247 const HloComputation* computation) const { 248 auto it = tracked_instructions_.find(computation); 249 return it != tracked_instructions_.end() ? &it->second : nullptr; 250 } 251 252 // Dump all the collected module group statistics to the logs. 253 void DumpCollectedStats() const; 254 255 // List of all companion instructions sets in the module. 256 std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_; 257 258 // Map from each companion while instruction to the index into companion_set_. 259 absl::flat_hash_map<const HloInstruction*, int64> companion_set_index_; 260 261 // Map from computation to the instruction using it (a kWhile, kConditional). 262 absl::flat_hash_map<const HloComputation*, TrackedInstruction> 263 tracked_instructions_; 264 265 // Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of 266 // communicating instructions within the proper called computation(s). 267 absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>> 268 tracked_instructions_comms_; 269 270 // All channels in the module. 271 std::vector<Channel> channels_; 272 273 // Map from channel ids to the index in channels_. 274 absl::flat_hash_map<int64, int64> channel_id_map_; 275 276 // Map from all-reduce ids to the all reduce instructions. 277 absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_; 278 279 // The maximum channel id used in the module group. 280 int64 max_channel_id_ = -1; 281 282 // The modules that this metadata was built from. 283 const std::vector<HloModule*> modules_; 284 285 absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>> 286 points_to_analyses_; 287 }; 288 289 } // namespace xla 290 291 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_GROUP_METADATA_H_ 292