1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_replication_analysis.h"
17 
18 #include <memory>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/memory/memory.h"
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 
34 namespace xla {
35 
36 namespace {
37 
38 // Determines whether an HLO instruction is replicated at index based on current
39 // knowledge in hlo_replication.
DetermineHloInstructionIsReplicated(const HloInstruction * hlo,const ShapeIndex & index,bool cross_partition_spmd,const absl::flat_hash_map<const HloInstruction *,ShapeTree<bool>> & hlo_replication)40 bool DetermineHloInstructionIsReplicated(
41     const HloInstruction* hlo, const ShapeIndex& index,
42     bool cross_partition_spmd,
43     const absl::flat_hash_map<const HloInstruction*, ShapeTree<bool>>&
44         hlo_replication) {
45   // Returns true if all operands are known to be replicated.
46   const auto all_operands_replicated =
47       [&hlo_replication](const HloInstruction* inst) {
48         for (auto operand : inst->operands()) {
49           auto operand_it = hlo_replication.find(operand);
50           if (operand_it == hlo_replication.end() ||
51               !operand_it->second.element({})) {
52             return false;
53           }
54         }
55         return true;
56       };
57 
58   if (hlo->opcode() == HloOpcode::kAllReduce ||
59       hlo->opcode() == HloOpcode::kAllGather) {
60     // All-reduce/all-gather returns same values across partitions/replicas as
61     // long as its operands are replicated.
62     if (all_operands_replicated(hlo)) {
63       return true;
64     }
65     if (!hlo->channel_id().has_value()) {
66       // This is cross-replica-only.
67       if (cross_partition_spmd) {
68         return false;
69       }
70       // Only all-reduce/all-gather across all cores are replicated, which means
71       // there is only one subgroup.
72       return hlo->replica_groups().empty() || hlo->replica_groups().size() == 1;
73     } else {
74       bool global_id;
75       if (hlo->opcode() == HloOpcode::kAllReduce) {
76         global_id = Cast<HloAllReduceInstruction>(hlo)->use_global_device_ids();
77       } else {
78         global_id = Cast<HloAllGatherInstruction>(hlo)->use_global_device_ids();
79       }
80       if (global_id) {
81         bool replicated_across_partitions = true;
82         bool replicated_across_replicas = true;
83         const int64 num_partitions =
84             hlo->GetModule()->config().num_partitions();
85         for (const auto& group : hlo->replica_groups()) {
86           absl::flat_hash_set<int64> visited_partitions;
87           absl::flat_hash_set<int64> visited_replicas;
88           for (int64 id : group.replica_ids()) {
89             int64 rid = id / num_partitions;
90             int64 pid = id % num_partitions;
91             visited_partitions.insert(pid);
92             visited_replicas.insert(rid);
93           }
94           replicated_across_partitions &=
95               visited_partitions.size() == num_partitions;
96           replicated_across_replicas &=
97               visited_replicas.size() ==
98               hlo->GetModule()->config().replica_count();
99         }
100         return cross_partition_spmd ? replicated_across_partitions
101                                     : replicated_across_replicas;
102       }
103       return cross_partition_spmd ? true
104                                   : hlo->replica_groups().empty() ||
105                                         hlo->replica_groups().size() == 1;
106     }
107   }
108   if (hlo->HasSideEffectNoRecurse()) {
109     return false;
110   }
111   if (hlo->opcode() == HloOpcode::kReplicaId) {
112     // ReplicaId returns the same value for all partitions in each replica.
113     return cross_partition_spmd;
114   }
115   if (hlo->opcode() == HloOpcode::kPartitionId) {
116     // PartitionId returns the same value for all replicas in each partition.
117     return !cross_partition_spmd;
118   }
119   auto it = hlo_replication.find(hlo);
120   if (hlo->opcode() == HloOpcode::kParameter) {
121     // Parameters should have been processed.
122     return it != hlo_replication.end() && it->second.element(index);
123   }
124   if (it != hlo_replication.end() && !it->second.element(index)) {
125     // The HLO is already marked as non-replicated.
126     return false;
127   }
128   if (hlo->opcode() == HloOpcode::kConstant) {
129     return true;
130   }
131 
132   if (hlo->opcode() == HloOpcode::kCustomCall &&
133       (hlo->custom_call_target() == "X64SplitLow" ||
134        hlo->custom_call_target() == "X64SplitHigh" ||
135        hlo->custom_call_target() == "X64Combine")) {
136     return all_operands_replicated(hlo);
137   }
138 
139   if (hlo->IsElementwise() ||                             //
140       hlo->opcode() == HloOpcode::kConcatenate ||         //
141       hlo->opcode() == HloOpcode::kConvolution ||         //
142       hlo->opcode() == HloOpcode::kDot ||                 //
143       hlo->opcode() == HloOpcode::kReduce ||              //
144       hlo->opcode() == HloOpcode::kBroadcast ||           //
145       hlo->opcode() == HloOpcode::kTranspose ||           //
146       hlo->opcode() == HloOpcode::kReshape ||             //
147       hlo->opcode() == HloOpcode::kBitcast ||             //
148       hlo->opcode() == HloOpcode::kReverse ||             //
149       hlo->opcode() == HloOpcode::kGather ||              //
150       hlo->opcode() == HloOpcode::kScatter ||             //
151       hlo->opcode() == HloOpcode::kIota ||                //
152       hlo->opcode() == HloOpcode::kPad ||                 //
153       hlo->opcode() == HloOpcode::kSlice ||               //
154       hlo->opcode() == HloOpcode::kDynamicSlice ||        //
155       hlo->opcode() == HloOpcode::kDynamicUpdateSlice ||  //
156       hlo->opcode() == HloOpcode::kReduceWindow ||        //
157       hlo->opcode() == HloOpcode::kCopy) {
158     return all_operands_replicated(hlo);
159   }
160   return false;
161 }
162 
163 }  // namespace
164 
ComputeHloReplicationOnComputation(const HloComputation * computation,bool mark_everything_not_replicated)165 bool HloReplicationAnalysis::ComputeHloReplicationOnComputation(
166     const HloComputation* computation, bool mark_everything_not_replicated) {
167   bool changed = false;
168   for (HloInstruction* inst : computation->MakeInstructionPostOrder()) {
169     // Assigns the shape tree to dest if dest doesn't have one yet, or combines
170     // it with the existing one by and'ing them. Returns if anything is updated.
171     auto assign_or_combine_shapetree = [&](ShapeTree<bool>&& to_combine,
172                                            const HloInstruction* dest) {
173       auto it = hlo_replication_.find(dest);
174       if (it == hlo_replication_.end()) {
175         hlo_replication_[dest] = std::move(to_combine);
176         return true;
177       }
178       bool updated = false;
179       it->second.ForEachMutableElement(
180           [&](const ShapeIndex& index, bool* element) {
181             if (*element && !to_combine.element(index)) {
182               *element = false;
183               updated = true;
184             }
185           });
186       return updated;
187     };
188     // Assigns or combines source's shape tree to dest. Returns if anything is
189     // updated.
190     auto propagate_shapetree = [&](const HloInstruction* source,
191                                    const HloInstruction* dest) {
192       auto source_it = hlo_replication_.find(source);
193       if (source_it == hlo_replication_.end()) {
194         return false;
195       }
196       return assign_or_combine_shapetree(ShapeTree<bool>(source_it->second),
197                                          dest);
198     };
199     // For the opcodes below that we do special handling, we don't need to
200     // explicitly check mark_everything_not_replicated because if it is set, the
201     // operands should already be marked as not replicated.
202     if (inst->opcode() == HloOpcode::kWhile) {
203       // Since while body's input and output alias each other, we need to run it
204       // multiple times until a fixed point is reached.
205       while (true) {
206         // First, propagate the input's and body root's shape trees to the
207         // parameters of the body and condition.
208         bool updated = propagate_shapetree(
209             inst->operand(0),
210             inst->while_condition()->parameter_instruction(0));
211         updated |= propagate_shapetree(
212             inst->while_body()->root_instruction(),
213             inst->while_condition()->parameter_instruction(0));
214         updated |= propagate_shapetree(
215             inst->operand(0), inst->while_body()->parameter_instruction(0));
216         updated |=
217             propagate_shapetree(inst->while_body()->root_instruction(),
218                                 inst->while_body()->parameter_instruction(0));
219         // Compute the condition.
220         updated |= ComputeHloReplicationOnComputation(
221             inst->while_condition(), mark_everything_not_replicated);
222         // Compute the body. If the condition is not replicated, the while body
223         // should be different across replicas.
224         if (!ContainsKey(loops_known_with_same_iterations_, inst) &&
225             !hlo_replication_[inst->while_condition()->root_instruction()]
226                  .element({})) {
227           updated |= ComputeHloReplicationOnComputation(
228               inst->while_body(), /*mark_everything_not_replicated=*/true);
229         } else {
230           updated |= ComputeHloReplicationOnComputation(
231               inst->while_body(), mark_everything_not_replicated);
232         }
233         if (!updated) {
234           break;
235         }
236         changed = true;
237       }
238       // Propagate the input's and body root's shape trees to the while HLO.
239       changed |= propagate_shapetree(inst->operand(0), inst);
240       changed |=
241           propagate_shapetree(inst->while_body()->root_instruction(), inst);
242     } else if (inst->opcode() == HloOpcode::kCall ||
243                inst->opcode() == HloOpcode::kFusion) {
244       auto called = inst->called_computations().front();
245       for (int64 i = 0; i < inst->operand_count(); ++i) {
246         changed |= propagate_shapetree(inst->operand(i),
247                                        called->parameter_instruction(i));
248       }
249       changed |= ComputeHloReplicationOnComputation(
250           called, mark_everything_not_replicated);
251       changed |= propagate_shapetree(called->root_instruction(), inst);
252     } else if (inst->opcode() == HloOpcode::kConditional) {
253       // Propagate inputs' shape trees to the called computations' parameters.
254       for (int64 i = 0; i < inst->called_computations().size(); ++i) {
255         changed |= propagate_shapetree(
256             inst->operand(i + 1),
257             inst->called_computations()[i]->parameter_instruction(0));
258       }
259       // If the condition is not replicated, the conditional result should be
260       // different across replicas.
261       if (!hlo_replication_[inst->operand(0)].element({})) {
262         for (auto called : inst->called_computations()) {
263           changed |= ComputeHloReplicationOnComputation(
264               called,
265               /*mark_everything_not_replicated=*/true);
266         }
267         changed |= assign_or_combine_shapetree(
268             ShapeTree<bool>(inst->shape(), false), inst);
269       } else {
270         for (auto called : inst->called_computations()) {
271           changed |= ComputeHloReplicationOnComputation(
272               called, mark_everything_not_replicated);
273           changed |= propagate_shapetree(called->root_instruction(), inst);
274         }
275       }
276     } else if (inst->opcode() == HloOpcode::kTupleSelect) {
277       if (!hlo_replication_[inst->operand(0)].element({})) {
278         // The predicate is not replicated, so the result is different across
279         // replicas.
280         changed |= assign_or_combine_shapetree(
281             ShapeTree<bool>(inst->shape(), false), inst);
282       } else {
283         changed |= propagate_shapetree(inst->operand(1), inst);
284         changed |= propagate_shapetree(inst->operand(2), inst);
285       }
286     } else if (inst->opcode() == HloOpcode::kTuple) {
287       ShapeTree<bool> shape_tree(inst->shape(), true);
288       for (int64 i = 0; i < inst->operand_count(); ++i) {
289         shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(i)], {}, {i});
290       }
291       changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
292     } else if (inst->opcode() == HloOpcode::kGetTupleElement) {
293       ShapeTree<bool> shape_tree(inst->shape(), true);
294       shape_tree.CopySubtreeFrom(hlo_replication_[inst->operand(0)],
295                                  {inst->tuple_index()}, {});
296       changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
297     } else if (inst->opcode() == HloOpcode::kInfeed && cross_partition_spmd_) {
298       ShapeTree<bool> shape_tree(inst->shape(), false);
299       if (inst->has_sharding()) {
300         auto sharding = inst->sharding().GetAsShapeTree(inst->shape());
301         shape_tree.ForEachMutableElement(
302             [&sharding](const ShapeIndex& index, bool* data) {
303               *data = sharding.element(index).IsReplicated();
304             });
305       }
306       changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
307     } else {
308       if (mark_everything_not_replicated) {
309         changed |= assign_or_combine_shapetree(
310             ShapeTree<bool>(inst->shape(), false), inst);
311       } else {
312         ShapeTree<bool> shape_tree(inst->shape(), true);
313         ShapeUtil::ForEachSubshape(
314             inst->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
315               *shape_tree.mutable_element(index) =
316                   DetermineHloInstructionIsReplicated(
317                       inst, index, cross_partition_spmd_, hlo_replication_);
318               return Status::OK();
319             });
320         changed |= assign_or_combine_shapetree(std::move(shape_tree), inst);
321       }
322     }
323   }
324   return changed;
325 }
326 
ComputeHloReplication()327 void HloReplicationAnalysis::ComputeHloReplication() {
328   // Add entry parameters to the above sets according to user annotation.
329   // Replicated modules read from `parameter_replicated_at_leaf_buffers` whereas
330   // SPMD partitioned modules read from HloSharding attributes.
331   auto entry = module_->entry_computation();
332   for (int i = 0; i < entry->num_parameters(); ++i) {
333     auto param = entry->parameter_instruction(i);
334     ShapeTree<bool> shape_tree(param->shape(), false);
335     if (cross_partition_spmd_ && param->has_sharding()) {
336       auto sharding_tree =
337           param->sharding().AsShapeTree(param->shape()).ValueOrDie();
338       ShapeUtil::ForEachSubshape(
339           param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
340             if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
341               return Status::OK();
342             }
343             *shape_tree.mutable_element(index) =
344                 sharding_tree.element(index).IsReplicated();
345             return Status::OK();
346           });
347     } else if (!cross_partition_spmd_) {
348       const auto& replication = param->parameter_replicated_at_leaf_buffers();
349       int leaf_index = 0;
350       ShapeUtil::ForEachSubshape(
351           param->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
352             if (!ShapeUtil::IsLeafIndex(param->shape(), index)) {
353               return Status::OK();
354             }
355             if (replication && replication->at(leaf_index)) {
356               *shape_tree.mutable_element(index) = true;
357             }
358             ++leaf_index;
359             return Status::OK();
360           });
361     }
362     hlo_replication_[param] = std::move(shape_tree);
363   }
364   ComputeHloReplicationOnComputation(entry,
365                                      /*mark_everything_not_replicated=*/false);
366 }
367 
HloInstructionIsReplicatedAt(const HloInstruction * inst,const ShapeIndex & index) const368 bool HloReplicationAnalysis::HloInstructionIsReplicatedAt(
369     const HloInstruction* inst, const ShapeIndex& index) const {
370   auto it = hlo_replication_.find(inst);
371   if (it == hlo_replication_.end()) {
372     return false;
373   }
374   return it->second.element(index);
375 }
376 
377 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd)378 HloReplicationAnalysis::Run(const HloModule* module,
379                             bool cross_partition_spmd) {
380   const absl::flat_hash_set<const HloInstruction*> empty;
381   return Run(module, cross_partition_spmd, &empty);
382 }
383 
384 /* static */ StatusOr<std::unique_ptr<HloReplicationAnalysis>>
Run(const HloModule * module,bool cross_partition_spmd,const absl::flat_hash_set<const HloInstruction * > * loops_known_with_same_iterations)385 HloReplicationAnalysis::Run(const HloModule* module, bool cross_partition_spmd,
386                             const absl::flat_hash_set<const HloInstruction*>*
387                                 loops_known_with_same_iterations) {
388   auto analysis = absl::WrapUnique(new HloReplicationAnalysis(
389       module, cross_partition_spmd, loops_known_with_same_iterations));
390   analysis->ComputeHloReplication();
391   return analysis;
392 }
393 
394 }  // namespace xla
395