1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/memory/memory.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33
34 namespace xla {
35
36 namespace {
37
38 // Determines whether an HLO instruction is replicated at index based on current
39 // knowledge in hlo_replication.
DetermineHloInstructionIsReplicated(const HloInstruction * hlo,const ShapeIndex & index,bool cross_partition_spmd,const absl::flat_hash_map<const HloInstruction *,ShapeTree<bool>> & hlo_replication)40 bool DetermineHloInstructionIsReplicated(
41 const HloInstruction* hlo, const ShapeIndex& index,
42 bool cross_partition_spmd,
43 const absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>&
44 hlo_replication) {
45 // Returns true if all operands are known to be replicated.
46 const auto all_operands_replicated =
47 [&hlo_replication](const HloInstruction* inst) {
48 for (auto operand : inst->operands()) {
49 auto operand_it = hlo_replication.find(operand);
50 if (operand_it == hlo_replication.end() ||
51 !operand_it->second.element({})) {
52 return false;
53 }
54 }
55 return true;
56 };
57
58 if (hlo->opcode() == HloOpcode::kAllReduce ||
59 hlo->opcode() == HloOpcode::kAllGather) {
60 // All-reduce/all-gather returns same values across partitions/replicas as
61 // long as its operands are replicated.
62 if (all_operands_replicated(hlo)) {
63 return true;
64 }
65 if (!hlo->channel_id().has_value()) {
66 // This is cross-replica-only.
67 if (cross_partition_spmd) {
68 return false;
69 }
70 // Only all-reduce/all-gather across all cores are replicated, which means
71 // there is only one subgroup.
72 return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
73 } else {
74 bool global_id;
75 if (hlo->opcode() == HloOpcode::kAllReduce) {
76 global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
77 } else {
78 global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
79 }
80 if (global_id) {
81 bool replicated_across_partitions = true;
82 bool replicated_across_replicas = true;
83 const int64 num_partitions =
84 hlo->GetModule()->config().num_partitions();
85 for (const auto& group : hlo->replica_groups()) {
86 absl::flat_hash_set<int64> visited_partitions;
87 absl::flat_hash_set<int64> visited_replicas;
88 for (int64 id : group.replica_ids()) {
89 int64 rid = id / num_partitions;
90 int64 pid = id % num_partitions;
91 visited_partitions.insert(pid);
92 visited_replicas.insert(rid);
93 }
94 replicated_across_partitions &=
95 visited_partitions.size() == num_partitions;
96 replicated_across_replicas &=
97 visited_replicas.size() ==
98 hlo->GetModule()->config().replica_count();
99 }
100 return cross_partition_spmd ? replicated_across_partitions
101 : replicated_across_replicas;
102 }
103 return cross_partition_spmd ? true
104 : hlo->replica_groups().empty() ||
105 hlo->replica_groups().size() == 1;
106 }
107 }
108 if (hlo->HasSideEffectNoRecurse()) {
109 return false;
110 }
111 if (hlo->opcode() == HloOpcode::kReplicaId) {
112 // ReplicaId returns the same value for all partitions in each replica.
113 return cross_partition_spmd;
114 }
115 if (hlo->opcode() == HloOpcode::kPartitionId) {
116 // PartitionId returns the same value for all replicas in each partition.
117 return !cross_partition_spmd;
118 }
119 auto it = hlo_replication.find(hlo);
120 if (hlo->opcode() == HloOpcode::kParameter) {
121 // Parameters should have been processed.
122 return it != hlo_replication.end() && it->second.element(index);
123 }
124 if (it != hlo_replication.end() && !it->second.element(index)) {
125 // The HLO is already marked as non-replicated.
126 return false;
127 }
128 if (hlo->opcode() == HloOpcode::kConstant) {
129 return true;
130 }
131
132 if (hlo->opcode() == HloOpcode::kCustomCall &&
133 (hlo->custom_call_target() == "X64SplitLow" ||
134 hlo->custom_call_target() == "X64SplitHigh" ||
135 hlo->custom_call_target() == "X64Combine")) {
136 return all_operands_replicated(hlo);
137 }
138
139 if (hlo->IsElementwise() || //
140 hlo->opcode() == HloOpcode::kConcatenate || //
141 hlo->opcode() == HloOpcode::kConvolution || //
142 hlo->opcode() == HloOpcode::kDot || //
143 hlo->opcode() == HloOpcode::kReduce || //
144 hlo->opcode() == HloOpcode::kBroadcast || //
145 hlo->opcode() == HloOpcode::kTranspose || //
146 hlo->opcode() == HloOpcode::kReshape || //
147 hlo->opcode() == HloOpcode::kBitcast || //
148 hlo->opcode() == HloOpcode::kReverse || //
149 hlo->opcode() == HloOpcode::kGather || //
150 hlo->opcode() == HloOpcode::kScatter || //
151 hlo->opcode() == HloOpcode::kIota || //
152 hlo->opcode() == HloOpcode::kPad || //
153 hlo->opcode() == HloOpcode::kSlice || //
154 hlo->opcode() == HloOpcode::kDynamicSlice || //
155 hlo->opcode() == HloOpcode::kDynamicUpdateSlice || //
156 hlo->opcode() == HloOpcode::kReduceWindow || //
157 hlo->opcode() == HloOpcode::kCopy) {
158 return all_operands_replicated(hlo);
159 }
160 return false;
161 }
162
163 } // namespace
164
ComputeHloReplicationOnComputation(const HloComputation * computation,bool mark_everything_not_replicated)165 bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
166 const HloComputation* computation, bool mark_everything_not_replicated) {
167 bool changed = false;
168 for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
169 // Assigns the shape tree to dest if dest doesn't have one yet, or combines
170 // it with the existing one by and'ing them. Returns if anything is updated.
171 auto assign_or_combine_shapetree = [&](ShapeTree<bool>&& to_combine,
172 const HloInstruction* dest) {
173 auto it = hlo_replication_.find(dest);
174 if (it == hlo_replication_.end()) {
175 hlo_replication_[dest] = std::move(to_combine);
176 return true;
177 }
178 bool updated = false;
179 it->second.ForEachMutableElement(
180 [&](const ShapeIndex& index, bool* element) {
181 if (*element && !to_combine.element(index)) {
182 *element = false;
183 updated = true;
184 }
185 });
186 return updated;
187 };
188 // Assigns or combines source's shape tree to dest. Returns if anything is
189 // updated.
190 auto propagate_shapetree = [&](const HloInstruction* source,
191 const HloInstruction* dest) {
192 auto source_it = hlo_replication_.find(source);
193 if (source_it == hlo_replication_.end()) {
194 return false;
195 }
196 return assign_or_combine_shapetree(ShapeTree<bool>(source_it->second),
197 dest);
198 };
199 // For the opcodes below that we do special handling, we don't need to
200 // explicitly check mark_everything_not_replicated because if it is set, the
201 // operands should already be marked as not replicated.
202 if (inst->opcode() == HloOpcode::kWhile) {
203 // Since while body's input and output alias each other, we need to run it
204 // multiple times until a fixed point is reached.
205 while (true) {
206 // First, propagate the input's and body root's shape trees to the
207 // parameters of the body and condition.
208 bool updated = propagate_shapetree(
209 inst->operand(0),
210 inst->while_condition()->parameter_instruction(0));
211 updated |= propagate_shapetree(
212 inst->while_body()->root_instruction(),
213 inst->while_condition()->parameter_instruction(0));
214 updated |= propagate_shapetree(
215 inst->operand(0), inst->while_body()->parameter_instruction(0));
216 updated |=
217 propagate_shapetree(inst->while_body()->root_instruction(),
218 inst->while_body()->parameter_instruction(0));
219 // Compute the condition.
220 updated |= ComputeHloReplicationOnComputation(
221 inst->while_condition(), mark_everything_not_replicated);
222 // Compute the body. If the condition is not replicated, the while body
223 // should be different across replicas.
224 if (!ContainsKey(loops_known_with_same_iterations_, inst) &&
225 !hlo_replication_[inst->while_condition()->root_instruction()]
226 .element({})) {
227 updated |= ComputeHloReplicationOnComputation(
228 inst->while_body(), /*mark_everything_not_replicated=*/true);
229 } else {
230 updated |= ComputeHloReplicationOnComputation(
231 inst->while_body(), mark_everything_not_replicated);
232 }
233 if (!updated) {
234 break;
235 }
236 changed = true;
237 }
238 // Propagate the input's and body root's shape trees to the while HLO.
239 changed |= propagate_shapetree(inst->operand(0), inst);
240 changed |=
241 propagate_shapetree(inst->while_body()->root_instruction(), inst);
242 } else if (inst->opcode() == HloOpcode::kCall ||
243 inst->opcode() == HloOpcode::kFusion) {
244 auto called = inst->called_computations().front();
245 for (int64 i = 0; i < inst->operand_count(); ++i) {
246 changed |= propagate_shapetree(inst->operand(i),
247 called->parameter_instruction(i));
248 }
249 changed |= ComputeHloReplicationOnComputation(
250 called, mark_everything_not_replicated);
251 changed |= propagate_shapetree(called->root_instruction(), inst);
252 } else if (inst->opcode() == HloOpcode::kConditional) {
253 // Propagate inputs' shape trees to the called computations' parameters.
254 for (int64 i = 0; i < inst->called_computations().size(); ++i) {
255 changed |= propagate_shapetree(
256 inst->operand(i + 1),
257 inst->called_computations()[i]->parameter_instruction(0));
258 }
259 // If the condition is not replicated, the conditional result should be
260 // different across replicas.
261 if (!hlo_replication_[inst->operand(0)].element({})) {
262 for (auto called : inst->called_computations()) {
263 changed |= ComputeHloReplicationOnComputation(
264 called,
265 /*mark_everything_not_replicated=*/true);
266 }
267 changed |= assign_or_combine_shapetree(
268 ShapeTree<bool>(inst->shape(), false), inst);
269 } else {
270 for (auto called : inst->called_computations()) {
271 changed |= ComputeHloReplicationOnComputation(
272 called, mark_everything_not_replicated);
273 changed |= propagate_shapetree(called->root_instruction(), inst);
274 }
275 }
276 } else if (inst->opcode() == HloOpcode::kTupleSelect) {
277 if (!hlo_replication_[inst->operand(0)].element({})) {
278 // The predicate is not replicated, so the result is different across
279 // replicas.
280 changed |= assign_or_combine_shapetree(
281 ShapeTree<bool>(inst->shape(), false), inst);
282 } else {
283 changed |= propagate_shapetree(inst->operand(1), inst);
284 changed |= propagate_shapetree(inst->operand(2), inst);
285 }
286 } else if (inst->opcode() == HloOpcode::kTuple) {
287 ShapeTree<bool> shape_tree(inst->shape(), true);
288 for (int64 i = 0; i < inst->operand_count(); ++i) {
289 shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(i)], {}, {i});
290 }
291 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
292 } else if (inst->opcode() == HloOpcode::kGetTupleElement) {
293 ShapeTree<bool> shape_tree(inst->shape(), true);
294 shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(0)],
295 {inst->tuple_index()}, {});
296 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
297 } else if (inst->opcode() == HloOpcode::kInfeed && cross_partition_spmd_) {
298 ShapeTree<bool> shape_tree(inst->shape(), false);
299 if (inst->has_sharding()) {
300 auto sharding = inst->sharding().GetAsShapeTree(inst->shape());
301 shape_tree.ForEachMutableElement(
302 [&sharding](const ShapeIndex& index, bool* data) {
303 *data = sharding.element(index).IsReplicated();
304 });
305 }
306 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
307 } else {
308 if (mark_everything_not_replicated) {
309 changed |= assign_or_combine_shapetree(
310 ShapeTree<bool>(inst->shape(), false), inst);
311 } else {
312 ShapeTree<bool> shape_tree(inst->shape(), true);
313 ShapeUtil::ForEachSubshape(
314 inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
315 *shape_tree.mutable_element(index) =
316 DetermineHloInstructionIsReplicated(
317 inst, index, cross_partition_spmd_, hlo_replication_);
318 return Status::OK();
319 });
320 changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
321 }
322 }
323 }
324 return changed;
325 }
326
ComputeHloReplication()327 void HloReplicationAnalysis::ComputeHloReplication() {
328 // Add entry parameters to the above sets according to user annotation.
329 // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
330 // SPMD partitioned modules read from HloSharding attributes.
331 auto entry = module_->entry_computation();
332 for (int i = 0; i < entry->num_parameters(); ++i) {
333 auto param = entry->parameter_instruction(i);
334 ShapeTree<bool> shape_tree(param->shape(), false);
335 if (cross_partition_spmd_ && param->has_sharding()) {
336 auto sharding_tree =
337 param->sharding().AsShapeTree(param->shape()).ValueOrDie();
338 ShapeUtil::ForEachSubshape(
339 param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
340 if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
341 return Status::OK();
342 }
343 *shape_tree.mutable_element(index) =
344 sharding_tree.element(index).IsReplicated();
345 return Status::OK();
346 });
347 } else if (!cross_partition_spmd_) {
348 const auto& replication = param->parameter_replicated_at_leaf_buffers();
349 int leaf_index = 0;
350 ShapeUtil::ForEachSubshape(
351 param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
352 if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
353 return Status::OK();
354 }
355 if (replication && replication->at(leaf_index)) {
356 *shape_tree.mutable_element(index) = true;
357 }
358 ++leaf_index;
359 return Status::OK();
360 });
361 }
362 hlo_replication_[param] = std::move(shape_tree);
363 }
364 ComputeHloReplicationOnComputation(entry,
365 /*mark_everything_not_replicated=*/false);
366 }
367
HloInstructionIsReplicatedAt(const HloInstruction * inst,const ShapeIndex & index) const368 bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
369 const HloInstruction* inst, const ShapeIndex& index) const {
370 auto it = hlo_replication_.find(inst);
371 if (it == hlo_replication_.end()) {
372 return false;
373 }
374 return it->second.element(index);
375 }
376
377 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd)378 HloReplicationAnalysis::Run(const HloModule* module,
379 bool cross_partition_spmd) {
380 const absl::flat_hash_set<const HloInstruction*> empty;
381 return Run(module, cross_partition_spmd, &empty);
382 }
383
384 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd,const absl::flat_hash_set<const HloInstruction * > * loops_known_with_same_iterations)385 HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
386 const absl::flat_hash_set<const HloInstruction*>*
387 loops_known_with_same_iterations) {
388 auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
389 module, cross_partition_spmd, loops_known_with_same_iterations));
390 analysis->ComputeHloReplication();
391 return analysis;
392 }
393
394 } // namespace xla
395