1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <string>
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/types/optional.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
35 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
36 #include "tensorflow/compiler/xla/service/hlo_module.h"
37 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
38 #include "tensorflow/compiler/xla/service/hlo_value.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/platform/logging.h"
45 
46 namespace xla {
47 namespace {
48 // CalculatePostOrderSchedule traverses a module and assign a ordinal to each
49 // instruction based the postorder dependency.
CalculatePostOrderScheduleHelper(const HloComputation * comp,int64 start_ordinal,absl::flat_hash_map<HloInstruction *,int64> * ordinal_map)50 int64 CalculatePostOrderScheduleHelper(
51     const HloComputation* comp, int64 start_ordinal,
52     absl::flat_hash_map<HloInstruction*, int64>* ordinal_map) {
53   int64 ordinal = start_ordinal;
54   for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
55     if (instruction->opcode() == HloOpcode::kCall ||
56         instruction->opcode() == HloOpcode::kConditional) {
57       for (const HloComputation* called_computation :
58            instruction->called_computations()) {
59         ordinal = CalculatePostOrderScheduleHelper(called_computation, ordinal,
60                                                    ordinal_map);
61       }
62     }
63     if (instruction->opcode() == HloOpcode::kWhile) {
64       ordinal = CalculatePostOrderScheduleHelper(instruction->while_condition(),
65                                                  ordinal, ordinal_map);
66       ordinal = CalculatePostOrderScheduleHelper(instruction->while_body(),
67                                                  ordinal, ordinal_map);
68     }
69     // It's possible that in some unit tests the computation graph is not
70     // flatten (meaning we could have multiple callers for one computation). In
71     // that case the oridinal_map will see the instruction multiple times. We
72     // consider that case to be ok as it only shows up in unit tests.
73     ordinal_map->insert({instruction, ordinal++});
74   }
75   return ordinal;
76 }
77 
CalculatePostOrderSchedule(const HloModule & module)78 absl::flat_hash_map<HloInstruction*, int64> CalculatePostOrderSchedule(
79     const HloModule& module) {
80   absl::flat_hash_map<HloInstruction*, int64> map;
81   CalculatePostOrderScheduleHelper(module.entry_computation(), 0, &map);
82   return map;
83 }
84 
85 }  // namespace
86 using absl::StrAppend;
87 using absl::StrCat;
88 
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)89 HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
90                                          bool bitcast_defines_value,
91                                          const CanShareBuffer& can_share_buffer)
92     : module_(module),
93       ssa_form_(ssa_form),
94       bitcast_defines_value_(bitcast_defines_value),
95       call_graph_(CallGraph::Build(&module)),
96       can_share_buffer_(can_share_buffer) {}
97 
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)98 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
99     const HloInstruction* inst) {
100   absl::flat_hash_set<const HloInstruction*> visited;
101   absl::InlinedVector<const HloInstruction*, 4> stack;
102   stack.push_back(inst);
103   while (!stack.empty()) {
104     const HloInstruction* current = stack.back();
105     stack.pop_back();
106     visited.insert(current);
107     for (const HloInstruction* user : current->users()) {
108       // Found a user that is non-elementwise on current instruction.
109       for (const int64 use_index : user->OperandIndices(current)) {
110         if (!user->IsElementwiseOnOperand(use_index) &&
111             user->opcode() != HloOpcode::kTuple) {
112           return false;
113         }
114       }
115       if (!visited.contains(user)) {
116         stack.push_back(user);
117       }
118     }
119   }
120   return true;
121 }
122 
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const123 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
124                                            const ShapeIndex& index) const {
125   const HloValueSet& value_set = GetValueSet(instruction, index);
126   if (value_set.values().size() != 1) {
127     return false;
128   }
129   return value_set.GetUniqueValue().defining_instruction() == instruction;
130 }
131 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const132 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
133     const HloInstruction* instruction, const ShapeIndex& index) const {
134   CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
135   return GetUniqueValueAt(instruction, index);
136 }
137 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)138 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
139     const HloInstruction* instruction, const ShapeIndex& index) {
140   CHECK(ValueIsDefinedAt(instruction, index));
141   return GetUniqueValueAt(instruction, index);
142 }
143 
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)144 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
145                                            const ShapeIndex& index,
146                                            bool is_phi) {
147   const int64 value_id = next_value_id_++;
148   auto emplaced = values_.emplace(
149       std::piecewise_construct, std::forward_as_tuple(value_id),
150       std::forward_as_tuple(value_id, instruction, index, is_phi));
151   CHECK(emplaced.second);
152 
153   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
154 
155   return &emplaced.first->second;
156 }
157 
MarkValueForDeletion(HloValue::Id value_id)158 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
159   HloValue& value = values_.at(value_id);
160   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
161 
162   value_ids_to_delete_.push_back(value_id);
163 }
164 
DeleteMarkedValues()165 void HloDataflowAnalysis::DeleteMarkedValues() {
166   // Use a set to prevent deleting an id twice.
167   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
168                                            value_ids_to_delete_.end());
169 #ifndef NDEBUG
170   // Verify that no marked-for-deletion values are in any of the value sets.
171   for (const auto& pair : value_sets_) {
172     const HloInstruction* instruction = pair.first;
173     const InstructionValueSet& instruction_value_set = pair.second;
174     for (const auto& index_value_set : instruction_value_set) {
175       const HloValueSet& value_set = index_value_set.second;
176       for (const HloValue* value : value_set.values()) {
177         DCHECK(!ContainsKey(id_set, value->id()))
178             << "Value " << value->ToShortString()
179             << " marked for deletion, but still exists in value set for "
180                "instruction "
181             << instruction->name();
182       }
183     }
184   }
185 #endif
186 
187   for (HloValue::Id value_id : id_set) {
188     values_.erase(value_id);
189   }
190   value_ids_to_delete_.clear();
191 }
192 
ToString() const193 string HloDataflowAnalysis::ToString() const {
194   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
195   StrAppend(&out, "  Instruction value sets:\n");
196   for (const HloComputation* computation : module_.computations()) {
197     for (const HloInstruction* instruction : computation->instructions()) {
198       StrAppend(&out, "Instruction: \n  ", instruction->name(), ":\n");
199       if (instruction->shape().IsTuple()) {
200         GetInstructionValueSet(instruction)
201             .ForEachElement([this, &instruction, &out](
202                                 const ShapeIndex& index,
203                                 const HloValueSet& value_set) {
204               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
205               for (const HloValue* value : value_set.values()) {
206                 StrAppend(&out, "        ", value->ToShortString(),
207                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
208                           "\n");
209               }
210             });
211       } else {
212         const HloValueSet& top_level_value_set =
213             GetValueSet(instruction, /*index=*/{});
214         for (const HloValue* value : top_level_value_set.values()) {
215           StrAppend(&out, "      ", value->ToShortString(),
216                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
217         }
218       }
219     }
220   }
221   StrAppend(&out, "  HloValues:\n");
222   for (const HloValue* value : values()) {
223     StrAppend(&out, value->ToString(/*indent=*/4));
224   }
225   return out;
226 }
227 
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)228 bool HloDataflowAnalysis::Phi(
229     HloInstruction* instruction,
230     absl::Span<const InstructionValueSet* const> inputs) {
231   CHECK(ssa_form_);
232   VLOG(4) << "Phi(" << instruction->name() << ")";
233   VLOG(5) << "instruction value set = "
234           << GetInstructionValueSet(instruction).ToString();
235   for (const InstructionValueSet* input : inputs) {
236     VLOG(5) << "input value set = " << input->ToString();
237   }
238 
239   if (bitcast_defines_value_) {
240     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
241       DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
242     });
243   } else {
244     const Shape& shape = instruction->shape();
245     PrimitiveType ty = shape.element_type();
246     bool is_array = shape.IsArray();
247     absl::c_for_each(inputs, [&](const InstructionValueSet* input) {
248       DCHECK(ty == input->shape().element_type() &&
249              (!is_array || ShapeUtil::ElementsIn(shape) ==
250                                ShapeUtil::ElementsIn(input->shape())));
251     });
252   }
253 
254   bool changed = false;
255   for (auto& pair : GetInstructionValueSet(instruction)) {
256     const ShapeIndex& index = pair.first;
257     HloValueSet& value_set = pair.second;
258 
259     // Positions with phi values should never have more than one value in the
260     // value set.
261     CHECK_LE(value_set.values().size(), 1);
262     const HloValue* current_value =
263         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
264 
265     // Construct a vector of value IDs of the inputs.
266     std::vector<HloValue::Id> input_value_ids;
267     for (const InstructionValueSet* input : inputs) {
268       for (const HloValue* value : input->element(index).values()) {
269         input_value_ids.push_back(value->id());
270       }
271     }
272 
273     // Remove the existing phi value (if it exists). The phi can be its own
274     // input, for example, in while body parameters where the body passes
275     // through the parameter value.
276     bool current_value_defined_here =
277         (current_value != nullptr &&
278          current_value->defining_instruction() == instruction &&
279          current_value->defining_index() == index);
280 
281     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
282     if (input_value_ids.empty()) {
283       // A value set which has at least one element should never have its value
284       // set reduced to zero elements. During dataflow value sets only can go
285       // from empty to non-empty, not the reverse.
286       CHECK_EQ(value_set.values().size(), 0)
287           << "Instruction " << instruction->name() << " at index " << index
288           << " previously had non-empty value set. Value set: " << value_set;
289     } else if (input_value_ids.size() == 1) {
290       // Only a single value reaches this point. There should be no phi, and
291       // this value set should contain this single value.
292       const HloValue& new_value = GetValue(input_value_ids[0]);
293       if (current_value == nullptr) {
294         value_set.Clear();
295         value_set.AddValue(&new_value);
296         changed = true;
297       } else if (current_value != &new_value) {
298         if (current_value_defined_here) {
299           // Remove the existing phi.
300           MarkValueForDeletion(current_value->id());
301         }
302         value_set.Clear();
303         value_set.AddValue(&new_value);
304         changed = true;
305       }
306     } else {
307       // Multiple distinct values reach this point. A phi value is
308       // necessary.
309       CHECK_GT(input_value_ids.size(), 1);
310       bool phi_defined_here =
311           current_value_defined_here && current_value->is_phi();
312       if (current_value == nullptr || !phi_defined_here) {
313         value_set.Clear();
314         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
315 
316         std::vector<HloValue*> inputs;
317         inputs.reserve(input_value_ids.size());
318         for (HloValue::Id id : input_value_ids) {
319           inputs.push_back(&GetValue(id));
320         }
321         // Register the phi into phi graph.
322         phi_graph_.RegisterPhi(*value_set.values()[0], inputs);
323         changed = true;
324       } else if (phi_defined_here) {
325         std::vector<HloValue*> new_inputs;
326         new_inputs.reserve(input_value_ids.size());
327         for (HloValue::Id id : input_value_ids) {
328           new_inputs.push_back(&GetValue(id));
329         }
330 
331         if (!phi_graph_.InputsEqualTo(*current_value, new_inputs)) {
332           VLOG(1) << current_value->ToShortString() << " has new phi inputs: ";
333           // Update phi inputs.
334           phi_graph_.RegisterPhi(*current_value, new_inputs);
335           changed = true;
336         }
337       }
338     }
339   }
340   return changed;
341 }
342 
GetValue(HloValue::Id value_id) const343 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
344   return values_.at(value_id);
345 }
346 
GetValue(HloValue::Id value_id)347 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
348   return values_.at(value_id);
349 }
350 
GetFlattenedValueSet(const HloInstruction * instruction) const351 HloValueSet HloDataflowAnalysis::GetFlattenedValueSet(
352     const HloInstruction* instruction) const {
353   HloValueSet value_set;
354 
355   const InstructionValueSet& value_set_tree =
356       GetInstructionValueSet(instruction);
357 
358   std::vector<const HloValueSet*> all_sets;
359   for (auto& pair : value_set_tree) {
360     const HloValueSet& value_set = pair.second;
361     all_sets.push_back(&value_set);
362   }
363   value_set.AssignUnionOf(all_sets);
364 
365   return value_set;
366 }
367 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const368 const HloValueSet& HloDataflowAnalysis::GetValueSet(
369     const HloInstruction* instruction, const ShapeIndex& index) const {
370   return GetInstructionValueSet(instruction).element(index);
371 }
372 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)373 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
374                                               const ShapeIndex& index) {
375   return *GetInstructionValueSet(instruction).mutable_element(index);
376 }
377 
GetValueSet(const HloPosition & position) const378 const HloValueSet& HloDataflowAnalysis::GetValueSet(
379     const HloPosition& position) const {
380   return GetValueSet(position.instruction, position.index);
381 }
382 
GetValueSet(const HloPosition & position)383 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
384   return GetValueSet(position.instruction, position.index);
385 }
386 
UpdateBitcastValueSet(HloInstruction * bitcast)387 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
388   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
389   const InstructionValueSet& operand_set =
390       GetInstructionValueSet(bitcast->operand(0));
391   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
392   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
393     bitcast_set = operand_set;
394     return true;
395   }
396   return false;
397 }
398 
UpdateSetDimensionSizeValueSet(HloInstruction * set_dimension_size)399 bool HloDataflowAnalysis::UpdateSetDimensionSizeValueSet(
400     HloInstruction* set_dimension_size) {
401   CHECK_EQ(set_dimension_size->opcode(), HloOpcode::kSetDimensionSize);
402   const InstructionValueSet& operand_set =
403       GetInstructionValueSet(set_dimension_size->operand(0));
404   InstructionValueSet& set_dimension_size_set =
405       GetInstructionValueSet(set_dimension_size);
406   if (operand_set != set_dimension_size_set) {
407     set_dimension_size_set = operand_set;
408     return true;
409   }
410   return false;
411 }
412 
UpdateSendValueSet(HloInstruction * send)413 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
414   CHECK_EQ(send->opcode(), HloOpcode::kSend);
415   bool changed = false;
416   // Send forwards the operand value to the output tuple at {0}.
417   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
418     const ShapeIndex& operand_index = pair.first;
419     const HloValueSet& operand_value_set = pair.second;
420 
421     ShapeIndex index = {0};
422     for (int64 i : operand_index) {
423       index.push_back(i);
424     }
425 
426     HloValueSet& value_set = GetValueSet(send, index);
427     if (value_set != operand_value_set) {
428       value_set = operand_value_set;
429       changed = true;
430     }
431   }
432   return changed;
433 }
434 
UpdateCustomCallValueSet(HloInstruction * custom_call)435 bool HloDataflowAnalysis::UpdateCustomCallValueSet(
436     HloInstruction* custom_call) {
437   CHECK_EQ(custom_call->opcode(), HloOpcode::kCustomCall);
438   bool changed = false;
439   for (const auto& aliasing : Cast<HloCustomCallInstruction>(custom_call)
440                                   ->output_to_operand_aliasing()) {
441     const HloValueSet& operand_value_set = GetValueSet(
442         custom_call->operand(aliasing.second.first), aliasing.second.second);
443     HloValueSet& value_set = GetValueSet(custom_call, aliasing.first);
444     if (value_set != operand_value_set) {
445       value_set = operand_value_set;
446       changed = true;
447     }
448   }
449   return changed;
450 }
451 
UpdateCopyStartValueSet(HloInstruction * copy_start)452 bool HloDataflowAnalysis::UpdateCopyStartValueSet(HloInstruction* copy_start) {
453   CHECK_EQ(copy_start->opcode(), HloOpcode::kCopyStart);
454   bool changed = false;
455   // CopyStart forwards the operand value to element {1} of its output.
456   const HloValueSet& operand_value_set = GetValueSet(copy_start->operand(0));
457   HloValueSet& value_set = GetValueSet(copy_start, {1});
458   if (value_set != operand_value_set) {
459     value_set = operand_value_set;
460     changed = true;
461   }
462   return changed;
463 }
464 
UpdateCopyDoneValueSet(HloInstruction * copy_done)465 bool HloDataflowAnalysis::UpdateCopyDoneValueSet(HloInstruction* copy_done) {
466   CHECK_EQ(copy_done->opcode(), HloOpcode::kCopyDone);
467   bool changed = false;
468   // CopyDone forwards the operand value at {0} to element {} of its output.
469   const HloValueSet& operand_value_set =
470       GetValueSet(copy_done->operand(0), {0});
471   HloValueSet& value_set = GetValueSet(copy_done);
472   if (value_set != operand_value_set) {
473     value_set = operand_value_set;
474     changed = true;
475   }
476   return changed;
477 }
478 
UpdateRecvDoneValueSet(HloInstruction * recv_done)479 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
480   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
481   bool changed = false;
482   // RecvDone forwards the operand value at {0} to element {0} of its output.
483   for (auto& pair : GetInstructionValueSet(recv_done)) {
484     ShapeIndex& index = pair.first;
485     HloValueSet& value_set = pair.second;
486 
487     if (index.empty() || index[0] != 0) {
488       continue;
489     }
490 
491     const HloValueSet& operand_value_set =
492         GetValueSet(recv_done->operand(0), index);
493     if (value_set != operand_value_set) {
494       value_set = operand_value_set;
495       changed = true;
496     }
497   }
498   return changed;
499 }
500 
UpdateCallValueSet(HloInstruction * call)501 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
502   CHECK_EQ(call->opcode(), HloOpcode::kCall);
503   InstructionValueSet& value_set = GetInstructionValueSet(call);
504   InstructionValueSet& root_value_set =
505       GetInstructionValueSet(call->to_apply()->root_instruction());
506   if (value_set != root_value_set) {
507     value_set = root_value_set;
508     return true;
509   }
510   return false;
511 }
512 
UpdateConditionalValueSet(HloInstruction * conditional)513 bool HloDataflowAnalysis::UpdateConditionalValueSet(
514     HloInstruction* conditional) {
515   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
516   std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
517   for (int j = 0; j < conditional->branch_count(); ++j) {
518     inputs[j] = &GetInstructionValueSet(
519         conditional->branch_computation(j)->root_instruction());
520   }
521   if (ssa_form_) {
522     return Phi(conditional, inputs);
523   } else {
524     return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
525   }
526 }
527 
UpdateCopyValueSet(HloInstruction * copy)528 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
529   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
530   bool changed = false;
531   for (auto& pair : GetInstructionValueSet(copy)) {
532     const ShapeIndex& index = pair.first;
533     if (index.empty()) {
534       // kCopy shallow copies and thus defines the top-level value so nothing to
535       // update.
536       continue;
537     }
538 
539     HloValueSet& value_set = pair.second;
540     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
541     if (value_set != operand_value_set) {
542       value_set = operand_value_set;
543       changed = true;
544     }
545   }
546   return changed;
547 }
548 
UpdateDomainValueSet(HloInstruction * domain)549 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
550   // Domain instructions just forward their operand. Given that domains can have
551   // a tuple operand, we iterate through its indexes, like for copies.
552   // Unlike copies though we also propagate the top-level value.
553   CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
554   bool changed = false;
555   for (auto& pair : GetInstructionValueSet(domain)) {
556     const ShapeIndex& index = pair.first;
557     HloValueSet& value_set = pair.second;
558     HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
559     if (value_set != operand_value_set) {
560       value_set = operand_value_set;
561       changed = true;
562     }
563   }
564   return changed;
565 }
566 
UpdateAddDependencyValueSet(HloInstruction * add_dependency)567 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
568     HloInstruction* add_dependency) {
569   // AddDependency just forwards the value of its zero-th operand.
570   CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
571   const InstructionValueSet& operand_set =
572       GetInstructionValueSet(add_dependency->operand(0));
573   InstructionValueSet& add_dependency_set =
574       GetInstructionValueSet(add_dependency);
575   if (operand_set != add_dependency_set) {
576     add_dependency_set = operand_set;
577     return true;
578   }
579   return false;
580 }
581 
UpdateGetTupleElementValueSet(HloInstruction * gte)582 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
583   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
584   bool changed = false;
585   // The GetTupleElement instruction forwards the values from the specified
586   // tuple element.
587   for (auto& pair : GetInstructionValueSet(gte)) {
588     const ShapeIndex& index = pair.first;
589     HloValueSet& value_set = pair.second;
590 
591     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
592     // with the tuple element number prefixed.
593     ShapeIndex operand_index = {gte->tuple_index()};
594     for (int64 i : index) {
595       operand_index.push_back(i);
596     }
597 
598     HloValueSet& operand_value_set =
599         GetValueSet(gte->operand(0), operand_index);
600     if (value_set != operand_value_set) {
601       value_set = operand_value_set;
602       changed = true;
603     }
604   }
605   return changed;
606 }
607 
UpdateParameterValueSet(HloInstruction * parameter)608 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
609   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
610   const CallGraphNode& call_graph_node =
611       call_graph_->GetNode(parameter->parent());
612 
613   // Subcomputations called in a parallel context (eg, map) do not have dataflow
614   // from the caller operands.
615   if (call_graph_node.context() == CallContext::kParallel ||
616       call_graph_node.caller_callsites().empty()) {
617     return false;
618   }
619   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
620 
621   std::vector<const InstructionValueSet*> inputs;
622   bool need_phi = false;
623   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
624     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
625       // The operand values of a call instruction are forwarded to the
626       // respective parameter instruction of the subcomputation.
627       inputs.push_back(&GetInstructionValueSet(
628           callsite.instruction()->operand(parameter->parameter_number())));
629     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
630       // In a while instruction, the while operand (ie, the init value) and the
631       // backedge are dataflow inputs to the parameter instruction. This is the
632       // case for parameters of both the body and condition computations.
633       CHECK_EQ(parameter->parameter_number(), 0);
634       inputs.push_back(
635           &GetInstructionValueSet(callsite.instruction()->operand(0)));
636       // If the parameter *is not* the root, parameter state would be
637       // updated by the root, otherwise don't consider it's current state
638       // (InstructionValueSet) as we are recomputing its current state.
639       if (parameter !=
640           callsite.instruction()->while_body()->root_instruction()) {
641         inputs.push_back(&GetInstructionValueSet(
642             callsite.instruction()->while_body()->root_instruction()));
643       }
644       need_phi = true;
645     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
646       CHECK_EQ(parameter->parameter_number(), 0);
647       auto conditional = callsite.instruction();
648       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
649       // operands 1 and onward are the arguments to the branch computations.
650       //
651       // If the parameter belongs to conditional's branch 0 computation, then
652       // operand 1 is forwarded to this parameter instruction. If the parameter
653       // belongs to conditional's branch 5 computation, then operand 6 is
654       // forwarded to this parameter instruction.
655       bool found_parent = false;
656       for (int j = 0; j < conditional->branch_count(); ++j) {
657         if (parameter->parent() == conditional->branch_computation(j)) {
658           inputs.push_back(
659               &GetInstructionValueSet(conditional->operand(j + 1)));
660           found_parent = true;
661           break;
662         }
663       }
664       CHECK(found_parent);
665       need_phi = true;
666     } else {
667       LOG(FATAL) << "CallContext::kSequential computations should only be "
668                     "called from call, while, or conditional instructions";
669     }
670   }
671   if (ssa_form_ && need_phi) {
672     return Phi(parameter, inputs);
673   } else {
674     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
675   }
676 }
677 
UpdateTupleSelectValueSet(HloInstruction * select)678 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
679   CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
680   // A phi value is not defined at a kTupleSelect instruction because
681   // kTupleSelect does not create a new value. Rather it forwards a value from
682   // its operands. This contrasts with kWhile instruction (which does define a
683   // phi value) which has in-place update semantics.
684   bool changed = false;
685   for (auto& pair : GetInstructionValueSet(select)) {
686     const ShapeIndex& index = pair.first;
687     if (index.empty()) {
688       // kTupleSelect copies (not forwards) the top-level value.
689       continue;
690     }
691     HloValueSet& value_set = pair.second;
692     changed |=
693         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
694                                  &GetValueSet(select->operand(2), index)});
695   }
696   return changed;
697 }
698 
UpdateTupleValueSet(HloInstruction * tuple)699 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
700   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
701   bool changed = false;
702   for (int64 i = 0; i < tuple->operands().size(); ++i) {
703     // Copy the value set(s) of each operand into the respective position in the
704     // kTuple instruction's value sets.
705     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
706       const ShapeIndex& operand_index = pair.first;
707       HloValueSet& operand_value_set = pair.second;
708 
709       ShapeIndex index = {i};
710       for (int64 op_index : operand_index) {
711         index.push_back(op_index);
712       }
713       HloValueSet& value_set = GetValueSet(tuple, index);
714 
715       if (value_set != operand_value_set) {
716         value_set = operand_value_set;
717         changed = true;
718       }
719     }
720   }
721   return changed;
722 }
723 
UpdateWhileValueSet(HloInstruction * xla_while)724 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
725   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
726   const InstructionValueSet* const inputs[] = {
727       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
728       &GetInstructionValueSet(xla_while->operand(0))};
729   if (ssa_form_) {
730     return Phi(xla_while, inputs);
731   } else {
732     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
733   }
734 }
735 
UpdateCollectivePermuteStartValueSet(HloInstruction * collective_permute_start)736 bool HloDataflowAnalysis::UpdateCollectivePermuteStartValueSet(
737     HloInstruction* collective_permute_start) {
738   CHECK_EQ(collective_permute_start->opcode(),
739            HloOpcode::kCollectivePermuteStart);
740   bool changed = false;
741   // CollectivePermuteStart forwards the operand value to element {0} of its
742   // output.
743   const HloValueSet& operand_value_set =
744       GetValueSet(collective_permute_start->operand(0));
745   HloValueSet& value_set = GetValueSet(collective_permute_start, {0});
746   if (value_set != operand_value_set) {
747     value_set = operand_value_set;
748     changed = true;
749   }
750   return changed;
751 }
752 
UpdateCollectivePermuteDoneValueSet(HloInstruction * collective_permute_done)753 bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
754     HloInstruction* collective_permute_done) {
755   CHECK_EQ(collective_permute_done->opcode(),
756            HloOpcode::kCollectivePermuteDone);
757   bool changed = false;
758   // CollectivePermuteDone forwards the operand value at {1} to its output.
759   const HloValueSet& operand_value_set =
760       GetValueSet(collective_permute_done->operand(0), {1});
761   HloValueSet& value_set = GetValueSet(collective_permute_done);
762   if (value_set != operand_value_set) {
763     value_set = operand_value_set;
764     changed = true;
765   }
766   return changed;
767 }
768 
UpdateInstructionValueSet(HloInstruction * instruction)769 bool HloDataflowAnalysis::UpdateInstructionValueSet(
770     HloInstruction* instruction) {
771   // Recompute from operands.
772   switch (instruction->opcode()) {
773     case HloOpcode::kAddDependency:
774       return UpdateAddDependencyValueSet(instruction);
775     case HloOpcode::kBitcast:
776       return UpdateBitcastValueSet(instruction);
777     case HloOpcode::kCustomCall:
778       return UpdateCustomCallValueSet(instruction);
779     case HloOpcode::kSetDimensionSize:
780       return UpdateSetDimensionSizeValueSet(instruction);
781     case HloOpcode::kDomain:
782       return UpdateDomainValueSet(instruction);
783     case HloOpcode::kCopy:
784       return UpdateCopyValueSet(instruction);
785     case HloOpcode::kGetTupleElement:
786       return UpdateGetTupleElementValueSet(instruction);
787     case HloOpcode::kTupleSelect:
788       return UpdateTupleSelectValueSet(instruction);
789     case HloOpcode::kTuple:
790       return UpdateTupleValueSet(instruction);
791     case HloOpcode::kParameter:
792       return UpdateParameterValueSet(instruction);
793     case HloOpcode::kCall:
794       return UpdateCallValueSet(instruction);
795     case HloOpcode::kWhile:
796       return UpdateWhileValueSet(instruction);
797     case HloOpcode::kSend:
798       return UpdateSendValueSet(instruction);
799     case HloOpcode::kRecvDone:
800       return UpdateRecvDoneValueSet(instruction);
801     case HloOpcode::kCopyStart:
802       return UpdateCopyStartValueSet(instruction);
803     case HloOpcode::kCopyDone:
804       return UpdateCopyDoneValueSet(instruction);
805     case HloOpcode::kConditional:
806       return UpdateConditionalValueSet(instruction);
807     case HloOpcode::kCollectivePermuteStart:
808       return UpdateCollectivePermuteStartValueSet(instruction);
809     case HloOpcode::kCollectivePermuteDone:
810       return UpdateCollectivePermuteDoneValueSet(instruction);
811     default:
812       // Instruction does not forward HloValues (it defines all values in its
813       // output). No update is necessary.
814       return false;
815   }
816 }
817 
Propagate()818 void HloDataflowAnalysis::Propagate() {
819   using Work = std::pair<int64, HloInstruction*>;
820   // Avoid duplicating work by preferring work items early in the post order
821   // schedule. Intuitively, we start from entry parameters and propagate buffers
822   // updates throughout the module only once.
823   std::priority_queue<Work, std::vector<Work>, std::greater<Work>> worklist;
824   absl::flat_hash_set<HloInstruction*> workset;
825   auto priority_map = CalculatePostOrderSchedule(module_);
826   auto add_to_worklist = [&priority_map, &worklist,
827                           &workset](HloInstruction* instruction) {
828     if (workset.insert(instruction).second) {
829       worklist.emplace(priority_map[instruction], instruction);
830     }
831   };
832 
833   auto comps = module_.MakeComputationPostOrder();
834   for (HloComputation* computation : comps) {
835     for (HloInstruction* instruction :
836          computation->MakeInstructionPostOrder()) {
837       add_to_worklist(instruction);
838     }
839   }
840   VLOG(1) << "SSA_FORM_: " << ssa_form_;
841 
842   while (!worklist.empty()) {
843     HloInstruction* instruction = worklist.top().second;
844     auto add_to_worklist = [&](HloInstruction* todo) {
845       if (workset.insert(todo).second) {
846         VLOG(1) << "  Adding todo : " << todo->name();
847         worklist.emplace(priority_map[todo], todo);
848       }
849     };
850     worklist.pop();
851 
852     workset.erase(workset.find(instruction));
853 
854     VLOG(3) << "Worklist top: " << instruction->name();
855     VLOG(3) << ToString();
856 
857     if (!UpdateInstructionValueSet(instruction)) {
858       // No change to the instruction's value set.
859       VLOG(4) << "No change.";
860       continue;
861     }
862 
863     VLOG(4) << "New value set for " << instruction->name() << ": "
864             << GetInstructionValueSet(instruction);
865 
866     // Instruction value was updated. Add users to work list if we haven't
867     // already.
868     for (HloInstruction* user : instruction->users()) {
869       add_to_worklist(user);
870 
871       // If user sequentially calls a computation, then the respective
872       // parameter(s) of the computation need to be updated.
873       if (user->opcode() == HloOpcode::kConditional) {
874         // If operand 0 is the use of instruction, then no parameters need to be
875         // updated, since that is the branch_index of the conditional.
876         // If operand n+1 is the use of instruction, then the branch_computation
877         // n's parameter need to be updated.
878         //
879         // Note that the same instruction can be used in multiple branches'
880         // operands.
881         for (int j = 0; j < user->branch_count(); ++j) {
882           if (user->operand(j + 1) == instruction) {
883             add_to_worklist(
884                 user->branch_computation(j)->parameter_instruction(0));
885           }
886         }
887       } else {
888         for (HloComputation* called_computation : user->called_computations()) {
889           const CallGraphNode& call_graph_node =
890               call_graph_->GetNode(called_computation);
891           if (call_graph_node.context() == CallContext::kSequential) {
892             for (int64 operand_number : user->OperandIndices(instruction)) {
893               add_to_worklist(
894                   called_computation->parameter_instruction(operand_number));
895             }
896           }
897         }
898       }
899     }
900 
901     // If instruction is a root instruction, then propagate out to any calling
902     // instruction and across any while backedge.
903     if (instruction == instruction->parent()->root_instruction()) {
904       const CallGraphNode& call_graph_node =
905           call_graph_->GetNode(instruction->parent());
906       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
907         if (callsite.instruction()->opcode() == HloOpcode::kCall ||
908             callsite.instruction()->opcode() == HloOpcode::kConditional) {
909           add_to_worklist(callsite.instruction());
910         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
911           // Add the while itself, and the body and condition parameters.
912           add_to_worklist(callsite.instruction());
913           add_to_worklist(
914               callsite.instruction()->while_body()->parameter_instruction(0));
915           add_to_worklist(
916               callsite.instruction()->while_condition()->parameter_instruction(
917                   0));
918         }
919       }
920     }
921   }
922 }
923 
GetInstructionValueSet(const HloInstruction * instruction) const924 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
925     const HloInstruction* instruction) const {
926   return value_sets_.at(instruction);
927 }
928 
GetInstructionValueSet(const HloInstruction * instruction)929 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
930     const HloInstruction* instruction) {
931   return value_sets_.at(instruction);
932 }
933 
InitializeInstructionValueSets()934 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
935   for (const HloComputation* computation : module_.computations()) {
936     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
937     for (HloInstruction* instruction : computation->instructions()) {
938       // Create an empty shape tree.
939       value_sets_.emplace(std::piecewise_construct,
940                           std::forward_as_tuple(instruction),
941                           std::forward_as_tuple(instruction->shape()));
942 
943       // For each sub-shape of the instruction shape, add a new HloValue to its
944       // HloValueSet.
945       auto define_all_values = [this, &instruction]() {
946         for (auto& pair : GetInstructionValueSet(instruction)) {
947           const ShapeIndex& index = pair.first;
948           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
949           GetValueSet(instruction, index).AddValue(value);
950         }
951       };
952 
953       // Add a new HloValue to the HloValueSet corresponding to the given index
954       // of the instruction shape.
955       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
956         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
957         GetValueSet(instruction, index).AddValue(value);
958       };
959 
960       switch (instruction->opcode()) {
961         case HloOpcode::kBitcast:
962           if (bitcast_defines_value_) {
963             define_all_values();
964           }
965           break;
966         case HloOpcode::kSetDimensionSize:
967         case HloOpcode::kAddDependency:
968         case HloOpcode::kWhile:
969         case HloOpcode::kCall:
970         case HloOpcode::kConditional:
971         case HloOpcode::kGetTupleElement:
972         case HloOpcode::kDomain:
973           // These instructions define no values. The values in their output
974           // flow from their operands or from cross computation dataflow.
975           break;
976         case HloOpcode::kParameter:
977           if (call_graph_node.context() == CallContext::kBoth) {
978             // We do not support a subcomputation that is called from both a
979             // parallel and sequential context. In this case, the parameter
980             // would both define a value and propagate a value from its
981             // caller. This limitation is not really a problem because the call
982             // graph is typically flattened.
983             return Unimplemented(
984                 "Computation %s is called in both a parallel (eg, kMap) and "
985                 "sequential (eg, kCall) context",
986                 computation->name());
987           }
988           if (call_graph_node.caller_callsites().empty() ||
989               call_graph_node.context() == CallContext::kParallel) {
990             // Parameters of computations called in a parallel context (eg, map
991             // and reduce) as well as parameters of dead computations define all
992             // values in their output. Otherwise the values of the parameter
993             // come from the caller (eg, operands to the kCall instruction).
994             define_all_values();
995           }
996           break;
997         case HloOpcode::kCopy:
998         case HloOpcode::kTupleSelect:
999         case HloOpcode::kTuple:
1000           // These instructions only define their top-level values. Any other
1001           // values flow from their operands.
1002           define_value_at(/*index=*/{});
1003           break;
1004         case HloOpcode::kCopyStart:
1005           // CopyStart produces a tuple of {destination buffer, aliased operand,
1006           // U32 context}.
1007           define_value_at(/*index=*/{});
1008           define_value_at(/*index=*/{0});
1009           define_value_at(/*index=*/{2});
1010           break;
1011         case HloOpcode::kCopyDone:
1012           // CopyDone consumes a tuple produced by CopyStart and produces an
1013           // element. Its output aliases its input tuple element {0}.
1014           break;
1015         case HloOpcode::kCollectivePermuteStart:
1016           // CollectivePermuteStart produces a tuple of
1017           // {aliased operand, destination buffer, U32 context, U32 context}.
1018           define_value_at(/*index=*/{});
1019           define_value_at(/*index=*/{1});
1020           define_value_at(/*index=*/{2});
1021           define_value_at(/*index=*/{3});
1022           break;
1023         case HloOpcode::kCollectivePermuteDone:
1024           // CollectivePermuteDone's output aliases its input tuple element {1}.
1025           break;
1026         case HloOpcode::kRecvDone:
1027           // RecvDone produces a two-element tuple. Element zero aliases its
1028           // input tuple element {0}; element one is a token.
1029           define_value_at(/*index=*/{});
1030           define_value_at(/*index=*/{1});
1031           break;
1032         case HloOpcode::kSend:
1033           // Send produces a tuple of {aliased operand, U32 context, token},
1034           // therefore only defines the top-level tuple and the tuple elements
1035           // at {1} and {2}.
1036           define_value_at(/*index=*/{});
1037           define_value_at(/*index=*/{1});
1038           define_value_at(/*index=*/{2});
1039           break;
1040         case HloOpcode::kCustomCall: {
1041           absl::flat_hash_set<ShapeIndex> aliasing_indices;
1042           for (const auto& aliasing :
1043                Cast<HloCustomCallInstruction>(instruction)
1044                    ->output_to_operand_aliasing()) {
1045             aliasing_indices.insert(aliasing.first);
1046           }
1047           ShapeUtil::ForEachSubshape(
1048               instruction->shape(),
1049               [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1050                 if (!aliasing_indices.contains(index)) {
1051                   define_value_at(index);
1052                 }
1053               });
1054           break;
1055         }
1056         default:
1057           define_all_values();
1058           break;
1059       }
1060     }
1061   }
1062 
1063   return Status::OK();
1064 }
1065 
OptimizePhiValues()1066 void HloDataflowAnalysis::OptimizePhiValues() {
1067   // Only applicable to SSA form where phis are defined.
1068   if (!ssa_form_) {
1069     return;
1070   }
1071 
1072   VLOG(1) << "Before phi graph optimization";
1073   XLA_VLOG_LINES(1, phi_graph_.ToString());
1074   phi_graph_.Optimize();
1075   VLOG(1) << "After phi graph optimization";
1076   XLA_VLOG_LINES(1, phi_graph_.ToString());
1077 
1078   for (const HloComputation* computation : module_.computations()) {
1079     for (HloInstruction* instruction : computation->instructions()) {
1080       InstructionValueSet& instruction_value_set =
1081           GetInstructionValueSet(instruction);
1082       VLOG(1) << "inst: " << instruction->name();
1083       VLOG(1) << instruction_value_set.ToString();
1084       instruction_value_set.ForEachMutableElement(
1085           [&](const xla::ShapeIndex& index, HloValueSet* value_set) {
1086             auto values = value_set->values();
1087             if (!(values.size() == 1 && values[0]->is_phi())) {
1088               return;
1089             }
1090             HloValue::Id phi_id = values[0]->id();
1091             HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
1092             if (new_id != phi_id) {
1093               VLOG(1) << "Replacing " << values[0]->ToString() << " with "
1094                       << GetValue(new_id).ToString();
1095               value_set->Clear();
1096               const HloValue& new_value = GetValue(new_id);
1097               value_set->AddValue(&new_value);
1098               MarkValueForDeletion(phi_id);
1099             }
1100           });
1101     }
1102   }
1103 }
1104 
1105 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const CanShareBuffer & can_share_buffer)1106 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
1107     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
1108     const CanShareBuffer& can_share_buffer) {
1109   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
1110   XLA_VLOG_LINES(2, module.ToString());
1111 
1112   auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
1113       module, ssa_form, bitcast_defines_value, can_share_buffer));
1114 
1115   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
1116   dataflow_analysis->Propagate();
1117   dataflow_analysis->OptimizePhiValues();
1118 
1119   // Delete all values marked for deletion.
1120   dataflow_analysis->DeleteMarkedValues();
1121 
1122   // Gather and set all non-definition positions of all values. Value deletion
1123   // is rare, so just use a vector indexed by Value::Id rather than a map from
1124   // Value::Id to positions. There should be very few holes in the vector, and
1125   // lookup is faster.
1126   std::vector<std::vector<HloPosition>> value_positions(
1127       dataflow_analysis->next_value_id_);
1128   for (const HloComputation* computation : module.computations()) {
1129     for (HloInstruction* instruction : computation->instructions()) {
1130       for (const auto& pair :
1131            dataflow_analysis->GetInstructionValueSet(instruction)) {
1132         const ShapeIndex& index = pair.first;
1133         const HloValueSet& value_set = pair.second;
1134         for (const HloValue* value : value_set.values()) {
1135           if (value->defining_instruction() != instruction) {
1136             value_positions[value->id()].push_back(
1137                 HloPosition{instruction, index});
1138           }
1139         }
1140       }
1141     }
1142   }
1143   for (auto& pair : dataflow_analysis->values_) {
1144     HloValue::Id value_id = pair.first;
1145     HloValue& value = pair.second;
1146     value.SetPositionsAndComputeUses(value_positions[value_id]);
1147   }
1148 
1149   // Construct vector of values.
1150   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
1151   for (auto& pair : dataflow_analysis->values_) {
1152     dataflow_analysis->values_vector_.push_back(&pair.second);
1153   }
1154   absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
1155 
1156   TF_DCHECK_OK(dataflow_analysis->Verify());
1157 
1158   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
1159 
1160   return std::move(dataflow_analysis);
1161 }
1162 
Verify() const1163 Status HloDataflowAnalysis::Verify() const {
1164   // Verify each HloValue appears in the value sets that the value's positions()
1165   // indicate.
1166   for (const HloValue* value : values()) {
1167     for (const HloPosition& position : value->positions()) {
1168       const HloValueSet& value_set = GetValueSet(position);
1169       TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
1170           << "Value set at position " << position << " does not contain value "
1171           << value->ToShortString();
1172     }
1173   }
1174 
1175   // For each value in each value set, verify that the value set's position
1176   // appears in the value's positions().
1177   for (const auto& computation : module_.computations()) {
1178     for (const auto& instruction : computation->instructions()) {
1179       for (const auto& pair : GetInstructionValueSet(instruction)) {
1180         const ShapeIndex& index = pair.first;
1181         const HloValueSet& value_set = pair.second;
1182         const HloPosition position{instruction, index};
1183         for (const HloValue* value : value_set.values()) {
1184           TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
1185               << "Value set at position " << position
1186               << " unexpectedly contains value " << value->ToShortString();
1187         }
1188       }
1189     }
1190   }
1191 
1192   return Status::OK();
1193 }
1194 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const1195 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
1196     const HloInstruction* operand, const ShapeIndex& index,
1197     const HloInstruction* user) const {
1198   // Return false if no value at 'operand' and 'index' is used at 'user'.
1199   for (const HloValue* value : GetValueSet(operand, index).values()) {
1200     for (const HloUse& use : value->uses()) {
1201       if (use.instruction == user) {
1202         if (user->IsLoopFusion()) {
1203           HloInstruction* fusion_param =
1204               user->fused_parameter(use.operand_number);
1205           const HloValue& value =
1206               GetValueDefinedAt(fusion_param, use.operand_index);
1207           return value.uses().empty();
1208         }
1209         return false;
1210       }
1211     }
1212   }
1213   return true;
1214 }
1215 
IsInPlaceOperation(HloOpcode opcode)1216 /*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
1217   return opcode == HloOpcode::kDynamicUpdateSlice ||
1218          opcode == HloOpcode::kScatter;
1219 }
1220 
1221 /*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
GetInPlaceInputOutputPairs(HloInstruction * instruction)1222 HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
1223   if (IsInPlaceOperation(instruction->opcode())) {
1224     return {{HloUse{instruction, 0, {}}, {}}};
1225   } else if (instruction->opcode() != HloOpcode::kFusion) {
1226     return {};
1227   }
1228   std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
1229   for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
1230     const HloInstruction* hlo_generating_output =
1231         instruction->fused_expression_root();
1232     for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
1233       if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
1234         hlo_generating_output =
1235             hlo_generating_output->operand(indexed_shape.index[i]);
1236       } else {
1237         CHECK_EQ(i, indexed_shape.index.size() - 1);
1238       }
1239     }
1240 
1241     if (IsInPlaceOperation(hlo_generating_output->opcode())) {
1242       ShapeIndex operand_index;
1243       const HloInstruction* fusion_parameter =
1244           hlo_generating_output->operand(0);
1245       while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
1246         operand_index.push_front(fusion_parameter->tuple_index());
1247         fusion_parameter = fusion_parameter->operand(0);
1248       }
1249 
1250       if (fusion_parameter->opcode() == HloOpcode::kParameter) {
1251         input_output_pairs.emplace_back(
1252             HloUse{instruction, fusion_parameter->parameter_number(),
1253                    operand_index},
1254             indexed_shape.index);
1255       }
1256     }
1257   }
1258   return input_output_pairs;
1259 }
1260 
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1261 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1262     HloInstruction* operand, const ShapeIndex& operand_index,
1263     HloInstruction* user, const ShapeIndex& user_index) const {
1264   CHECK(user->IsUserOf(operand))
1265       << "user: " << user->ToString() << " operand: " << operand->ToString();
1266   if (operand->opcode() == HloOpcode::kConstant) {
1267     return false;
1268   }
1269   const Shape& operand_subshape =
1270       ShapeUtil::GetSubshape(operand->shape(), operand_index);
1271   const Shape& user_subshape =
1272       ShapeUtil::GetSubshape(user->shape(), user_index);
1273 
1274   // Check that operand and user emit the same shape and layout.
1275   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1276     return false;
1277   }
1278 
1279   // Must-alias relationship returns true for in-place operations (DUS and DUS
1280   // fusions), regardless of the backend.
1281   for (const auto& operand_and_output_index :
1282        GetInPlaceInputOutputPairs(user)) {
1283     if (operand_and_output_index.second != user_index) {
1284       continue;
1285     }
1286     for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
1287       if (use == operand_and_output_index.first) {
1288         return true;
1289       }
1290     }
1291   }
1292 
1293   if (can_share_buffer_ != nullptr) {
1294     if (absl::optional<bool> hint =
1295             can_share_buffer_(user, operand, user_index)) {
1296       return *hint;
1297     }
1298   }
1299 
1300   if (user->opcode() == HloOpcode::kFusion) {
1301     HloInstruction* fusion_param =
1302         user->fused_parameter(user->operand_index(operand));
1303     const HloValue& fusion_param_value =
1304         GetValueDefinedAt(fusion_param, operand_index);
1305 
1306     if (user->IsLoopFusion() || user->IsInputFusion()) {
1307       return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1308     }
1309 
1310     if (user->IsOutputFusion() &&
1311         user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1312       // Output fusion with kAdd fused root.
1313 
1314       // Check if one operand of kAdd fused root is kDot or kConvolution.
1315       auto* add = user->fused_expression_root();
1316       auto add_operand_it =
1317           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1318             return operand->opcode() == HloOpcode::kConvolution ||
1319                    operand->opcode() == HloOpcode::kDot;
1320           });
1321       if (add_operand_it == add->operands().end()) {
1322         return false;
1323       }
1324       auto* matched_add_operand = *add_operand_it;
1325       // Calculate operand index of 'add' operand which was not matched above.
1326       const int64 other_add_operand_index =
1327           matched_add_operand == add->operand(0) ? 1 : 0;
1328       // Returns true iff there is exactly one use of 'operand' at shape index
1329       // 'operand_index', and this singleton use is the fused root (at operand
1330       // index 'other_add_operand_index').
1331       if (fusion_param_value.uses().size() == 1) {
1332         const HloUse& use = fusion_param_value.uses()[0];
1333         return use.instruction == user->fused_expression_root() &&
1334                use.operand_number == other_add_operand_index;
1335       }
1336       return false;
1337     }
1338   }
1339 
1340   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1341       user->opcode() == HloOpcode::kScatter ||
1342       user->opcode() == HloOpcode::kTriangularSolve ||
1343       user->opcode() == HloOpcode::kWhile) {
1344     // We eliminated other users in HloOrdering::LiveRangeStrictlyBefore
1345     // so here we just need to check that the use is at the right operand index.
1346     const auto operand_indices = user->OperandIndices(operand);
1347     int64 operand_no = user->opcode() == HloOpcode::kTriangularSolve ? 1 : 0;
1348     return operand_indices.size() == 1 && operand_indices[0] == operand_no;
1349   }
1350   if (user->opcode() == HloOpcode::kSort) {
1351     // Only valid if there are no other users.
1352     if (operand->users().size() != 1) {
1353       return false;
1354     }
1355     // If we only sort keys, the output of sort is not a tuple, so we can always
1356     // share the buffer.
1357     if (user->operand_count() == 1) {
1358       return true;
1359     }
1360     CHECK(!user_index.empty());
1361     // Only share with the right tuple element buffer.
1362     const auto operand_indices = user->OperandIndices(operand);
1363     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1364   }
1365   if (user->opcode() == HloOpcode::kCall) {
1366     // Get all uses of value defined by 'operand' at 'operand_index'.
1367     const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1368     // Return true iff:
1369     // *) There exists two uses of 'operand'.
1370     // *) One use is by 'user' (caller).
1371     // *) One use is by root instruction of called computation (callee root).
1372     //    (Note: we check the root of the called computation, because the
1373     //     root result buffer is required to alias with the Call result buffer).
1374     // *) The root instruction of the called computation is element-wise on
1375     //    'operand'.
1376     const bool found_caller_use =
1377         absl::c_find_if(uses, [user](const HloUse& use) {
1378           return use.instruction == user;
1379         }) != uses.end();
1380     auto* callee_root = user->to_apply()->root_instruction();
1381     const bool found_elementwise_callee_use =
1382         absl::c_find_if(uses, [callee_root](const HloUse& use) {
1383           return use.instruction == callee_root &&
1384                  callee_root->IsElementwiseOnOperand(use.operand_number);
1385         }) != uses.end();
1386     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1387   }
1388 
1389   // Loop fusions that contain transposing copies won't reach here as they have
1390   // different layouts, which fails the check in the beginning of this function.
1391   return user->IsElementwiseOnOperand(user->operand_index(operand));
1392 }
1393 
1394 }  // namespace xla
1395