1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
17 
18 #include <algorithm>
19 #include <queue>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/container/inlined_vector.h"
24 #include "absl/memory/memory.h"
25 #include "absl/strings/str_cat.h"
26 #include "tensorflow/compiler/xla/map_util.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status.h"
34 #include "tensorflow/compiler/xla/types.h"
35 #include "tensorflow/compiler/xla/util.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
41 using absl::StrAppend;
42 using absl::StrCat;
43 
HloDataflowAnalysis(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const FusionCanShareBufferFunction & fusion_can_share_buffer)44 HloDataflowAnalysis::HloDataflowAnalysis(
45     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
46     const FusionCanShareBufferFunction& fusion_can_share_buffer)
47     : module_(module),
48       ssa_form_(ssa_form),
49       bitcast_defines_value_(bitcast_defines_value),
50       call_graph_(CallGraph::Build(&module)),
51       fusion_can_share_buffer_(fusion_can_share_buffer) {}
52 
AreTransitiveUsesElementwiseOrTuple(const HloInstruction * inst)53 bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
54     const HloInstruction* inst) {
55   absl::flat_hash_set<const HloInstruction*> visited;
56   absl::InlinedVector<const HloInstruction*, 4> stack;
57   stack.push_back(inst);
58   while (!stack.empty()) {
59     const HloInstruction* current = stack.back();
60     stack.pop_back();
61     visited.insert(current);
62     for (const HloInstruction* user : current->users()) {
63       // Found a user that is non-elementwise on current instruction.
64       for (const int64 use_index : user->OperandIndices(current)) {
65         if (!user->IsElementwiseOnOperand(use_index) &&
66             user->opcode() != HloOpcode::kTuple) {
67           return false;
68         }
69       }
70       if (!visited.contains(user)) {
71         stack.push_back(user);
72       }
73     }
74   }
75   return true;
76 }
77 
ValueIsDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const78 bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
79                                            const ShapeIndex& index) const {
80   const HloValueSet& value_set = GetValueSet(instruction, index);
81   if (value_set.values().size() != 1) {
82     return false;
83   }
84   return value_set.GetUniqueValue().defining_instruction() == instruction;
85 }
86 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index) const87 const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
88     const HloInstruction* instruction, const ShapeIndex& index) const {
89   CHECK(ValueIsDefinedAt(instruction, index)) << instruction->ToString();
90   return GetUniqueValueAt(instruction, index);
91 }
92 
GetValueDefinedAt(const HloInstruction * instruction,const ShapeIndex & index)93 HloValue& HloDataflowAnalysis::GetValueDefinedAt(
94     const HloInstruction* instruction, const ShapeIndex& index) {
95   CHECK(ValueIsDefinedAt(instruction, index));
96   return GetUniqueValueAt(instruction, index);
97 }
98 
NewHloValue(HloInstruction * instruction,const ShapeIndex & index,bool is_phi)99 HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
100                                            const ShapeIndex& index,
101                                            bool is_phi) {
102   const int64 value_id = next_value_id_++;
103   auto emplaced = values_.emplace(
104       std::piecewise_construct, std::forward_as_tuple(value_id),
105       std::forward_as_tuple(value_id, instruction, index, is_phi));
106   CHECK(emplaced.second);
107 
108   VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
109 
110   return &emplaced.first->second;
111 }
112 
MarkValueForDeletion(HloValue::Id value_id)113 void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
114   HloValue& value = values_.at(value_id);
115   VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
116 
117   value_ids_to_delete_.push_back(value_id);
118 }
119 
DeleteMarkedValues()120 void HloDataflowAnalysis::DeleteMarkedValues() {
121 #ifndef NDEBUG
122   // Verify that no marked-for-deletion values are in any of the value sets.
123   absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
124                                            value_ids_to_delete_.end());
125   for (const auto& pair : value_sets_) {
126     const HloInstruction* instruction = pair.first;
127     const InstructionValueSet& instruction_value_set = pair.second;
128     for (const auto& index_value_set : instruction_value_set) {
129       const HloValueSet& value_set = index_value_set.second;
130       for (const HloValue* value : value_set.values()) {
131         DCHECK(!ContainsKey(id_set, value->id()))
132             << "Value " << value->ToShortString()
133             << " marked for deletion, but still exists in value set for "
134                "instruction "
135             << instruction->name();
136       }
137     }
138   }
139 #endif
140 
141   for (HloValue::Id value_id : value_ids_to_delete_) {
142     values_.erase(value_id);
143   }
144   value_ids_to_delete_.clear();
145 }
146 
ToString() const147 string HloDataflowAnalysis::ToString() const {
148   string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n");
149   StrAppend(&out, "  Instruction value sets:\n");
150   for (const HloComputation* computation : module_.computations()) {
151     for (const HloInstruction* instruction : computation->instructions()) {
152       StrAppend(&out, "    ", instruction->name(), ":\n");
153       if (instruction->shape().IsTuple()) {
154         GetInstructionValueSet(instruction)
155             .ForEachElement([this, &instruction, &out](
156                                 const ShapeIndex& index,
157                                 const HloValueSet& value_set) {
158               StrAppend(&out, "      tuple index ", index.ToString(), ":\n");
159               for (const HloValue* value : value_set.values()) {
160                 StrAppend(&out, "        ", value->ToShortString(),
161                           ValueIsDefinedAt(instruction, index) ? " (def)" : "",
162                           "\n");
163               }
164             });
165       } else {
166         const HloValueSet& top_level_value_set =
167             GetValueSet(instruction, /*index=*/{});
168         for (const HloValue* value : top_level_value_set.values()) {
169           StrAppend(&out, "      ", value->ToShortString(),
170                     ValueIsDefinedAt(instruction) ? " (def)" : "", "\n");
171         }
172       }
173     }
174   }
175   StrAppend(&out, "  HloValues:\n");
176   for (const HloValue* value : values()) {
177     StrAppend(&out, value->ToString(/*indent=*/4));
178   }
179   return out;
180 }
181 
Phi(HloInstruction * instruction,absl::Span<const InstructionValueSet * const> inputs)182 bool HloDataflowAnalysis::Phi(
183     HloInstruction* instruction,
184     absl::Span<const InstructionValueSet* const> inputs) {
185   CHECK(ssa_form_);
186   VLOG(4) << "Phi(" << instruction->name() << ")";
187   VLOG(5) << "instruction value set = "
188           << GetInstructionValueSet(instruction).ToString();
189   for (const InstructionValueSet* input : inputs) {
190     VLOG(5) << "input value set = " << input->ToString();
191   }
192   for (const InstructionValueSet* input : inputs) {
193     DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
194   }
195 
196   bool changed = false;
197   for (auto& pair : GetInstructionValueSet(instruction)) {
198     const ShapeIndex& index = pair.first;
199     HloValueSet& value_set = pair.second;
200 
201     // Positions with phi values should never have more than one value in the
202     // value set.
203     CHECK_LE(value_set.values().size(), 1);
204     const HloValue* current_value =
205         value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
206 
207     // Construct a vector of unique value IDs of the inputs.
208     // Don't add value ids where the input is equal to the definition.
209     std::vector<HloValue::Id> input_value_ids;
210     for (const InstructionValueSet* input : inputs) {
211       for (const HloValue* value : input->element(index).values()) {
212         if (value->defining_instruction() == instruction &&
213             value->defining_index() == index) {
214           continue;
215         }
216         input_value_ids.push_back(value->id());
217       }
218     }
219     absl::c_sort(input_value_ids);
220     input_value_ids.erase(
221         std::unique(input_value_ids.begin(), input_value_ids.end()),
222         input_value_ids.end());
223 
224     // Remove the existing phi value (if it exists). The phi can be its own
225     // input, for example, in while body parameters where the body passes
226     // through the parameter value.
227     bool current_value_defined_here =
228         (current_value != nullptr &&
229          current_value->defining_instruction() == instruction &&
230          current_value->defining_index() == index);
231     if (current_value_defined_here) {
232       VLOG(5) << "current_value_defined_here: " << current_value->ToString();
233       CHECK(current_value->is_phi());
234       auto it = absl::c_find(input_value_ids, current_value->id());
235       if (it != input_value_ids.end()) {
236         input_value_ids.erase(it);
237       }
238     }
239     VLOG(5) << "after input_value_ids.size = " << input_value_ids.size();
240     if (input_value_ids.empty()) {
241       // A value set which has at least one element should never have its value
242       // set reduced to zero elements. During dataflow value sets only can go
243       // from empty to non-empty, not the reverse.
244       CHECK_EQ(value_set.values().size(), 0)
245           << "Instruction " << instruction->name() << " at index " << index
246           << " previously had non-empty value set. Value set: " << value_set;
247     } else if (input_value_ids.size() == 1) {
248       // Only a single value reaches this point. There should be no phi, and
249       // this value set should contain this single value.
250       const HloValue& new_value = GetValue(input_value_ids[0]);
251       if (current_value == nullptr) {
252         value_set.Clear();
253         value_set.AddValue(&new_value);
254         changed = true;
255       } else if (current_value != &new_value) {
256         if (current_value_defined_here) {
257           // Remove the existing phi.
258           MarkValueForDeletion(current_value->id());
259         }
260         value_set.Clear();
261         value_set.AddValue(&new_value);
262         changed = true;
263       }
264     } else {
265       // Multiple distinct values reach this point. A phi value is
266       // necessary.
267       CHECK_GT(input_value_ids.size(), 1);
268       if (current_value == nullptr ||
269           !(current_value->is_phi() && current_value_defined_here)) {
270         value_set.Clear();
271         value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
272         changed = true;
273       }
274     }
275   }
276   return changed;
277 }
278 
GetValue(HloValue::Id value_id) const279 const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
280   return values_.at(value_id);
281 }
282 
GetValue(HloValue::Id value_id)283 HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) {
284   return values_.at(value_id);
285 }
286 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index) const287 const HloValueSet& HloDataflowAnalysis::GetValueSet(
288     const HloInstruction* instruction, const ShapeIndex& index) const {
289   return GetInstructionValueSet(instruction).element(index);
290 }
291 
GetValueSet(const HloInstruction * instruction,const ShapeIndex & index)292 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
293                                               const ShapeIndex& index) {
294   return *GetInstructionValueSet(instruction).mutable_element(index);
295 }
296 
GetValueSet(const HloPosition & position) const297 const HloValueSet& HloDataflowAnalysis::GetValueSet(
298     const HloPosition& position) const {
299   return GetValueSet(position.instruction, position.index);
300 }
301 
GetValueSet(const HloPosition & position)302 HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
303   return GetValueSet(position.instruction, position.index);
304 }
305 
UpdateBitcastValueSet(HloInstruction * bitcast)306 bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
307   CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
308   const InstructionValueSet& operand_set =
309       GetInstructionValueSet(bitcast->operand(0));
310   InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
311   if (!bitcast_defines_value_ && operand_set != bitcast_set) {
312     bitcast_set = operand_set;
313     return true;
314   }
315   return false;
316 }
317 
UpdateSendValueSet(HloInstruction * send)318 bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) {
319   CHECK_EQ(send->opcode(), HloOpcode::kSend);
320   bool changed = false;
321   // Send forwards the operand value to the output tuple at {0}.
322   for (auto& pair : GetInstructionValueSet(send->operand(0))) {
323     const ShapeIndex& operand_index = pair.first;
324     const HloValueSet& operand_value_set = pair.second;
325 
326     ShapeIndex index = {0};
327     for (int64 i : operand_index) {
328       index.push_back(i);
329     }
330 
331     HloValueSet& value_set = GetValueSet(send, index);
332     if (value_set != operand_value_set) {
333       value_set = operand_value_set;
334       changed = true;
335     }
336   }
337   return changed;
338 }
339 
UpdateRecvDoneValueSet(HloInstruction * recv_done)340 bool HloDataflowAnalysis::UpdateRecvDoneValueSet(HloInstruction* recv_done) {
341   CHECK_EQ(recv_done->opcode(), HloOpcode::kRecvDone);
342   bool changed = false;
343   // RecvDone forwards the operand value at {0} to element {0} of its output.
344   for (auto& pair : GetInstructionValueSet(recv_done)) {
345     ShapeIndex& index = pair.first;
346     HloValueSet& value_set = pair.second;
347 
348     if (index.empty() || index[0] != 0) {
349       continue;
350     }
351 
352     const HloValueSet& operand_value_set =
353         GetValueSet(recv_done->operand(0), index);
354     if (value_set != operand_value_set) {
355       value_set = operand_value_set;
356       changed = true;
357     }
358   }
359   return changed;
360 }
361 
UpdateCallValueSet(HloInstruction * call)362 bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
363   CHECK_EQ(call->opcode(), HloOpcode::kCall);
364   InstructionValueSet& value_set = GetInstructionValueSet(call);
365   InstructionValueSet& root_value_set =
366       GetInstructionValueSet(call->to_apply()->root_instruction());
367   if (value_set != root_value_set) {
368     value_set = root_value_set;
369     return true;
370   }
371   return false;
372 }
373 
UpdateConditionalValueSet(HloInstruction * conditional)374 bool HloDataflowAnalysis::UpdateConditionalValueSet(
375     HloInstruction* conditional) {
376   CHECK_EQ(conditional->opcode(), HloOpcode::kConditional);
377   std::vector<const InstructionValueSet*> inputs(conditional->branch_count());
378   for (int j = 0; j < conditional->branch_count(); ++j) {
379     inputs[j] = &GetInstructionValueSet(
380         conditional->branch_computation(j)->root_instruction());
381   }
382   if (ssa_form_) {
383     return Phi(conditional, inputs);
384   } else {
385     return GetInstructionValueSet(conditional).AssignUnionOf(inputs);
386   }
387 }
388 
UpdateCopyValueSet(HloInstruction * copy)389 bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
390   CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
391   bool changed = false;
392   for (auto& pair : GetInstructionValueSet(copy)) {
393     const ShapeIndex& index = pair.first;
394     if (index.empty()) {
395       // kCopy shallow copies and thus defines the top-level value so nothing to
396       // update.
397       continue;
398     }
399 
400     HloValueSet& value_set = pair.second;
401     HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
402     if (value_set != operand_value_set) {
403       value_set = operand_value_set;
404       changed = true;
405     }
406   }
407   return changed;
408 }
409 
UpdateDomainValueSet(HloInstruction * domain)410 bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
411   // Domain instructions just forward their operand. Given that domains can have
412   // a tuple operand, we iterate through its indexes, like for copies.
413   // Unlike copies though we also propagate the top-level value.
414   CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
415   bool changed = false;
416   for (auto& pair : GetInstructionValueSet(domain)) {
417     const ShapeIndex& index = pair.first;
418     HloValueSet& value_set = pair.second;
419     HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
420     if (value_set != operand_value_set) {
421       value_set = operand_value_set;
422       changed = true;
423     }
424   }
425   return changed;
426 }
427 
UpdateAddDependencyValueSet(HloInstruction * add_dependency)428 bool HloDataflowAnalysis::UpdateAddDependencyValueSet(
429     HloInstruction* add_dependency) {
430   // AddDependency just forwards the value of its zero-th operand.
431   CHECK_EQ(add_dependency->opcode(), HloOpcode::kAddDependency);
432   const InstructionValueSet& operand_set =
433       GetInstructionValueSet(add_dependency->operand(0));
434   InstructionValueSet& add_dependency_set =
435       GetInstructionValueSet(add_dependency);
436   if (operand_set != add_dependency_set) {
437     add_dependency_set = operand_set;
438     return true;
439   }
440   return false;
441 }
442 
UpdateGetTupleElementValueSet(HloInstruction * gte)443 bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
444   CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
445   bool changed = false;
446   // The GetTupleElement instruction forwards the values from the specified
447   // tuple element.
448   for (auto& pair : GetInstructionValueSet(gte)) {
449     const ShapeIndex& index = pair.first;
450     HloValueSet& value_set = pair.second;
451 
452     // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
453     // with the tuple element number prefixed.
454     ShapeIndex operand_index = {gte->tuple_index()};
455     for (int64 i : index) {
456       operand_index.push_back(i);
457     }
458 
459     HloValueSet& operand_value_set =
460         GetValueSet(gte->operand(0), operand_index);
461     if (value_set != operand_value_set) {
462       value_set = operand_value_set;
463       changed = true;
464     }
465   }
466   return changed;
467 }
468 
UpdateParameterValueSet(HloInstruction * parameter)469 bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
470   CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
471   const CallGraphNode& call_graph_node =
472       call_graph_->GetNode(parameter->parent());
473 
474   // Subcomputations called in a parallel context (eg, map) do not have dataflow
475   // from the caller operands.
476   if (call_graph_node.context() == CallContext::kParallel ||
477       call_graph_node.caller_callsites().empty()) {
478     return false;
479   }
480   CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
481 
482   std::vector<const InstructionValueSet*> inputs;
483   bool need_phi = false;
484   for (const CallSite& callsite : call_graph_node.caller_callsites()) {
485     if (callsite.instruction()->opcode() == HloOpcode::kCall) {
486       // The operand values of a call instruction are forwarded to the
487       // respective parameter instruction of the subcomputation.
488       inputs.push_back(&GetInstructionValueSet(
489           callsite.instruction()->operand(parameter->parameter_number())));
490     } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
491       // In a while instruction, the while operand (ie, the init value) and the
492       // backedge are dataflow inputs to the parameter instruction. This is the
493       // case for parameters of both the body and condition computations.
494       CHECK_EQ(parameter->parameter_number(), 0);
495       inputs.push_back(
496           &GetInstructionValueSet(callsite.instruction()->operand(0)));
497       // If the parameter *is* the root, then don't consider it's current state
498       // (InstructionValueSet) as we are recomputing its current
499       // state. Otherwise, the parameter state would never be updated.
500       if (parameter !=
501           callsite.instruction()->while_body()->root_instruction()) {
502         inputs.push_back(&GetInstructionValueSet(
503             callsite.instruction()->while_body()->root_instruction()));
504       }
505       need_phi = true;
506     } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) {
507       CHECK_EQ(parameter->parameter_number(), 0);
508       auto conditional = callsite.instruction();
509       // Conditional has branch_count+1 operands. Operand 0 is the branch_index,
510       // operands 1 and onward are the arguments to the branch computations.
511       //
512       // If the parameter belongs to conditional's branch 0 computation, then
513       // operand 1 is forwarded to this parameter instruction. If the parameter
514       // belongs to conditional's branch 5 computation, then operand 6 is
515       // forwarded to this parameter instruction.
516       bool found_parent = false;
517       for (int j = 0; j < conditional->branch_count(); ++j) {
518         if (parameter->parent() == conditional->branch_computation(j)) {
519           inputs.push_back(
520               &GetInstructionValueSet(conditional->operand(j + 1)));
521           found_parent = true;
522           break;
523         }
524       }
525       CHECK(found_parent);
526       need_phi = true;
527     } else {
528       LOG(FATAL) << "CallContext::kSequential computations should only be "
529                     "called from call, while, or conditional instructions";
530     }
531   }
532 
533   if (ssa_form_ && need_phi) {
534     return Phi(parameter, inputs);
535   } else {
536     return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
537   }
538 }
539 
UpdateTupleSelectValueSet(HloInstruction * select)540 bool HloDataflowAnalysis::UpdateTupleSelectValueSet(HloInstruction* select) {
541   CHECK_EQ(select->opcode(), HloOpcode::kTupleSelect);
542   // A phi value is not defined at a kTupleSelect instruction because
543   // kTupleSelect does not create a new value. Rather it forwards a value from
544   // its operands. This contrasts with kWhile instruction (which does define a
545   // phi value) which has in-place update semantics.
546   bool changed = false;
547   for (auto& pair : GetInstructionValueSet(select)) {
548     const ShapeIndex& index = pair.first;
549     if (index.empty()) {
550       // kTupleSelect copies (not forwards) the top-level value.
551       continue;
552     }
553     HloValueSet& value_set = pair.second;
554     changed |=
555         value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
556                                  &GetValueSet(select->operand(2), index)});
557   }
558   return changed;
559 }
560 
UpdateTupleValueSet(HloInstruction * tuple)561 bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
562   CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
563   bool changed = false;
564   for (int64 i = 0; i < tuple->operands().size(); ++i) {
565     // Copy the value set(s) of each operand into the respective position in the
566     // kTuple instruction's value sets.
567     for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
568       const ShapeIndex& operand_index = pair.first;
569       HloValueSet& operand_value_set = pair.second;
570 
571       ShapeIndex index = {i};
572       for (int64 op_index : operand_index) {
573         index.push_back(op_index);
574       }
575       HloValueSet& value_set = GetValueSet(tuple, index);
576 
577       if (value_set != operand_value_set) {
578         value_set = operand_value_set;
579         changed = true;
580       }
581     }
582   }
583   return changed;
584 }
585 
UpdateWhileValueSet(HloInstruction * xla_while)586 bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
587   CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
588   const InstructionValueSet* const inputs[] = {
589       &GetInstructionValueSet(xla_while->while_body()->root_instruction()),
590       &GetInstructionValueSet(xla_while->operand(0))};
591   if (ssa_form_) {
592     return Phi(xla_while, inputs);
593   } else {
594     return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
595   }
596 }
597 
UpdateInstructionValueSet(HloInstruction * instruction)598 bool HloDataflowAnalysis::UpdateInstructionValueSet(
599     HloInstruction* instruction) {
600   // Recompute from operands.
601   switch (instruction->opcode()) {
602     case HloOpcode::kAddDependency:
603       return UpdateAddDependencyValueSet(instruction);
604     case HloOpcode::kBitcast:
605       return UpdateBitcastValueSet(instruction);
606     case HloOpcode::kDomain:
607       return UpdateDomainValueSet(instruction);
608     case HloOpcode::kCopy:
609       return UpdateCopyValueSet(instruction);
610     case HloOpcode::kGetTupleElement:
611       return UpdateGetTupleElementValueSet(instruction);
612     case HloOpcode::kTupleSelect:
613       return UpdateTupleSelectValueSet(instruction);
614     case HloOpcode::kTuple:
615       return UpdateTupleValueSet(instruction);
616     case HloOpcode::kParameter:
617       return UpdateParameterValueSet(instruction);
618     case HloOpcode::kCall:
619       return UpdateCallValueSet(instruction);
620     case HloOpcode::kWhile:
621       return UpdateWhileValueSet(instruction);
622     case HloOpcode::kSend:
623       return UpdateSendValueSet(instruction);
624     case HloOpcode::kRecvDone:
625       return UpdateRecvDoneValueSet(instruction);
626     case HloOpcode::kConditional:
627       return UpdateConditionalValueSet(instruction);
628     default:
629       // Instruction does not forward HloValues (it defines all values in its
630       // output). No update is necessary.
631       return false;
632   }
633 }
634 
Propagate()635 void HloDataflowAnalysis::Propagate() {
636   std::queue<HloInstruction*> worklist;
637   absl::flat_hash_set<HloInstruction*> workset;
638   auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
639     if (workset.insert(instruction).second) {
640       worklist.push(instruction);
641     }
642   };
643 
644   for (HloComputation* computation : module_.computations()) {
645     for (HloInstruction* instruction : computation->instructions()) {
646       add_to_worklist(instruction);
647     }
648   }
649 
650   while (!worklist.empty()) {
651     HloInstruction* instruction = worklist.front();
652     worklist.pop();
653     workset.erase(workset.find(instruction));
654 
655     VLOG(3) << "Worklist top: " << instruction->name();
656     VLOG(3) << ToString();
657 
658     if (!UpdateInstructionValueSet(instruction)) {
659       // No change to the instruction's value set.
660       VLOG(4) << "No change.";
661       continue;
662     }
663 
664     VLOG(4) << "New value set for " << instruction->name() << ": "
665             << GetInstructionValueSet(instruction);
666 
667     // Instruction value was updated. Add users to work list if we haven't
668     // already.
669     for (HloInstruction* user : instruction->users()) {
670       add_to_worklist(user);
671 
672       // If user sequentially calls a computation, then the respective
673       // parameter(s) of the computation need to be updated.
674       if (user->opcode() == HloOpcode::kConditional) {
675         // If operand 0 is the use of instruction, then no parameters need to be
676         // updated, since that is the branch_index of the conditional.
677         // If operand n+1 is the use of instruction, then the branch_computation
678         // n's parameter need to be updated.
679         //
680         // Note that the same instruction can be used in multiple branches'
681         // operands.
682         for (int j = 0; j < user->branch_count(); ++j) {
683           if (user->operand(j + 1) == instruction) {
684             add_to_worklist(
685                 user->branch_computation(j)->parameter_instruction(0));
686           }
687         }
688       } else {
689         for (HloComputation* called_computation : user->called_computations()) {
690           const CallGraphNode& call_graph_node =
691               call_graph_->GetNode(called_computation);
692           if (call_graph_node.context() == CallContext::kSequential) {
693             for (int64 operand_number : user->OperandIndices(instruction)) {
694               add_to_worklist(
695                   called_computation->parameter_instruction(operand_number));
696             }
697           }
698         }
699       }
700     }
701 
702     // If instruction is a root instruction, then propagate out to any calling
703     // instruction and across any while backedge.
704     if (instruction == instruction->parent()->root_instruction()) {
705       const CallGraphNode& call_graph_node =
706           call_graph_->GetNode(instruction->parent());
707       for (const CallSite& callsite : call_graph_node.caller_callsites()) {
708         if (callsite.instruction()->opcode() == HloOpcode::kCall ||
709             callsite.instruction()->opcode() == HloOpcode::kConditional) {
710           add_to_worklist(callsite.instruction());
711         } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
712           // Add the while itself, and the body and condition parameters.
713           add_to_worklist(callsite.instruction());
714           add_to_worklist(
715               callsite.instruction()->while_body()->parameter_instruction(0));
716           add_to_worklist(
717               callsite.instruction()->while_condition()->parameter_instruction(
718                   0));
719         }
720       }
721     }
722   }
723 }
724 
GetInstructionValueSet(const HloInstruction * instruction) const725 const InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
726     const HloInstruction* instruction) const {
727   return value_sets_.at(instruction);
728 }
729 
GetInstructionValueSet(const HloInstruction * instruction)730 InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
731     const HloInstruction* instruction) {
732   return value_sets_.at(instruction);
733 }
734 
InitializeInstructionValueSets()735 Status HloDataflowAnalysis::InitializeInstructionValueSets() {
736   for (const HloComputation* computation : module_.computations()) {
737     const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
738     for (HloInstruction* instruction : computation->instructions()) {
739       // Create an empty shape tree.
740       value_sets_.emplace(std::piecewise_construct,
741                           std::forward_as_tuple(instruction),
742                           std::forward_as_tuple(instruction->shape()));
743 
744       // Lambda to set the value set to define all values in the output of the
745       // instruction.
746       auto define_all_values = [this, &instruction](bool is_phi = false) {
747         for (auto& pair : GetInstructionValueSet(instruction)) {
748           const ShapeIndex& index = pair.first;
749           HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
750           GetValueSet(instruction, index).AddValue(value);
751         }
752       };
753 
754       // Lambda to set the value set to define only the top-level buffer in the
755       // output of the instruction. Any other values flow from the operands of
756       // the instruction (or from cross-computation dataflow).
757       auto define_top_level_only = [this, &instruction]() {
758         HloValue* value =
759             NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false);
760         GetValueSet(instruction, /*index=*/{}).AddValue(value);
761       };
762 
763       // Lambda to set the value set at the given index of the output.
764       auto define_value_at = [this, &instruction](const ShapeIndex& index) {
765         HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false);
766         GetValueSet(instruction, index).AddValue(value);
767       };
768 
769       switch (instruction->opcode()) {
770         case HloOpcode::kBitcast:
771           if (bitcast_defines_value_) {
772             define_all_values();
773           }
774           break;
775         case HloOpcode::kAddDependency:
776         case HloOpcode::kWhile:
777         case HloOpcode::kCall:
778         case HloOpcode::kConditional:
779         case HloOpcode::kGetTupleElement:
780         case HloOpcode::kDomain:
781           // These instructions define no values. The values in their output
782           // flow from their operands or from cross computation dataflow.
783           break;
784         case HloOpcode::kParameter:
785           if (call_graph_node.context() == CallContext::kBoth) {
786             // We do not support a subcomputation that is called from both a
787             // parallel and sequential context. In this case, the parameter
788             // would both define a value and propagate a value from its
789             // caller. This limitation is not really a problem because the call
790             // graph is typically flattened.
791             return Unimplemented(
792                 "Computation %s is called in both a parallel (eg, kMap) and "
793                 "sequential (eg, kCall) context",
794                 computation->name());
795           }
796           if (call_graph_node.caller_callsites().empty() ||
797               call_graph_node.context() == CallContext::kParallel) {
798             // Parameters of computations called in a parallel context (eg, map
799             // and reduce) as well as parameters of dead computations define all
800             // values in their output. Otherwise the values of the parameter
801             // come from the caller (eg, operands to the kCall instruction).
802             define_all_values();
803           }
804           break;
805         case HloOpcode::kCopy:
806         case HloOpcode::kTupleSelect:
807         case HloOpcode::kTuple:
808           // These instructions only define their top-level values. Any other
809           // values flow from their operands.
810           define_top_level_only();
811           break;
812         case HloOpcode::kRecvDone:
813           // RecvDone produces a two-element tuple. Element zero aliases its
814           // input tuple element {0}; element one is a token.
815           define_value_at(/*index=*/{});
816           define_value_at(/*index=*/{1});
817           break;
818         case HloOpcode::kSend:
819           // Send produces a tuple of {aliased operand, U32 context, token},
820           // therefore only defines the top-level tuple and the tuple elements
821           // at {1} and {2}.
822           define_value_at(/*index=*/{});
823           define_value_at(/*index=*/{1});
824           define_value_at(/*index=*/{2});
825           break;
826         default:
827           define_all_values();
828           break;
829       }
830     }
831   }
832 
833   return Status::OK();
834 }
835 
836 /* static */
Run(const HloModule & module,bool ssa_form,bool bitcast_defines_value,const FusionCanShareBufferFunction & fusion_can_share_buffer)837 StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
838     const HloModule& module, bool ssa_form, bool bitcast_defines_value,
839     const FusionCanShareBufferFunction& fusion_can_share_buffer) {
840   VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
841   XLA_VLOG_LINES(2, module.ToString());
842 
843   auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
844       module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
845 
846   TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
847   dataflow_analysis->Propagate();
848 
849   // Delete all values marked for deletion.
850   dataflow_analysis->DeleteMarkedValues();
851 
852   // Gather and set all non-definition positions of all values. Value deletion
853   // is rare, so just use a vector indexed by Value::Id rather than a map from
854   // Value::Id to positions. There should be very few holes in the vector, and
855   // lookup is faster.
856   std::vector<std::vector<HloPosition>> value_positions(
857       dataflow_analysis->next_value_id_);
858   for (const HloComputation* computation : module.computations()) {
859     for (HloInstruction* instruction : computation->instructions()) {
860       for (const auto& pair :
861            dataflow_analysis->GetInstructionValueSet(instruction)) {
862         const ShapeIndex& index = pair.first;
863         const HloValueSet& value_set = pair.second;
864         for (const HloValue* value : value_set.values()) {
865           if (value->defining_instruction() != instruction) {
866             value_positions[value->id()].push_back(
867                 HloPosition{instruction, index});
868           }
869         }
870       }
871     }
872   }
873   for (auto& pair : dataflow_analysis->values_) {
874     HloValue::Id value_id = pair.first;
875     HloValue& value = pair.second;
876     value.SetPositionsAndComputeUses(value_positions[value_id]);
877   }
878 
879   // Construct vector of values.
880   dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
881   for (auto& pair : dataflow_analysis->values_) {
882     dataflow_analysis->values_vector_.push_back(&pair.second);
883   }
884   absl::c_sort(dataflow_analysis->values_vector_, HloValue::IdLessThan);
885 
886   TF_DCHECK_OK(dataflow_analysis->Verify());
887 
888   XLA_VLOG_LINES(1, dataflow_analysis->ToString());
889 
890   return std::move(dataflow_analysis);
891 }
892 
Verify() const893 Status HloDataflowAnalysis::Verify() const {
894   // Verify each HloValue appears in the value sets that the value's positions()
895   // indicate.
896   for (const HloValue* value : values()) {
897     for (const HloPosition& position : value->positions()) {
898       const HloValueSet& value_set = GetValueSet(position);
899       TF_RET_CHECK(absl::c_linear_search(value_set.values(), value))
900           << "Value set at position " << position << " does not contain value "
901           << value->ToShortString();
902     }
903   }
904 
905   // For each value in each value set, verify that the value set's position
906   // appears in the value's positions().
907   for (const auto& computation : module_.computations()) {
908     for (const auto& instruction : computation->instructions()) {
909       for (const auto& pair : GetInstructionValueSet(instruction)) {
910         const ShapeIndex& index = pair.first;
911         const HloValueSet& value_set = pair.second;
912         const HloPosition position{instruction, index};
913         for (const HloValue* value : value_set.values()) {
914           TF_RET_CHECK(absl::c_linear_search(value->positions(), position))
915               << "Value set at position " << position
916               << " unexpectedly contains value " << value->ToShortString();
917         }
918       }
919     }
920   }
921 
922   return Status::OK();
923 }
924 
DoesNotUseOperandBuffer(const HloInstruction * operand,const ShapeIndex & index,const HloInstruction * user) const925 bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
926     const HloInstruction* operand, const ShapeIndex& index,
927     const HloInstruction* user) const {
928   // Return false if no value at 'operand' and 'index' is used at 'user'.
929   for (const HloValue* value : GetValueSet(operand, index).values()) {
930     for (const HloUse& use : value->uses()) {
931       if (use.instruction == user) {
932         if (user->opcode() == HloOpcode::kFusion &&
933             user->fusion_kind() == HloInstruction::FusionKind::kLoop) {
934           HloInstruction* fusion_param =
935               user->fused_parameter(use.operand_number);
936           const HloValue& value =
937               GetValueDefinedAt(fusion_param, use.operand_index);
938           return value.uses().empty();
939         }
940         return false;
941       }
942     }
943   }
944   return true;
945 }
946 
947 // Given a fusion whose root is a dynamic-update-slice op, determines whether
948 // the fusion's output buffer can be shared with the buffer of fusion_param,
949 // which must be a fused parameter of the fusion.
950 //
951 // Preconditions:
952 //
953 //  - fusion's root is a dynamic-update-slice op.
954 //  - fusion_param is a parameter within the fusion.
955 //
956 // fusion_param may point to a subelement of the actual parameter instruction if
957 // the param is a tuple; i.e. fusion_param->index() need not be the empty list.
958 //
959 // Returns true if:
960 //
961 //  * fusion is a loop or input fusion, AND
962 //  * fusion_param is used by the root of dynamic-update-slice as the "base" of
963 //    the update, i.e. the thing being updated, AND
964 //  * all other uses of fusion_param are dynamic-slices that slice the same
965 //    indices as are overwritten in the dynamic-update-slice.
966 //
967 // In the case that there are no other uses of fusion_param (last bullet point
968 // is vacuously true) it's easy to see why an in-place DUS is safe; this is just
969 // the "natural" implementation of DUS.  If there are other users, in-place DUS
970 // is safe on the assumption that the thread which writes element i of the
971 // output will be the only one to read element i of fusion_param (via the
972 // dynamic-slice ops).
CanDoInPlaceDynamicUpdateSlice(HloInstruction * fusion,const HloValue & fusion_param_value)973 static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
974                                            const HloValue& fusion_param_value) {
975   auto* root =
976       Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
977   auto* fusion_param = fusion_param_value.instruction();
978   CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
979   CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
980 
981   // fusion must be a loop or input fusion.
982   auto kind = fusion->fusion_kind();
983   if (kind != HloInstruction::FusionKind::kLoop &&
984       kind != HloInstruction::FusionKind::kInput) {
985     return false;
986   }
987 
988   // fusion_param must be used by the root as the "base" of the
989   // dynamic-update-slice.  The natural way to check this would be
990   //
991   //   `if (root->operand(0) != fusion_param)`
992   //
993   // but we also have to handle the case where the fusion parameter is
994   // tuple-shaped and we're considering just one element of that tuple, i.e.
995   // fusion_param.index() != {}.
996   if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
997         return use.instruction == root;
998       }) != 1) {
999     return false;
1000   }
1001 
1002   // All other uses of fusion_param must be dynamic-slices that slice the same
1003   // indices as are overwritten by the dynamic-update-slice.
1004   for (const HloUse& use : fusion_param_value.uses()) {
1005     auto* user = use.instruction;
1006     if (user == root) {
1007       continue;
1008     }
1009 
1010     // Check that `user` is a dynamic-slice op and has the same slice indices as
1011     // `root`.
1012     auto* ds = DynCast<HloDynamicSliceInstruction>(user);
1013     if (!ds || ds->index_operands() != root->index_operands()) {
1014       return false;
1015     }
1016   }
1017   return true;
1018 }
1019 
CanShareOperandBufferWithUser(HloInstruction * operand,const ShapeIndex & operand_index,HloInstruction * user,const ShapeIndex & user_index) const1020 bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
1021     HloInstruction* operand, const ShapeIndex& operand_index,
1022     HloInstruction* user, const ShapeIndex& user_index) const {
1023   CHECK(user->IsUserOf(operand))
1024       << "user: " << user->ToString() << " operand: " << operand->ToString();
1025   const Shape& operand_subshape =
1026       ShapeUtil::GetSubshape(operand->shape(), operand_index);
1027   const Shape& user_subshape =
1028       ShapeUtil::GetSubshape(user->shape(), user_index);
1029 
1030   // Check that operand and user emit the same shape and layout.
1031   if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
1032     return false;
1033   }
1034 
1035   if (user->opcode() == HloOpcode::kFusion) {
1036     // Get the parameter associated with 'operand';
1037     HloInstruction* fusion_param =
1038         user->fused_parameter(user->operand_index(operand));
1039 
1040     const HloValue& fusion_param_value =
1041         GetValueDefinedAt(fusion_param, operand_index);
1042 
1043     // TODO(b/80315712): This code is in a bit of a weird intermediate state
1044     // at the moment. The in-place DUS check really needs to be common to all
1045     // backends, so it runs first. Then we run the backend-specific check if
1046     // provided, or go through the target-indepdendent check if not.
1047     // Unfortunately, the notionally "target-independent" path actually contains
1048     // some target-specific code, so we can't run all of it *in addition* to the
1049     // target-specific function, like the interface documentation says.
1050     if (user->fused_expression_root()->opcode() ==
1051         HloOpcode::kDynamicUpdateSlice) {
1052       return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
1053     }
1054 
1055     if (fusion_can_share_buffer_ != nullptr) {
1056       return fusion_can_share_buffer_(user, operand);
1057     }
1058 
1059     if (user->fusion_kind() == HloInstruction::FusionKind::kLoop ||
1060         user->fusion_kind() == HloInstruction::FusionKind::kInput) {
1061       return AreTransitiveUsesElementwiseOrTuple(fusion_param);
1062     }
1063 
1064     if (user->fusion_kind() == HloInstruction::FusionKind::kOutput &&
1065         user->fused_expression_root()->opcode() == HloOpcode::kAdd) {
1066       // Output fusion with kAdd fused root.
1067 
1068       // Check if one operand of kAdd fused root is kDot or kConvolution.
1069       auto* add = user->fused_expression_root();
1070       auto add_operand_it =
1071           absl::c_find_if(add->operands(), [&](HloInstruction* operand) {
1072             return operand->opcode() == HloOpcode::kConvolution ||
1073                    operand->opcode() == HloOpcode::kDot;
1074           });
1075       if (add_operand_it == add->operands().end()) {
1076         return false;
1077       }
1078       auto* matched_add_operand = *add_operand_it;
1079       // Calculate operand index of 'add' operand which was not matched above.
1080       const int64 other_add_operand_index =
1081           matched_add_operand == add->operand(0) ? 1 : 0;
1082       // Returns true iff there is exactly one use of 'operand' at shape index
1083       // 'operand_index', and this singleton use is the fused root (at operand
1084       // index 'other_add_operand_index').
1085       if (fusion_param_value.uses().size() == 1) {
1086         const HloUse& use = fusion_param_value.uses()[0];
1087         return use.instruction == user->fused_expression_root() &&
1088                use.operand_number == other_add_operand_index;
1089       }
1090       return false;
1091     }
1092   }
1093 
1094   if (user->opcode() == HloOpcode::kDynamicUpdateSlice ||
1095       user->opcode() == HloOpcode::kScatter ||
1096       user->opcode() == HloOpcode::kWhile) {
1097     // We eliminated other users in BufferLiveness::live_range_strictly_before,
1098     // so here we just need to check that the use is at operand index 0.
1099     std::vector<int64> operand_indices = user->OperandIndices(operand);
1100     return operand_indices.size() == 1 && operand_indices[0] == 0;
1101   }
1102   if (user->opcode() == HloOpcode::kSort) {
1103     // Only valid if there are no other users.
1104     if (operand->users().size() != 1) {
1105       return false;
1106     }
1107     // If we only sort keys, the output of sort is not a tuple, so we can always
1108     // share the buffer.
1109     if (user->operand_count() == 1) {
1110       return true;
1111     }
1112     CHECK(!user_index.empty());
1113     // Only share with the right tuple element buffer.
1114     std::vector<int64> operand_indices = user->OperandIndices(operand);
1115     return operand_indices.size() == 1 && user_index[0] == operand_indices[0];
1116   }
1117   if (user->opcode() == HloOpcode::kCall) {
1118     // Get all uses of value defined by 'operand' at 'operand_index'.
1119     const auto& uses = GetValueDefinedAt(operand, operand_index).uses();
1120     // Return true iff:
1121     // *) There exists two uses of 'operand'.
1122     // *) One use is by 'user' (caller).
1123     // *) One use is by root instruction of called computation (callee root).
1124     //    (Note: we check the root of the called computation, because the
1125     //     root result buffer is required to alias with the Call result buffer).
1126     // *) The root instruction of the called computation is element-wise on
1127     //    'operand'.
1128     const bool found_caller_use =
1129         absl::c_find_if(uses, [user](const HloUse& use) {
1130           return use.instruction == user;
1131         }) != uses.end();
1132     auto* callee_root = user->to_apply()->root_instruction();
1133     const bool found_elementwise_callee_use =
1134         absl::c_find_if(uses, [callee_root](const HloUse& use) {
1135           return use.instruction == callee_root &&
1136                  callee_root->IsElementwiseOnOperand(use.operand_number);
1137         }) != uses.end();
1138     return uses.size() == 2 && found_caller_use && found_elementwise_callee_use;
1139   }
1140 
1141   // Loop fusions that contain transposing copies won't reach here as they have
1142   // different layouts, which fails the check in the beginning of this function.
1143   return user->IsElementwiseOnOperand(user->operand_index(operand));
1144 }
1145 
1146 }  // namespace xla
1147