1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dump.h"
23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 namespace {
41 
42 using absl::StrAppend;
43 
IsReadonlyEntryParameterValue(const HloValue & value)44 bool IsReadonlyEntryParameterValue(const HloValue& value) {
45   const HloComputation* computation = value.defining_instruction()->parent();
46   return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
47          computation == computation->parent()->entry_computation() &&
48          !computation->parent()->input_output_alias_config().ParameterHasAlias(
49              value.defining_instruction()->parameter_number(), value.index());
50 }
51 
IsConstantValue(const HloValue & value)52 bool IsConstantValue(const HloValue& value) {
53   return value.defining_instruction()->opcode() == HloOpcode::kConstant;
54 }
55 
ValueIsReadOnly(const HloValue & value)56 bool ValueIsReadOnly(const HloValue& value) {
57   return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
58 }
59 
60 // Data structure describing the action which should be taken on parts of a
61 // computation buffers, with respect to the adding of special case copies.
62 struct SpecialCaseCopyPolicy {
63   // Insert a copy if the same buffer is found at multiple indices within the
64   // output tuple.
65   bool copy_root_replicated_buffers = false;
66   // If true, insert a copy if a buffer coming from a constant or a parameter
67   // is found within the output tuple.
68   bool copy_parameters_and_constants = false;
69 };
70 
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)71 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
72                                                HloModule* module,
73                                                HloComputation* computation) {
74   SpecialCaseCopyPolicy policy;
75   if (computation == module->entry_computation()) {
76     policy.copy_parameters_and_constants = true;
77     policy.copy_root_replicated_buffers = true;
78   }
79   return policy;
80 }
81 
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)82 bool ShouldCopyRootValue(const HloValue& value,
83                          const SpecialCaseCopyPolicy& policy) {
84   if (policy.copy_parameters_and_constants) {
85     return ValueIsReadOnly(value);
86   }
87   return false;
88 }
89 
90 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
91 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
92 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
93 // of 'to'.
94 //
95 // Requirements: 'from' and 'to' must have compatible shapes.
96 //
97 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
98 // the only index to copy. Prior to deep-copying we have:
99 //
100 //
101 //      'from'
102 //         |
103 //        ...
104 //         |
105 //       'to'
106 //
107 // DeepCopyAndAddControlEdges produces:
108 //
109 //       'from'
110 //        /   \
111 //      GTE   GTE
112 //       |     |
113 //     Copy    |
114 //    /   \   /
115 //   |    Tuple
116 //   |      |
117 //  ctrl   ...
118 //  edge    |
119 //   |      |
120 //   |    'to'
121 //   |    /   \
122 //   |  GTE   GTE
123 //    \  |     |
124 //     Copy    |
125 //        \   /
126 //        Tuple
127 //
128 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)129 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
130                            const ShapeTree<bool>& indices_to_copy) {
131   DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
132   // to/from_copy_tree hold the kCopy instruction produces by the deep
133   // copies. Elements which are not copied (indices_to_copy.element(index) ==
134   // false) have nullptr at that index.
135   ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
136                                             /*init_value=*/nullptr);
137   TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
138                       from->parent()->DeepCopyInstruction(
139                           from, &indices_to_copy, &from_copy_tree));
140 
141   ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
142   TF_ASSIGN_OR_RETURN(
143       HloInstruction * to_deep_copy,
144       to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
145 
146   // Add control edges between the respective kCopy instructions.
147   for (const auto& pair : from_copy_tree) {
148     const ShapeIndex& index = pair.first;
149     HloInstruction* from_copy = pair.second;
150     HloInstruction* to_copy = to_copy_tree.element(index);
151     if (from_copy == nullptr) {
152       TF_RET_CHECK(to_copy == nullptr);
153       continue;
154     }
155     TF_RET_CHECK(to_copy != nullptr);
156     TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
157   }
158 
159   return std::make_pair(from_deep_copy, to_deep_copy);
160 }
161 
162 // Compute the indices of the loop state which need copies in order to avoid
163 // live range interference. Generally, an element in the loop state does not
164 // need to be copied if the element is passed through transparently through the
165 // body.
166 //
167 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)168 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
169                            const HloInstruction* xla_while,
170                            ShapeTree<bool>* indices_to_copy) {
171   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
172 
173   bool any_copies = false;
174   const HloInstruction* init = xla_while->operand(0);
175   for (auto& pair : *indices_to_copy) {
176     const ShapeIndex& index = pair.first;
177     bool& should_copy = pair.second;
178     // If there is any ambiguity, then loop state must be copied.
179     if (dataflow.GetValueSet(init, index).values().size() > 1 ||
180         dataflow.GetValueSet(xla_while, index).values().size() > 1) {
181       should_copy = true;
182     } else {
183       // If the output of the while instruction is not the same as the init
184       // value of the while, then this element is not passed through the body
185       // transparently and must be copied.
186       should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
187                     dataflow.GetUniqueValueAt(init, index);
188     }
189     any_copies |= should_copy;
190   }
191   return any_copies;
192 }
193 
194 // Add kCopy instructions around the given kWhile instruction to eliminate any
195 // possible live range interference of HLO values assuming a dependency-based
196 // ordering (HloDependencyOrdering). Copies are added conservatively. There
197 // likely are copies which are not strictly necessary, but they are removed
198 // later in the pass via RemoveUnnecessaryCopies.
199 //
200 //
201 // Elements (each ShapeIndex) in the loop state are considered independently.  A
202 // copy is added to each element of the loop state which is modified in the
203 // while body. For each such element, a total of three kCopy instructions are
204 // added at following locations:
205 //
206 //   (1) The init value is copied before the kWhile instruction. Before:
207 //
208 //           (Init)
209 //             |
210 //           kWhile
211 //             |
212 //            ...
213 //
214 //       After:
215 //
216 //           (Init)
217 //             |
218 //           kCopy
219 //             |
220 //           kWhile
221 //             |
222 //            ...
223 //
224 //       This copy is necessary in case the init value is simultaneously live
225 //       with the kWhile.
226 //
227 //   (2) Copies are added to the parameter and root of the while body
228 //       computation. Before:
229 //
230 //           kParameter
231 //               |
232 //              ...
233 //               |
234 //           (body root)
235 //
236 //       After:
237 //
238 //           kParameter
239 //               |
240 //             kCopy ----------+
241 //               |             |
242 //              ...           ctrl
243 //               |            edge
244 //           (body root)       |
245 //               |             |
246 //             kCopy <---------+
247 //
248 //       The root kCopy becomes the new root of the computation. Both copies are
249 //       necessary to any potential interference between the parameter value and
250 //       the root value. The control edge prevents potential interference
251 //       between the copies themselves.
252 //
253 // If the loop state is a tuple then the above kCopy instructions are a deep
254 // copy constructed of kCopy, KGetTupleElement, and kTuple instruction as
255 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)256 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
257                          HloInstruction* xla_while) {
258   VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
259   TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
260 
261   ShapeTree<bool> indices_to_copy(xla_while->shape());
262   if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
263                              &indices_to_copy)) {
264     VLOG(2) << "No copies necessary for kWhile instruction "
265             << xla_while->name();
266     return Status::OK();
267   }
268 
269   VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
270   for (auto& pair : indices_to_copy) {
271     if (pair.second) {
272       VLOG(2) << "  " << pair.first;
273     }
274   }
275 
276   // Deep copy init.
277   HloInstruction* while_init = xla_while->mutable_operand(0);
278   TF_ASSIGN_OR_RETURN(
279       HloInstruction * while_init_copy,
280       xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
281   TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
282 
283   // Deep copy the parameter and the root. Extend a control edge from the copy
284   // of the parameter value to the corresponding copy value of the root.
285   HloComputation* body = xla_while->while_body();
286   HloInstruction* param = body->parameter_instruction(0);
287   HloInstruction* root = body->root_instruction();
288 
289   // If param is the root then all indices should have been passed through the
290   // while body and we should have returned early above.
291   TF_RET_CHECK(param != root);
292 
293   // Copy users before making a deep copy of the parameter as the deep copy
294   // will create new users of the parameter (eg, the GTE instructions of the
295   // deep copy).
296   std::vector<HloInstruction*> param_users = param->users();
297 
298   ShapeIndex current_index;
299   TF_ASSIGN_OR_RETURN(auto pair,
300                       DeepCopyAndAddControlEdges(param, root, indices_to_copy));
301 
302   HloInstruction* param_copy = pair.first;
303   HloInstruction* root_copy = pair.second;
304 
305   for (HloInstruction* user : param_users) {
306     TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
307   }
308 
309   body->set_root_instruction(root_copy);
310 
311   return Status::OK();
312 }
313 
314 // We add copies for all the indices of the true and false computation roots, in
315 // order to resolve interference. We later rely on RemoveUnnecessaryCopies to
316 // drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)317 Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
318                                HloInstruction* conditional) {
319   VLOG(2) << "Adding copies for kConditional instruction "
320           << conditional->name();
321   TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
322 
323   for (HloComputation* computation : conditional->branch_computations()) {
324     HloInstruction* root = computation->root_instruction();
325     std::vector<HloInstruction*> users = root->users();
326     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
327                         computation->DeepCopyInstruction(root));
328     for (HloInstruction* user : users) {
329       TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
330     }
331     computation->set_root_instruction(deep_copy);
332   }
333   return Status::OK();
334 }
335 
336 // Conservatively adds copies before root instruction of entry computation and
337 // each aliased parameter to resolve interference of aliased input and output
338 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
339 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)340 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
341   HloComputation* entry = module->entry_computation();
342   HloInstruction* root = entry->root_instruction();
343 
344   ShapeTree<bool> output_indices_to_copy(root->shape());
345   std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
346       entry->num_parameters());
347   bool has_alias = false;
348   for (auto* param : entry->parameter_instructions()) {
349     bool param_has_alias = false;
350     ShapeTree<bool> param_indices_to_copy(param->shape());
351 
352     module->input_output_alias_config().ForEachAlias(
353         [&](const ShapeIndex& output_index,
354             const HloInputOutputAliasConfig::Alias& alias) {
355           if (alias.parameter_number == param->parameter_number()) {
356             param_has_alias = true;
357             *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
358                 true;
359             *(output_indices_to_copy.mutable_element(output_index)) = true;
360           }
361         });
362 
363     if (!param_has_alias) {
364       continue;
365     }
366 
367     TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
368     TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
369 
370     has_alias = true;
371     // Store a snapshot of users before DeepCopyInstruction, as
372     // DeepCopyInstruction introduces new users of the instruction.
373     std::vector<HloInstruction*> users = param->users();
374     ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
375                                                /*init_value=*/nullptr);
376     TF_ASSIGN_OR_RETURN(HloInstruction * copied,
377                         entry->DeepCopyInstruction(
378                             param, &param_indices_to_copy, &param_copy_tree));
379     for (HloInstruction* user : users) {
380       TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
381     }
382 
383     copied_parameters[param->parameter_number()] = param_copy_tree;
384   }
385 
386   if (!has_alias) {
387     return Status::OK();
388   }
389 
390   // Add copies before root instruction.
391   ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
392                                               /*init_value=*/nullptr);
393 
394   TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
395                       root->parent()->DeepCopyInstruction(
396                           root, &output_indices_to_copy, &output_copy_tree));
397 
398   // Add control dependencies between the input/output copies.
399   TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
400       [&](const ShapeIndex& output_index,
401           const HloInputOutputAliasConfig::Alias& alias) -> Status {
402         if (!copied_parameters[alias.parameter_number]) {
403           return Status::OK();
404         }
405         HloInstruction* from =
406             copied_parameters[alias.parameter_number]->element(
407                 alias.parameter_index);
408         HloInstruction* to = output_copy_tree.element(output_index);
409 
410         TF_RET_CHECK(from != nullptr);
411         TF_RET_CHECK(to != nullptr);
412         TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
413         return Status::OK();
414       }));
415 
416   entry->set_root_instruction(root_copied);
417 
418   return Status::OK();
419 }
420 
421 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)422 Status StripControlDependenciesFrom(HloInstruction* instruction) {
423   while (!instruction->control_successors().empty()) {
424     TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
425         instruction->control_successors().front()));
426   }
427 
428   while (!instruction->control_predecessors().empty()) {
429     TF_RETURN_IF_ERROR(
430         instruction->control_predecessors().front()->RemoveControlDependencyTo(
431             instruction));
432   }
433 
434   return Status::OK();
435 }
436 
437 // Class which tracks the HLO values within each HLO buffer in the module
438 // during copy removal.
439 //
440 // The values are held in a linked list where there is one list for each
441 // buffer. Removing a copy instruction merges together the values in the
442 // source buffer of the copy to the destination buffer of the copy. This class
443 // tracks these value lists as copies are removed from the graph (and value
444 // lists are merged).
445 //
446 // The CopyRemover object is initialized to match the state of
447 // HloAliasAnalysis. However, as copies are removed this state diverges. The
448 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
449 // a fully updatable alias analysis is very slow.
450 class CopyRemover {
451  public:
452   // The values held in a single HLO buffer are represented using a linked
453   // list. An element type in this list is ValueNode.
454   //
455   // This linked list is hand-rolled to enable efficient splicing of lists
456   // using only references to list elements without knowing which lists are
457   // being spliced. std::list requires a reference to the list object to
458   // splice.
459   struct ValueNode {
ValueNodexla::__anon747efefc0111::CopyRemover::ValueNode460     explicit ValueNode(const HloValue* v) : value(v) {}
461 
462     const HloValue* value;
463 
464     // The uses are maintained outside of HloValue::uses() because
465     // HloValue::uses() is not updatable (a fully updatable dataflow analysis
466     // is slow).
467     std::vector<const HloUse*> uses;
468 
469     // next/prev elements in the linked list. The list is circularly linked so
470     // these values are never null for elements in the list.
471     ValueNode* prev = nullptr;
472     ValueNode* next = nullptr;
473   };
474 
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,const HloOrdering & ordering)475   CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
476               const HloOrdering& ordering)
477       : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
478     // Construct a list for each HLO buffer in the alias analysis. Maintain a
479     // map from HloValue to the respective list element representing that
480     // value. The map is used to construct the copy info map below.
481     absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
482     for (const HloBuffer& buffer : alias_analysis.buffers()) {
483       // Verify values contained in the buffer are strictly ordered. This
484       // should always be the case after adding copies to eliminate
485       // interference. Specifically, the addition of the control flow edges
486       // between copies added around aliased operations (kWhile) guarantees
487       // this strict order.
488       for (const HloValue* value_a : buffer.values()) {
489         if (value_a->shape().IsToken()) {
490           // Token values have no representation and cannot interfere.
491           continue;
492         }
493         for (const HloValue* value_b : buffer.values()) {
494           if (value_a != value_b) {
495             DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
496                                                      dataflow_) ||
497                    ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
498                                                      dataflow_))
499                 << value_a->ToShortString() << " and "
500                 << value_b->ToShortString() << " are not ordered";
501           }
502         }
503       }
504 
505       std::vector<const HloValue*> values = buffer.values();
506       absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
507         return ordering_.IsDefinedBefore(*a, *b);
508       });
509 
510       // Create a list containing all of the values in the buffer.
511       AddValueList(values, &value_to_node);
512     }
513 
514     // Create copy_map_ which contains the source and destination values
515     // of all copies.
516     CreateCopyMap(module, value_to_node);
517 
518     XLA_VLOG_LINES(3, ToString());
519     TF_DCHECK_OK(Verify());
520   }
521 
522   // Add a list containing the given values to CopyRemover. This
523   // represents the values contained in a single buffer. For each value in
524   // 'values' an entry is created in value_to_node which indicates the
525   // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)526   void AddValueList(
527       absl::Span<const HloValue* const> values,
528       absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
529     ValueNode* tail = nullptr;
530     ValueNode* head = nullptr;
531     for (const HloValue* value : values) {
532       auto new_node = new ValueNode(value);
533       (*value_to_node)[value] = new_node;
534 
535       // Copy the HLO values's uses into the ValueNode for the value. These
536       // uses in ValueNode are updated as copies are removed.
537       new_node->uses.reserve(value->uses().size());
538       for (const HloUse& use : value->uses()) {
539         new_node->uses.push_back(&use);
540       }
541 
542       // Connect the new node into the linked list.
543       if (tail == nullptr) {
544         head = new_node;
545       } else {
546         tail->next = new_node;
547         new_node->prev = tail;
548       }
549       tail = new_node;
550     }
551 
552     // The linked list is circular so connect the head and tail.
553     tail->next = head;
554     head->prev = tail;
555     value_lists_.insert(head);
556   }
557 
558   // This method also fills in copy_map_ which indicates which nodes
559   // in the value lists corresponding to the source and destination values of
560   // kCopy instructions. value_to_node should map each HloValue to its
561   // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)562   void CreateCopyMap(
563       const HloModule& module,
564       const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
565     for (HloComputation* computation : module.computations()) {
566       for (HloInstruction* instruction : computation->instructions()) {
567         // Add copies with unambiguous source values to the map. Copies with
568         // ambiguous sources are not removable.
569         if (instruction->opcode() == HloOpcode::kCopy) {
570           const HloValueSet& src_value_set =
571               dataflow_.GetValueSet(instruction->operand(0));
572           if (src_value_set.values().size() == 1) {
573             CopyNodes& copy_node = copy_map_[instruction];
574             copy_node.dest =
575                 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
576             copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
577           }
578         }
579       }
580     }
581   }
582 
~CopyRemover()583   ~CopyRemover() {
584     for (const ValueNode* head : value_lists_) {
585       const ValueNode* p = head;
586       do {
587         const ValueNode* tmp = p->next;
588         delete p;
589         p = tmp;
590       } while (p != head);
591     }
592   }
593 
594   // Verify invariants within the linked lists.
Verify() const595   Status Verify() const {
596     for (const ValueNode* head : value_lists_) {
597       const ValueNode* p = head;
598       do {
599         // Verify links between elements are consistent.
600         TF_RET_CHECK(p->prev->next == p);
601         TF_RET_CHECK(p->next->prev == p);
602 
603         const HloInstruction* def = p->value->defining_instruction();
604         if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
605           TF_RET_CHECK(copy_map_.at(def).dest == p);
606         }
607         for (const HloUse* use : p->uses) {
608           if (use->instruction->opcode() == HloOpcode::kCopy &&
609               ContainsKey(copy_map_, use->instruction)) {
610             TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
611           }
612         }
613 
614         p = p->next;
615       } while (p != head);
616     }
617     return Status::OK();
618   }
619 
620   // Try to elide the given copy. Elision of a copy is possible only if no
621   // live range interference is introduced by the copy's elimination. If
622   // elision is possible, then the internal state (value lists) are updated,
623   // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy)624   bool TryElideCopy(const HloInstruction* copy) {
625     VLOG(2) << "Trying to remove " << copy->name();
626 
627     if (!ContainsKey(copy_map_, copy)) {
628       VLOG(2) << copy->name() << " is not removable";
629       return false;
630     }
631     if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
632       VLOG(2) << copy->name() << " is not removable (shape mismatch)";
633       return false;
634     }
635     const CopyNodes& copy_node = copy_map_.at(copy);
636     ValueNode* src = copy_node.src;
637     ValueNode* dest = copy_node.dest;
638     DCHECK(src != nullptr);
639     DCHECK(dest != nullptr);
640 
641     auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) {
642       VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
643       if (LiveRangeBefore(a, b)) {
644         VLOG(2) << "  Live range of " << a.value->ToShortString()
645                 << " is before " << b.value->ToShortString();
646         return true;
647       } else {
648         VLOG(2) << "  Live range of " << a.value->ToShortString()
649                 << " is not before " << b.value->ToShortString();
650         return false;
651       }
652     };
653 
654     VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
655     VLOG(3) << "Source buffer values: " << ValueListToString(src);
656     VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
657 
658     // A kCopy instruction copies an HLO value from a source buffer and
659     // defines an HLO value in a destination buffer. Most generally, the
660     // source and destination buffers may each hold more than one value at
661     // different points in the computation so we define the following:
662     //
663     //   Values in source buffer:      {s_0, ..., s_n}
664     //   Values in destination buffer: {d_0, ..., d_m}
665     //
666     // A kCopy instruction between these buffers copies a value s_x in the
667     // source buffer and defines a value d_y in the destination buffer. The
668     // elision of a copy merges the source and destination buffers together,
669     // so the list of values for the source and destination buffers are
670     // merged.
671     //
672     // We handle two different cases for copy elision:
673     //
674     //  (1) the kCopy defines the first value in the destination buffer (d_0).
675     //
676     //  (2) the kCopy copies the last value in the source buffer (s_n).
677     //
678     // For the remaining case where the kCopy copies a not-last value from the
679     // source buffer to a not-first value of the destination buffer, the kCopy
680     // instruction cannot be removed. This case is generated, for example, if
681     // the kCopy copies a while body parameter of the loop state at one tuple
682     // index to a different tuple index in the while body root. Removal of the
683     // copy necessarily results in live range interference of values in the
684     // loop state at the two different tuple indices.
685     //
686     //  We can only perform copy elision if the resulting merged values have
687     //  totally ordered live ranges; otherwise the merged buffer would have
688     //  live range interference.
689     if (src->next == dest) {
690       // In the process of eliding copies, its possible for a copy to have the
691       // same source and destination buffer. In this case, the copy can be
692       // safely removed.
693       VLOG(2) << copy->name() << " source and destination buffers are same.";
694     } else if (IsHead(*dest)) {
695       // The copy copies an arbitrary value in the source buffer (call it s_x)
696       // and defines d_0, the first value in the destination buffer. After
697       // merging, the values in the combined buffer must be strictly ordered
698       // as follows** to elide the copy:
699       //
700       // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
701       //
702       // Removing the copy eliminates d_0, and uses of d_0 become uses of
703       // s_x. In the above ordering, the live range of d_m must be ordered
704       // before the live range of s_{x+1} and the definition and all uses of
705       // s_x must be ordered before the definition of d_1. These conditions
706       // are checked below prior to elision.
707       //
708       // ** Technically it might be possible to have a non-interfering
709       //    non-trivial interleaving of the values of the source and
710       //    destination buffers in the resulting order. However, this case is
711       //    slow and complicated to check and likely not worth it. So instead
712       //    we simply check for the case where *all* values of the destination
713       //    buffer (d_1 through d_m) are spliced into the point where the copy
714       //    used to be.
715       VLOG(2) << copy->name() << " defines the first value in its buffer";
716       ValueNode* next_dest = Next(*dest);
717       if (next_dest != nullptr) {
718         // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
719         if (!is_live_range_before(*src, *next_dest)) {
720           return false;
721         }
722       }
723       ValueNode* next_src = Next(*src);
724 
725       if (next_src != nullptr) {
726         // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
727         ValueNode* last_dest = dest->prev;
728         DCHECK(IsTail(*last_dest));
729         if (!is_live_range_before(*last_dest, *next_src)) {
730           return false;
731         }
732       }
733 
734       // Splice in destination buffer values list right after 'src'.
735       SpliceAfter(dest, src);
736     } else if (IsTail(*src)) {
737       // The copy copies the last value in the source buffer, s_n, and defines
738       // an arbitrary value in the destination buffer, d_y.  After
739       // merging, the values in the combined buffer must be strictly ordered
740       // as follows** to elide the copy:
741       //
742       // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
743       //
744       // Removing the copy eliminates d_y, and uses of d_y become uses of
745       // s_n. To enforce the above order, the live range of d_{y-1} must be
746       // before the live range of s_0, and the live range of s_n must be
747       // before the live range of d_{y+1}.
748       //
749       // ** See comment above in the code handling Case (1).
750       VLOG(2) << copy->name() << " copies the last value ("
751               << src->value->ToShortString() << ") in its buffer";
752 
753       ValueNode* prev_dest = Prev(*dest);
754       // nullptr condition handled above in the first 'if' case.
755       DCHECK(prev_dest != nullptr);
756       ValueNode* first_src = src->next;
757       DCHECK(IsHead(*first_src));
758       if (!is_live_range_before(*prev_dest, *first_src)) {
759         // Live range of value d_{y-1} is not before s_0.
760         return false;
761       }
762       ValueNode* next_dest = Next(*dest);
763       if (next_dest != nullptr) {
764         if (!is_live_range_before(*src, *next_dest)) {
765           // Live range of value s_n is not before d_{y+1}.
766           return false;
767         }
768       }
769 
770       // Splice source buffer values list right after 'prev_dest'.
771       SpliceAfter(first_src, prev_dest);
772     } else {
773       VLOG(2) << copy->name()
774               << " copies value in middle of source buffer to value in middle "
775                  "of destination buffer";
776       return false;
777     }
778 
779     RemoveCopyValue(dest);
780 
781     XLA_VLOG_LINES(4, ToString());
782     TF_DCHECK_OK(Verify());
783 
784     return true;
785   }
786 
787   // Delete the given ValueNode associated with a elided kCopy
788   // instruction. This should be called after splicing the value lists of the
789   // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)790   void RemoveCopyValue(ValueNode* copy_value_node) {
791     CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
792              HloOpcode::kCopy);
793     ValueNode* operand_node = copy_value_node->prev;
794     CHECK(operand_node != copy_value_node);
795 
796     VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
797             << " => " << copy_value_node->value->ToShortString();
798 
799     // Splice out the copy value node.
800     operand_node->next = copy_value_node->next;
801     copy_value_node->next->prev = operand_node;
802 
803     // Patch up uses. Remove use of copy from operand_node uses.
804     auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
805                                                       const HloUse* use) {
806       return use->instruction == copy_value_node->value->defining_instruction();
807     });
808     CHECK(it != operand_node->uses.end());
809     operand_node->uses.erase(it);
810 
811     // If the elided copy has any uses which are themselves kCopy instructions
812     // then patch up the copy info to reflect the that this kCopy instruction
813     // has a different operand (the operand of the elided copy).
814     for (const HloUse* copy_use : copy_value_node->uses) {
815       operand_node->uses.push_back(copy_use);
816       if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
817           ContainsKey(copy_map_, copy_use->instruction)) {
818         copy_map_.at(copy_use->instruction).src = operand_node;
819       }
820     }
821 
822     // Delete the copy info and the value node.
823     copy_map_.erase(copy_value_node->value->defining_instruction());
824     delete copy_value_node;
825   }
826 
827   // Returns true if the live range of given value 'a' is before the live
828   // range of 'b'.
829   //
830   // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
831   // updated as copies are removed.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)832   bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
833     if (a.uses.empty()) {
834       VLOG(2) << "Empty uses for " << *a.value;
835       return ordering_.IsDefinedBefore(*a.value, *b.value);
836     }
837     for (const HloUse* use : a.uses) {
838       VLOG(2) << "Checking use " << *use << " against " << *b.value;
839       if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
840         VLOG(2) << "Use " << *use << " is NOT before " << *b.value;
841         return false;
842       }
843       VLOG(2) << "Use " << *use << " is before " << *b.value;
844     }
845     return true;
846   }
847 
848   // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const849   bool IsTail(const ValueNode& node) const {
850     return ContainsKey(value_lists_, node.next);
851   }
852 
853   // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const854   bool IsHead(const ValueNode& node) const {
855     return ContainsKey(value_lists_, &node);
856   }
857 
858   // Returns the next node in the list after 'node'. If 'node' is the
859   // tail, then nullptr is returned.
Next(const ValueNode & node) const860   ValueNode* Next(const ValueNode& node) const {
861     if (IsTail(node)) {
862       return nullptr;
863     } else {
864       return node.next;
865     }
866   }
867 
868   // Returns the previous node in the list before 'node'. If 'node'
869   // is the head, then nullptr is returned.
Prev(const ValueNode & node) const870   ValueNode* Prev(const ValueNode& node) const {
871     if (IsHead(node)) {
872       return nullptr;
873     } else {
874       return node.prev;
875     }
876   }
877 
878   // Splices the entire linked list with 'head' as its head right after the
879   // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)880   void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
881     DCHECK(IsHead(*head));
882     value_lists_.erase(head);
883 
884     ValueNode* tail = head->prev;
885     tail->next = insert_after->next;
886     insert_after->next->prev = tail;
887 
888     insert_after->next = head;
889     head->prev = insert_after;
890   }
891 
ValueListToString(const ValueNode * element)892   string ValueListToString(const ValueNode* element) {
893     const ValueNode* head = element;
894     while (!IsHead(*head)) {
895       head = Prev(*head);
896     }
897     std::vector<const HloValue*> values;
898     for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
899       values.push_back(p->value);
900     }
901     return absl::StrCat("{",
902                         absl::StrJoin(values, ", ",
903                                       [](string* s, const HloValue* value) {
904                                         StrAppend(s, value->ToShortString());
905                                       }),
906                         "}");
907   }
908 
ToString() const909   string ToString() const {
910     string out = absl::StrCat("CopyRemover:\n");
911     StrAppend(&out, "  Def-use chains in each buffer:\n");
912     for (const ValueNode* head : value_lists_) {
913       StrAppend(&out, "    Buffer defined by ", head->value->ToShortString(),
914                 ":\n");
915       const ValueNode* p = head;
916       do {
917         StrAppend(&out, "      ", p->value->ToShortString(), ", uses: ",
918                   absl::StrJoin(p->uses, "; ",
919                                 [](string* s, const HloUse* use) {
920                                   StrAppend(s, use->ToString());
921                                 }),
922                   "\n");
923 
924         p = p->next;
925       } while (p != head);
926     }
927     StrAppend(&out, "  Potentially removable copies:\n");
928     for (const auto& pair : copy_map_) {
929       const HloInstruction* copy = pair.first;
930       const CopyNodes& copy_info = pair.second;
931 
932       StrAppend(&out, "    ", copy->name(), " : ",
933                 copy_info.src->value->ToShortString(), " => ",
934                 copy_info.dest->value->ToShortString(), "\n");
935     }
936     return out;
937   }
938 
939  private:
940   const HloDataflowAnalysis& dataflow_;
941   const HloOrdering& ordering_;
942 
943   // The heads of all the value lists. Each value list represents the HLO
944   // values contained in a particular HLO buffer. The values in the list are
945   // in dependency order.
946   absl::flat_hash_set<const ValueNode*> value_lists_;
947 
948   // Copy removal requires fast access to the value list elements
949   // corresponding to the source and destination values of the kCopy
950   // instruction. This data structure holds pointers to these elements for
951   // each kCopy instruction in the graph.
952   struct CopyNodes {
953     // The source and destinations values of the kCopy instruction.
954     ValueNode* src = nullptr;
955     ValueNode* dest = nullptr;
956   };
957   absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
958 };
959 
960 }  // namespace
961 
962 // Add kCopy instructions to the given module to guarantee there is no
963 // live-range interference. Generally interference can only occur around kWhile
964 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module)965 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
966   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
967                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
968 
969   for (HloComputation* computation : module->computations()) {
970     for (HloInstruction* instruction : computation->instructions()) {
971       if (instruction->opcode() == HloOpcode::kWhile) {
972         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
973       } else if (instruction->opcode() == HloOpcode::kConditional) {
974         TF_RETURN_IF_ERROR(
975             AddCopiesForConditional(*alias_analysis, instruction));
976       }
977     }
978   }
979 
980   TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
981   return Status::OK();
982 }
983 
AddSpecialCaseCopies(HloModule * module)984 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
985   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
986   return AddSpecialCaseCopies(*call_graph, module);
987 }
988 
AddSpecialCaseCopies(const CallGraph & call_graph,HloModule * module)989 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
990                                            HloModule* module) {
991   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
992                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
993 
994   // Identify which shape indices of which instructions need to be copied. Store
995   // these results in 'instructions_to_copy'.
996   HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
997   auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
998                                                    const ShapeIndex& index) {
999     auto it = instructions_to_copy.find(instruction);
1000     if (it == instructions_to_copy.end()) {
1001       auto it_added = instructions_to_copy.emplace(
1002           std::piecewise_construct, std::forward_as_tuple(instruction),
1003           std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1004       it = it_added.first;
1005     }
1006     *it->second.mutable_element(index) = true;
1007   };
1008 
1009   // Iterate through values of all constants and entry parameters. These values
1010   // are special because they are held in read-only buffers. If any of these
1011   // values share a buffer with other values (for example, the init value of a
1012   // while is a constant) then copy the value at its definition and replace all
1013   // its uses with the copy.
1014   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1015     if (ValueIsReadOnly(*value) &&
1016         alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
1017       VLOG(2) << "Value " << value->ToShortString()
1018               << " is read only, but its buffer contains more than one value. "
1019                  "Copying.";
1020       add_index_to_copy(value->defining_instruction(), value->defining_index());
1021     }
1022   }
1023 
1024   // Identify copies which must be added at root instructions
1025   for (HloComputation* computation : module->computations()) {
1026     const CallGraphNode& node = call_graph.GetNode(computation);
1027     if (node.context() == CallContext::kParallel) {
1028       continue;
1029     }
1030     TF_RET_CHECK(node.context() == CallContext::kSequential);
1031 
1032     SpecialCaseCopyPolicy policy =
1033         GetSpecialCaseCopyPolicy(node, module, computation);
1034     HloInstruction* root = computation->root_instruction();
1035 
1036     // Mark nondistinct/ambiguous indices.
1037     absl::flat_hash_set<const HloBuffer*> seen;
1038     ShapeUtil::ForEachSubshape(
1039         root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1040           std::vector<const HloBuffer*> buffers_at_index =
1041               alias_analysis->ComputeBuffersAt(root, index);
1042           bool buffer_seen_before = false;
1043           for (const HloBuffer* buffer : buffers_at_index) {
1044             buffer_seen_before |= !seen.insert(buffer).second;
1045           }
1046           if (buffers_at_index.size() > 1 ||
1047               (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1048             VLOG(2) << "Index " << index << " of computation "
1049                     << computation->name() << " (" << root->name()
1050                     << ") has ambiguous or non-distinct buffer. Copying.";
1051             add_index_to_copy(root, index);
1052           }
1053         });
1054 
1055     for (const auto& pair :
1056          alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1057       const ShapeIndex& index = pair.first;
1058       const HloValueSet& value_set = pair.second;
1059       for (const HloValue* value : value_set.values()) {
1060         if (ShouldCopyRootValue(*value, policy)) {
1061           VLOG(2) << "Root of (" << root->name() << ") of computation("
1062                   << computation->name()
1063                   << ") has constant or parameter value at index " << index
1064                   << ". Copying.";
1065           add_index_to_copy(root, index);
1066         }
1067       }
1068     }
1069   }
1070 
1071   // Add copy instructions indicated in 'instructions_to_copy' to the module.
1072   for (const auto& pair : instructions_to_copy) {
1073     HloInstruction* instruction = pair.first;
1074     const ShapeTree<bool>& indices_to_copy = pair.second;
1075 
1076     ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1077     std::vector<HloInstruction*> users = instruction->users();
1078     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1079                         instruction->parent()->DeepCopyInstruction(
1080                             instruction, &indices_to_copy, &copies_added));
1081     for (HloInstruction* user : users) {
1082       TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
1083     }
1084     if (instruction == instruction->parent()->root_instruction()) {
1085       instruction->parent()->set_root_instruction(deep_copy);
1086     }
1087   }
1088   return Status::OK();
1089 }
1090 
VerifyNoLiveRangeInterference(const HloOrdering & ordering,HloModule * module)1091 Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering,
1092                                                     HloModule* module) {
1093   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1094                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
1095   TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
1096   return Status::OK();
1097 }
1098 
RemoveUnnecessaryCopies(const HloOrdering & ordering,HloModule * module)1099 Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
1100                                               HloModule* module) {
1101   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1102                       HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
1103 
1104   CopyRemover copy_remover(*module, *alias_analysis, ordering);
1105   if (VLOG_IS_ON(3)) {
1106     LOG(INFO) << "Removing unnecessary copies in " << module->name();
1107     LOG(INFO) << "Buffer values, in dependency order: ";
1108     for (const HloBuffer& buffer : alias_analysis->buffers()) {
1109       LOG(INFO) << "    HloBuffer " << buffer.id();
1110     }
1111   }
1112 
1113   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1114   for (HloComputation* computation : module->computations()) {
1115     for (HloInstruction* instruction : computation->instructions()) {
1116       if (instruction->opcode() == HloOpcode::kCopy &&
1117           copy_remover.TryElideCopy(instruction)) {
1118         TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
1119         TF_RETURN_IF_ERROR(
1120             instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
1121       }
1122     }
1123   }
1124   return Status::OK();
1125 }
1126 
Run(HloModule * module)1127 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
1128   // Copy insertion is performed in three steps:
1129   //
1130   // (1) Add copies conservatively to guarantee that there is no live-range
1131   //     interference. This is done simplistically and usually results in more
1132   //     copies than is strictly necessary.
1133   //
1134   // (2) Using a more fine-grained analysis, remove as many copies that were
1135   //     added in (1) as possible while ensuring no live-range interference.
1136   //
1137   // (3) Add copies to resolve issues not related to live range interference
1138   //     such as parameters and constants live out of the entry computation.
1139   //
1140   // We add copies then remove them (step (1) then (2)) rather than simply
1141   // adding only the copies that are necessary because, in general, it is
1142   // difficult to figure out the minimal set of copies to add once there is
1143   // interference. On the other hand, it is easy to determine if removing a copy
1144   // will introduce interference.
1145   //
1146   // The final copy insertion in (3) is done separately to simplify the
1147   // implementation of copy removal in (2) which is the most complicated part of
1148   // the pass. As is, copy removal only has to reason about live range
1149   // interference. If all copies were added in step (1) then copy removal would
1150   // also have to reason about things like constants and parameters live out of
1151   // the computation.
1152   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1153   if (!call_graph->IsFlattened()) {
1154     return FailedPrecondition(
1155         "Call graph must be flattened before copy insertion.");
1156   }
1157 
1158   int64 num_existing_copies = 0;
1159   if (VLOG_IS_ON(1)) {
1160     for (HloComputation* computation : module->computations()) {
1161       for (HloInstruction* instruction : computation->instructions()) {
1162         if (instruction->opcode() == HloOpcode::kCopy) {
1163           ++num_existing_copies;
1164         }
1165       }
1166     }
1167   }
1168 
1169   TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
1170 
1171   // Simplify the tuple structures introduced by the deep copies. This should be
1172   // done before removing copies (RemoveUnnecessaryCopies) because tuple
1173   // simplification changes dependencies in the graph which changes live range
1174   // interference in the graph. Also run DCE to remove the dead Tuple/GTE
1175   // instructions introduced by tuple simplification.
1176   TupleSimplifier tuple_simplifier;
1177   HloDCE dce;
1178   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1179   TF_RETURN_IF_ERROR(dce.Run(module).status());
1180   DumpHloModuleDuringPassIfEnabled(
1181       name(), "after adding copies to resolve interference", *module);
1182 
1183   DependencyHloOrdering dep_ordering(module);
1184   TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module));
1185 
1186   TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module));
1187   DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
1188                                    *module);
1189 
1190   TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
1191   DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
1192                                    *module);
1193 
1194   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1195   TF_RETURN_IF_ERROR(dce.Run(module).status());
1196   TF_DCHECK_OK(
1197       VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module));
1198 
1199   if (VLOG_IS_ON(1)) {
1200     int64 num_total_copies = 0;
1201     for (HloComputation* computation : module->computations()) {
1202       for (HloInstruction* instruction : computation->instructions()) {
1203         if (instruction->opcode() == HloOpcode::kCopy) {
1204           num_total_copies++;
1205         }
1206       }
1207     }
1208     VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
1209     VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
1210   }
1211 
1212   return true;
1213 }
1214 
1215 namespace {
1216 
IsWhileBody(const HloComputation * computation,const CallGraph & call_graph)1217 bool IsWhileBody(const HloComputation* computation,
1218                  const CallGraph& call_graph) {
1219   const CallGraphNode& node = call_graph.GetNode(computation);
1220 
1221   if (node.context() == CallContext::kSequential &&
1222       !node.caller_callsites().empty()) {
1223     // Callgraph should be flattened so sequential context computations can
1224     // have at most one caller.
1225     CHECK_EQ(node.caller_callsites().size(), 1);
1226     const HloInstruction* calling_instruction =
1227         node.caller_callsites()[0].instruction();
1228     if (calling_instruction->opcode() == HloOpcode::kWhile &&
1229         calling_instruction->while_body() == node.computation()) {
1230       return true;
1231     }
1232   }
1233   return false;
1234 }
1235 
1236 }  // namespace
1237 
AddCopiesForBufferAssignment(HloModule * module)1238 /* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment(
1239     HloModule* module) {
1240   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1241   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
1242                       HloDataflowAnalysis::Run(*module));
1243 
1244   bool changed = false;
1245 
1246   // If a buffer live out of a computation is a constant, a parameter, or not
1247   // defined in the computation, then copy it to account for the limited
1248   // computation-scoped analysis in buffer assignment. An exception to this rule
1249   // is the while body which is handled properly without copies.
1250   for (HloComputation* computation : module->computations()) {
1251     if (computation == module->entry_computation() ||
1252         IsWhileBody(computation, *call_graph)) {
1253       continue;
1254     }
1255 
1256     HloInstruction* root = computation->root_instruction();
1257     ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false);
1258     bool copy_root = false;
1259     for (const auto& pair : dataflow->GetInstructionValueSet(root)) {
1260       const ShapeIndex& index = pair.first;
1261       const HloValueSet& value_set = pair.second;
1262       for (const HloValue* value : value_set.values()) {
1263         HloInstruction* def = value->defining_instruction();
1264         if (def->parent() != computation ||
1265             def->opcode() == HloOpcode::kConstant ||
1266             def->opcode() == HloOpcode::kParameter) {
1267           *indices_to_copy.mutable_element(index) = true;
1268           copy_root = true;
1269         }
1270       }
1271     }
1272     if (copy_root) {
1273       TF_ASSIGN_OR_RETURN(
1274           HloInstruction * root_copy,
1275           computation->DeepCopyInstruction(root, &indices_to_copy));
1276       computation->set_root_instruction(root_copy);
1277       changed = true;
1278     }
1279   }
1280 
1281   TupleSimplifier tuple_simplifier;
1282   HloDCE dce;
1283   TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
1284                       tuple_simplifier.Run(module));
1285   TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module));
1286 
1287   return changed || tuple_simplifier_changed || dce_changed;
1288 }
1289 
1290 }  // namespace xla
1291