1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17 
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dump.h"
23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 namespace {
41 
42 using absl::StrAppend;
43 
IsReadonlyEntryParameterValue(const HloValue & value)44 bool IsReadonlyEntryParameterValue(const HloValue& value) {
45   const HloComputation* computation = value.defining_instruction()->parent();
46   return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
47          computation == computation->parent()->entry_computation() &&
48          !computation->parent()->input_output_alias_config().ParameterHasAlias(
49              value.defining_instruction()->parameter_number(), value.index());
50 }
51 
IsConstantValue(const HloValue & value)52 bool IsConstantValue(const HloValue& value) {
53   return value.defining_instruction()->opcode() == HloOpcode::kConstant;
54 }
55 
ValueIsReadOnly(const HloValue & value)56 bool ValueIsReadOnly(const HloValue& value) {
57   return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
58 }
59 
60 // Data structure describing the action which should be taken on parts of a
61 // computation buffers, with respect to the adding of special case copies.
62 struct SpecialCaseCopyPolicy {
63   // Insert a copy if the same buffer is found at multiple indices within the
64   // output tuple.
65   bool copy_root_replicated_buffers = false;
66   // If true, insert a copy if a buffer coming from a constant or a parameter
67   // is found within the output tuple.
68   bool copy_parameters_and_constants = false;
69 };
70 
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)71 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
72                                                HloModule* module,
73                                                HloComputation* computation) {
74   SpecialCaseCopyPolicy policy;
75   if (computation == module->entry_computation()) {
76     policy.copy_parameters_and_constants = true;
77     policy.copy_root_replicated_buffers = true;
78   }
79   return policy;
80 }
81 
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)82 bool ShouldCopyRootValue(const HloValue& value,
83                          const SpecialCaseCopyPolicy& policy) {
84   if (policy.copy_parameters_and_constants) {
85     return ValueIsReadOnly(value);
86   }
87   return false;
88 }
89 
90 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
91 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
92 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
93 // of 'to'.
94 //
95 // Requirements: 'from' and 'to' must have compatible shapes.
96 //
97 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
98 // the only index to copy. Prior to deep-copying we have:
99 //
100 //
101 //      'from'
102 //         |
103 //        ...
104 //         |
105 //       'to'
106 //
107 // DeepCopyAndAddControlEdges produces:
108 //
109 //       'from'
110 //        /   \
111 //      GTE   GTE
112 //       |     |
113 //     Copy    |
114 //    /   \   /
115 //   |    Tuple
116 //   |      |
117 //  ctrl   ...
118 //  edge    |
119 //   |      |
120 //   |    'to'
121 //   |    /   \
122 //   |  GTE   GTE
123 //    \  |     |
124 //     Copy    |
125 //        \   /
126 //        Tuple
127 //
128 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)129 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
130                            const ShapeTree<bool>& indices_to_copy) {
131   DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
132   // to/from_copy_tree hold the kCopy instruction produces by the deep
133   // copies. Elements which are not copied (indices_to_copy.element(index) ==
134   // false) have nullptr at that index.
135   ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
136                                             /*init_value=*/nullptr);
137   TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
138                       from->parent()->DeepCopyInstruction(
139                           from, &indices_to_copy, &from_copy_tree));
140 
141   ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
142   TF_ASSIGN_OR_RETURN(
143       HloInstruction * to_deep_copy,
144       to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
145 
146   // Add control edges between the respective kCopy instructions.
147   for (const auto& pair : from_copy_tree) {
148     const ShapeIndex& index = pair.first;
149     HloInstruction* from_copy = pair.second;
150     HloInstruction* to_copy = to_copy_tree.element(index);
151     if (from_copy == nullptr) {
152       TF_RET_CHECK(to_copy == nullptr);
153       continue;
154     }
155     TF_RET_CHECK(to_copy != nullptr);
156     TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
157   }
158 
159   return std::make_pair(from_deep_copy, to_deep_copy);
160 }
161 
162 // Compute the indices of the loop state which need copies in order to avoid
163 // live range interference. Generally, an element in the loop state does not
164 // need to be copied if the element is passed through transparently through the
165 // body.
166 //
167 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)168 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
169                            const HloInstruction* xla_while,
170                            ShapeTree<bool>* indices_to_copy) {
171   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
172 
173   bool any_copies = false;
174   const HloInstruction* init = xla_while->operand(0);
175   for (auto& pair : *indices_to_copy) {
176     const ShapeIndex& index = pair.first;
177     bool& should_copy = pair.second;
178     // If there is any ambiguity, then loop state must be copied.
179     if (dataflow.GetValueSet(init, index).values().size() > 1 ||
180         dataflow.GetValueSet(xla_while, index).values().size() > 1) {
181       should_copy = true;
182     } else {
183       // If the output of the while instruction is not the same as the init
184       // value of the while, then this element is not passed through the body
185       // transparently and must be copied.
186       should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
187                     dataflow.GetUniqueValueAt(init, index);
188     }
189     any_copies |= should_copy;
190   }
191   return any_copies;
192 }
193 
194 // Compute the indices of the conditional outputs which need copies. Umambiguous
195 // buffers(buffer with only one value) don't need copies.
IndicesToCopyForConditional(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_conditional,ShapeTree<bool> * indices_to_copy)196 bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow,
197                                  const HloInstruction* xla_conditional,
198                                  ShapeTree<bool>* indices_to_copy) {
199   DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(),
200                                xla_conditional->shape()));
201 
202   bool any_copies = false;
203   for (auto& pair : *indices_to_copy) {
204     const ShapeIndex& index = pair.first;
205     bool& should_copy = pair.second;
206 
207     CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1);
208 
209     auto value = dataflow.GetValueSet(xla_conditional, index).values()[0];
210     // The conditional must be copied if the value is a phi.
211     should_copy =
212         value->is_phi() && value->defining_instruction() == xla_conditional;
213     any_copies |= should_copy;
214   }
215   return any_copies;
216 }
217 
218 // Add kCopy instructions around the given kWhile instruction to eliminate any
219 // possible live range interference of HLO values assuming a dependency-based
220 // ordering. Copies are added conservatively. There  likely are copies which are
221 // not strictly necessary, but they are removed later in the pass via
222 // RemoveUnnecessaryCopies.
223 //
224 // Elements (each ShapeIndex) in the loop state are considered independently.  A
225 // copy is added to each element of the loop state which is modified in the
226 // while body. For each such element, a total of three kCopy instructions are
227 // added at following locations:
228 //
229 //   (1) The init value is copied before the kWhile instruction. Before:
230 //
231 //           (Init)
232 //             |
233 //           kWhile
234 //             |
235 //            ...
236 //
237 //       After:
238 //
239 //           (Init)
240 //             |
241 //           kCopy
242 //             |
243 //           kWhile
244 //             |
245 //            ...
246 //
247 //       This copy is necessary in case the init value is simultaneously live
248 //       with the kWhile.
249 //
250 //   (2) Copies are added to the parameter and root of the while body
251 //       computation. Before:
252 //
253 //           kParameter
254 //               |
255 //              ...
256 //               |
257 //           (body root)
258 //
259 //       After:
260 //
261 //           kParameter
262 //               |
263 //             kCopy ----------+
264 //               |             |
265 //              ...           ctrl
266 //               |            edge
267 //           (body root)       |
268 //               |             |
269 //             kCopy <---------+
270 //
271 //       The root kCopy becomes the new root of the computation. Both copies are
272 //       necessary to any potential interference between the parameter value and
273 //       the root value. The control edge prevents potential interference
274 //       between the copies themselves.
275 //
276 // If the loop state is a tuple then the above kCopy instructions are a deep
277 // copy constructed of kCopy, kGetTupleElement, and kTuple instruction as
278 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)279 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
280                          HloInstruction* xla_while) {
281   VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
282   TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
283 
284   ShapeTree<bool> indices_to_copy(xla_while->shape());
285   if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
286                              &indices_to_copy)) {
287     VLOG(2) << "No copies necessary for kWhile instruction "
288             << xla_while->name();
289     return Status::OK();
290   }
291 
292   VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
293   for (auto& pair : indices_to_copy) {
294     if (pair.second) {
295       VLOG(2) << "  " << pair.first;
296     }
297   }
298 
299   // Deep copy init.
300   HloInstruction* while_init = xla_while->mutable_operand(0);
301   TF_ASSIGN_OR_RETURN(
302       HloInstruction * while_init_copy,
303       xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
304   TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
305 
306   // Deep copy the parameter and the root. Extend a control edge from the copy
307   // of the parameter value to the corresponding copy value of the root.
308   HloComputation* body = xla_while->while_body();
309   HloInstruction* param = body->parameter_instruction(0);
310   HloInstruction* root = body->root_instruction();
311 
312   // If param is the root then all indices should have been passed through the
313   // while body and we should have returned early above.
314   TF_RET_CHECK(param != root);
315 
316   // Copy users before making a deep copy of the parameter as the deep copy
317   // will create new users of the parameter (eg, the GTE instructions of the
318   // deep copy).
319   std::vector<HloInstruction*> param_users = param->users();
320 
321   TF_ASSIGN_OR_RETURN(auto pair,
322                       DeepCopyAndAddControlEdges(param, root, indices_to_copy));
323 
324   HloInstruction* param_copy = pair.first;
325   HloInstruction* root_copy = pair.second;
326 
327   for (HloInstruction* user : param_users) {
328     TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
329   }
330 
331   body->set_root_instruction(root_copy);
332   return Status::OK();
333 }
334 
335 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
336 // will remove the unnecessary copies.
AddCopiesForInPlaceOperation(const HloAliasAnalysis & alias_analysis,HloInstruction * in_place_op,int64 operand_number)337 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
338                                     HloInstruction* in_place_op,
339                                     int64 operand_number) {
340   VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
341   HloInstruction* operand = in_place_op->mutable_operand(operand_number);
342   TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
343                       in_place_op->parent()->DeepCopyInstruction(operand));
344   TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
345   return Status::OK();
346 }
347 
348 // Conservatively adds copies before root instruction of entry computation and
349 // each aliased parameter to resolve interference of aliased input and output
350 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
351 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)352 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
353   HloComputation* entry = module->entry_computation();
354   HloInstruction* root = entry->root_instruction();
355 
356   ShapeTree<bool> output_indices_to_copy(root->shape());
357   std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
358       entry->num_parameters());
359   bool has_alias = false;
360   for (auto* param : entry->parameter_instructions()) {
361     bool param_has_alias = false;
362     ShapeTree<bool> param_indices_to_copy(param->shape());
363 
364     module->input_output_alias_config().ForEachAlias(
365         [&](const ShapeIndex& output_index,
366             const HloInputOutputAliasConfig::Alias& alias) {
367           if (alias.parameter_number == param->parameter_number()) {
368             param_has_alias = true;
369             *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
370                 true;
371             *(output_indices_to_copy.mutable_element(output_index)) = true;
372           }
373         });
374 
375     if (!param_has_alias) {
376       continue;
377     }
378 
379     TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
380     TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
381 
382     has_alias = true;
383     // Store a snapshot of users before DeepCopyInstruction, as
384     // DeepCopyInstruction introduces new users of the instruction.
385     std::vector<HloInstruction*> users = param->users();
386     ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
387                                                /*init_value=*/nullptr);
388     TF_ASSIGN_OR_RETURN(HloInstruction * copied,
389                         entry->DeepCopyInstruction(
390                             param, &param_indices_to_copy, &param_copy_tree));
391     for (HloInstruction* user : users) {
392       TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
393     }
394 
395     copied_parameters[param->parameter_number()] = param_copy_tree;
396   }
397 
398   if (!has_alias) {
399     return Status::OK();
400   }
401 
402   // Add copies before root instruction.
403   ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
404                                               /*init_value=*/nullptr);
405 
406   TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
407                       root->parent()->DeepCopyInstruction(
408                           root, &output_indices_to_copy, &output_copy_tree));
409 
410   // Add control dependencies between the input/output copies.
411   TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
412       [&](const ShapeIndex& output_index,
413           const HloInputOutputAliasConfig::Alias& alias) -> Status {
414         if (!copied_parameters[alias.parameter_number]) {
415           return Status::OK();
416         }
417         HloInstruction* from =
418             copied_parameters[alias.parameter_number]->element(
419                 alias.parameter_index);
420         HloInstruction* to = output_copy_tree.element(output_index);
421 
422         TF_RET_CHECK(from != nullptr);
423         TF_RET_CHECK(to != nullptr);
424         TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
425         return Status::OK();
426       }));
427 
428   entry->set_root_instruction(root_copied);
429 
430   return Status::OK();
431 }
432 
433 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)434 Status StripControlDependenciesFrom(HloInstruction* instruction) {
435   while (!instruction->control_successors().empty()) {
436     TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
437         instruction->control_successors().front()));
438   }
439 
440   while (!instruction->control_predecessors().empty()) {
441     TF_RETURN_IF_ERROR(
442         instruction->control_predecessors().front()->RemoveControlDependencyTo(
443             instruction));
444   }
445 
446   return Status::OK();
447 }
448 
449 // Class which tracks the HLO values within each HLO buffer in the module
450 // during copy removal.
451 //
452 // The values are held in a linked list where there is one list for each
453 // buffer. Removing a copy instruction merges together the values in the
454 // source buffer of the copy to the destination buffer of the copy. This class
455 // tracks these value lists as copies are removed from the graph (and value
456 // lists are merged).
457 //
458 // The CopyRemover object is initialized to match the state of
459 // HloAliasAnalysis. However, as copies are removed this state diverges. The
460 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
461 // a fully updatable alias analysis is very slow.
462 class CopyRemover {
463  public:
464   // The values held in a single HLO buffer are represented using a linked
465   // list. An element type in this list is ValueNode.
466   //
467   // This linked list is hand-rolled to enable efficient splicing of lists
468   // using only references to list elements without knowing which lists are
469   // being spliced. std::list requires a reference to the list object to
470   // splice.
471   struct ValueNode {
ValueNodexla::__anon747efefc0111::CopyRemover::ValueNode472     explicit ValueNode(const HloValue* v) : value(v) {}
473 
474     const HloValue* value;
475 
476     // The uses are maintained outside of HloValue::uses() because
477     // HloValue::uses() is not updatable (a fully updatable dataflow analysis
478     // is slow).
479     std::vector<const HloUse*> uses;
480 
481     // next/prev elements in the linked list. The list is circularly linked so
482     // these values are never null for elements in the list.
483     ValueNode* prev = nullptr;
484     ValueNode* next = nullptr;
485   };
486 
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,const HloOrdering & ordering)487   CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
488               const HloOrdering& ordering)
489       : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
490     // Construct a list for each HLO buffer in the alias analysis. Maintain a
491     // map from HloValue to the respective list element representing that
492     // value. The map is used to construct the copy info map below.
493     absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
494     for (const HloBuffer& buffer : alias_analysis.buffers()) {
495       // No copies should have been inserted within fused computations, so no
496       // need to remove them. HloOrdering isn't compatible with HloValues inside
497       // fusions, so skip copy removal for them.
498       if (buffer.values().at(0)->defining_instruction()->IsFused()) {
499         continue;
500       }
501       // Verify values contained in the buffer are strictly ordered. This
502       // should always be the case after adding copies to eliminate
503       // interference. Specifically, the addition of the control flow edges
504       // between copies added around aliased operations (kWhile) guarantees
505       // this strict order.
506       for (const HloValue* value_a : buffer.values()) {
507         if (value_a->shape().IsToken()) {
508           // Token values have no representation and cannot interfere.
509           continue;
510         }
511         for (const HloValue* value_b : buffer.values()) {
512           if (value_a != value_b) {
513             DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
514                                                      dataflow_) ||
515                    ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
516                                                      dataflow_))
517                 << value_a->ToShortString() << " and "
518                 << value_b->ToShortString() << " are not ordered";
519           }
520         }
521       }
522 
523       std::vector<const HloValue*> values = buffer.values();
524       absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
525         return ordering_.IsDefinedBefore(*a, *b);
526       });
527 
528       // Create a list containing all of the values in the buffer.
529       AddValueList(values, &value_to_node);
530     }
531 
532     // Create copy_map_ which contains the source and destination values
533     // of all copies.
534     CreateCopyMap(module, value_to_node);
535 
536     XLA_VLOG_LINES(3, ToString());
537     TF_DCHECK_OK(Verify());
538   }
539 
540   // Add a list containing the given values to CopyRemover. This
541   // represents the values contained in a single buffer. For each value in
542   // 'values' an entry is created in value_to_node which indicates the
543   // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)544   void AddValueList(
545       absl::Span<const HloValue* const> values,
546       absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
547     ValueNode* tail = nullptr;
548     ValueNode* head = nullptr;
549     for (const HloValue* value : values) {
550       auto new_node = new ValueNode(value);
551       (*value_to_node)[value] = new_node;
552 
553       // Copy the HLO values's uses into the ValueNode for the value. These
554       // uses in ValueNode are updated as copies are removed.
555       new_node->uses.reserve(value->uses().size());
556       for (const HloUse& use : value->uses()) {
557         new_node->uses.push_back(&use);
558       }
559 
560       // Connect the new node into the linked list.
561       if (tail == nullptr) {
562         head = new_node;
563       } else {
564         tail->next = new_node;
565         new_node->prev = tail;
566       }
567       tail = new_node;
568     }
569 
570     // The linked list is circular so connect the head and tail.
571     tail->next = head;
572     head->prev = tail;
573     value_lists_.insert(head);
574   }
575 
576   // This method also fills in copy_map_ which indicates which nodes
577   // in the value lists corresponding to the source and destination values of
578   // kCopy instructions. value_to_node should map each HloValue to its
579   // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)580   void CreateCopyMap(
581       const HloModule& module,
582       const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
583     for (HloComputation* computation : module.MakeNonfusionComputations()) {
584       for (HloInstruction* instruction : computation->instructions()) {
585         // Add copies with unambiguous source values to the map. Copies with
586         // ambiguous sources are not removable.
587         if (instruction->opcode() == HloOpcode::kCopy) {
588           const HloValueSet& src_value_set =
589               dataflow_.GetValueSet(instruction->operand(0));
590           if (src_value_set.values().size() == 1) {
591             CopyNodes& copy_node = copy_map_[instruction];
592             copy_node.dest =
593                 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
594             copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
595           }
596         }
597       }
598     }
599   }
600 
~CopyRemover()601   ~CopyRemover() {
602     for (const ValueNode* head : value_lists_) {
603       const ValueNode* p = head;
604       do {
605         const ValueNode* tmp = p->next;
606         delete p;
607         p = tmp;
608       } while (p != head);
609     }
610   }
611 
612   // Verify invariants within the linked lists.
Verify() const613   Status Verify() const {
614     for (const ValueNode* head : value_lists_) {
615       const ValueNode* p = head;
616       do {
617         // Verify links between elements are consistent.
618         TF_RET_CHECK(p->prev->next == p);
619         TF_RET_CHECK(p->next->prev == p);
620 
621         const HloInstruction* def = p->value->defining_instruction();
622         if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
623           TF_RET_CHECK(copy_map_.at(def).dest == p);
624         }
625         for (const HloUse* use : p->uses) {
626           if (use->instruction->opcode() == HloOpcode::kCopy &&
627               ContainsKey(copy_map_, use->instruction)) {
628             TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
629           }
630         }
631 
632         p = p->next;
633       } while (p != head);
634     }
635     return Status::OK();
636   }
637 
638   // Try to elide the given copy. Elision of a copy is possible only if no
639   // live range interference is introduced by the copy's elimination. If
640   // elision is possible, then the internal state (value lists) are updated,
641   // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy)642   bool TryElideCopy(const HloInstruction* copy) {
643     VLOG(2) << "Trying to remove " << copy->name();
644 
645     if (!ContainsKey(copy_map_, copy)) {
646       VLOG(2) << copy->name() << " is not removable";
647       return false;
648     }
649     if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
650       VLOG(2) << copy->name() << " is not removable (shape mismatch)";
651       return false;
652     }
653     const CopyNodes& copy_node = copy_map_.at(copy);
654     ValueNode* src = copy_node.src;
655     ValueNode* dest = copy_node.dest;
656     DCHECK(src != nullptr);
657     DCHECK(dest != nullptr);
658 
659     VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
660     VLOG(3) << "Source buffer values: " << ValueListToString(src);
661     VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
662 
663     // A kCopy instruction copies an HLO value from a source buffer and
664     // defines an HLO value in a destination buffer. Most generally, the
665     // source and destination buffers may each hold more than one value at
666     // different points in the computation so we define the following:
667     //
668     //   Values in source buffer:      {s_0, ..., s_n}
669     //   Values in destination buffer: {d_0, ..., d_m}
670     //
671     // A kCopy instruction between these buffers copies a value s_x in the
672     // source buffer and defines a value d_y in the destination buffer. The
673     // elision of a copy merges the source and destination buffers together,
674     // so the list of values for the source and destination buffers are
675     // merged.
676     //
677     // We handle two different cases for copy elision:
678     //
679     //  (1) the kCopy defines the first value in the destination buffer (d_0).
680     //
681     //  (2) the kCopy copies the last value in the source buffer (s_n).
682     //
683     // For the remaining case where the kCopy copies a not-last value from the
684     // source buffer to a not-first value of the destination buffer, the kCopy
685     // instruction cannot be removed. This case is generated, for example, if
686     // the kCopy copies a while body parameter of the loop state at one tuple
687     // index to a different tuple index in the while body root. Removal of the
688     // copy necessarily results in live range interference of values in the
689     // loop state at the two different tuple indices.
690     //
691     //  We can only perform copy elision if the resulting merged values have
692     //  totally ordered live ranges; otherwise the merged buffer would have
693     //  live range interference.
694     if (src->next == dest) {
695       // In the process of eliding copies, its possible for a copy to have the
696       // same source and destination buffer. In this case, the copy can be
697       // safely removed.
698       VLOG(2) << copy->name() << " source and destination buffers are same.";
699     } else if (IsHead(*dest)) {
700       // The copy copies an arbitrary value in the source buffer (call it s_x)
701       // and defines d_0, the first value in the destination buffer. After
702       // merging, the values in the combined buffer must be strictly ordered
703       // as follows** to elide the copy:
704       //
705       // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
706       //
707       // Removing the copy eliminates d_0, and uses of d_0 become uses of
708       // s_x. In the above ordering, the live range of d_m will be ordered
709       // before the live range of s_{x+1} and the definition and all uses of
710       // s_x will be ordered before the definition of d_1. To make sure the
711       // copy elision is safe, the following code checks that this ordering is
712       // valid --- in particular we check it is safe to order d_m ahead of all
713       // the liverages at and after x_{x+1}, and it is safe to order all uses
714       // of s_x before the definition of d_1, by checking the live range
715       // constraints for each pair --- we cannot skip the later checks because
716       // the live range ordering is not guranteed to be transitive --- while it
717       // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
718       // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
719       // because the exclusiveness relation of non-overlapping computations is
720       // not transitive.
721       //
722       // ** Technically it might be possible to have a non-interfering
723       //    non-trivial interleaving of the values of the source and
724       //    destination buffers in the resulting order. However, this case is
725       //    slow and complicated to check and likely not worth it. So instead
726       //    we simply check for the case where *all* values of the destination
727       //    buffer (d_1 through d_m) are spliced into the point where the copy
728       //    used to be.
729       VLOG(2) << copy->name() << " defines the first value in its buffer";
730       for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
731            next_dest = Next(*next_dest)) {
732         // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
733         if (!LiveRangeBefore(*src, *next_dest)) {
734           VLOG(2) << "Not removing the copy: live range of "
735                   << src->value->ToShortString() << " is not before "
736                   << next_dest->value->ToShortString();
737           return false;
738         }
739       }
740       for (ValueNode* next_src = Next(*src); next_src != nullptr;
741            next_src = Next(*next_src)) {
742         // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
743         ValueNode* last_dest = dest->prev;
744         DCHECK(IsTail(*last_dest));
745         if (!LiveRangeBefore(*last_dest, *next_src)) {
746           VLOG(2) << "Not removing the copy: live range of "
747                   << last_dest->value->ToShortString() << " is not before "
748                   << next_src->value->ToShortString();
749           return false;
750         }
751       }
752 
753       // Splice in destination buffer values list right after 'src'.
754       SpliceAfter(dest, src);
755     } else if (IsTail(*src)) {
756       // The copy copies the last value in the source buffer, s_n, and defines
757       // an arbitrary value in the destination buffer, d_y.  After
758       // merging, the values in the combined buffer must be strictly ordered
759       // as follows** to elide the copy:
760       //
761       // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
762       //
763       // Removing the copy eliminates d_y, and uses of d_y become uses of
764       // s_n. To enforce the above order, the live range of d_{y-1} must be
765       // before the live range of s_0, and the live range of s_n must be
766       // before the live range of d_{y+1}.
767       //
768       // ** See comment above in the code handling Case (1).
769       VLOG(2) << copy->name() << " copies the last value ("
770               << src->value->ToShortString() << ") in its buffer";
771 
772       ValueNode* first_src = src->next;
773       DCHECK(IsHead(*first_src));
774       for (ValueNode* prev_dest = Prev(*dest);
775            // nullptr condition handled above in the first 'if' case.
776            prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
777         if (!LiveRangeBefore(*prev_dest, *first_src)) {
778           // Live range of value d_{y-1} is not before s_0.
779           VLOG(2) << "Not removing the copy: live range of "
780                   << prev_dest->value->ToShortString() << " is not before "
781                   << first_src->value->ToShortString();
782           return false;
783         }
784       }
785       for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
786            next_dest = Next(*next_dest)) {
787         if (!LiveRangeBefore(*src, *next_dest)) {
788           // Live range of value s_n is not before d_{y+1}.
789           VLOG(2) << "Not removing the copy: live range of "
790                   << src->value->ToShortString() << " is not before "
791                   << next_dest->value->ToShortString();
792           return false;
793         }
794       }
795 
796       // Splice source buffer values list right after 'prev_dest'.
797       SpliceAfter(first_src, Prev(*dest));
798     } else {
799       VLOG(2) << copy->name()
800               << " copies value in middle of source buffer to value in middle "
801                  "of destination buffer";
802       return false;
803     }
804 
805     RemoveCopyValue(dest);
806 
807     XLA_VLOG_LINES(4, ToString());
808     TF_DCHECK_OK(Verify());
809 
810     return true;
811   }
812 
813   // Delete the given ValueNode associated with a elided kCopy
814   // instruction. This should be called after splicing the value lists of the
815   // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)816   void RemoveCopyValue(ValueNode* copy_value_node) {
817     CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
818              HloOpcode::kCopy);
819     ValueNode* operand_node = copy_value_node->prev;
820     CHECK(operand_node != copy_value_node);
821 
822     VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
823             << " => " << copy_value_node->value->ToShortString();
824 
825     // Splice out the copy value node.
826     operand_node->next = copy_value_node->next;
827     copy_value_node->next->prev = operand_node;
828 
829     // Patch up uses. Remove use of copy from operand_node uses.
830     auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
831                                                       const HloUse* use) {
832       return use->instruction == copy_value_node->value->defining_instruction();
833     });
834     CHECK(it != operand_node->uses.end());
835     operand_node->uses.erase(it);
836 
837     // If the elided copy has any uses which are themselves kCopy instructions
838     // then patch up the copy info to reflect the that this kCopy instruction
839     // has a different operand (the operand of the elided copy).
840     for (const HloUse* copy_use : copy_value_node->uses) {
841       operand_node->uses.push_back(copy_use);
842       if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
843           ContainsKey(copy_map_, copy_use->instruction)) {
844         copy_map_.at(copy_use->instruction).src = operand_node;
845       }
846     }
847 
848     // Delete the copy info and the value node.
849     copy_map_.erase(copy_value_node->value->defining_instruction());
850     delete copy_value_node;
851   }
852 
853   // Returns true if the live range of given value 'a' is before the live
854   // range of 'b'.
855   //
856   // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
857   // updated as copies are removed.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)858   bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
859     if (a.uses.empty()) {
860       VLOG(2) << "Empty uses for " << *a.value;
861       return ordering_.IsDefinedBefore(*a.value, *b.value);
862     }
863     return ordering_.UsesBeforeValueDefinition(a.uses, *b.value, dataflow_);
864   }
865 
866   // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const867   bool IsTail(const ValueNode& node) const {
868     return ContainsKey(value_lists_, node.next);
869   }
870 
871   // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const872   bool IsHead(const ValueNode& node) const {
873     return ContainsKey(value_lists_, &node);
874   }
875 
876   // Returns the next node in the list after 'node'. If 'node' is the
877   // tail, then nullptr is returned.
Next(const ValueNode & node) const878   ValueNode* Next(const ValueNode& node) const {
879     if (IsTail(node)) {
880       return nullptr;
881     } else {
882       return node.next;
883     }
884   }
885 
886   // Returns the previous node in the list before 'node'. If 'node'
887   // is the head, then nullptr is returned.
Prev(const ValueNode & node) const888   ValueNode* Prev(const ValueNode& node) const {
889     if (IsHead(node)) {
890       return nullptr;
891     } else {
892       return node.prev;
893     }
894   }
895 
896   // Splices the entire linked list with 'head' as its head right after the
897   // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)898   void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
899     DCHECK(IsHead(*head));
900     value_lists_.erase(head);
901 
902     ValueNode* tail = head->prev;
903     tail->next = insert_after->next;
904     insert_after->next->prev = tail;
905 
906     insert_after->next = head;
907     head->prev = insert_after;
908   }
909 
ValueListToString(const ValueNode * element)910   string ValueListToString(const ValueNode* element) {
911     const ValueNode* head = element;
912     while (!IsHead(*head)) {
913       head = Prev(*head);
914     }
915     std::vector<const HloValue*> values;
916     for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
917       values.push_back(p->value);
918     }
919     return absl::StrCat("{",
920                         absl::StrJoin(values, ", ",
921                                       [](string* s, const HloValue* value) {
922                                         StrAppend(s, value->ToShortString());
923                                       }),
924                         "}");
925   }
926 
ToString() const927   string ToString() const {
928     string out = absl::StrCat("CopyRemover:\n");
929     StrAppend(&out, "  Def-use chains in each buffer:\n");
930     for (const ValueNode* head : value_lists_) {
931       StrAppend(&out, "    Buffer defined by ", head->value->ToShortString(),
932                 ":\n");
933       const ValueNode* p = head;
934       do {
935         StrAppend(&out, "      ", p->value->ToShortString(), ", uses: ",
936                   absl::StrJoin(p->uses, "; ",
937                                 [](string* s, const HloUse* use) {
938                                   StrAppend(s, use->ToString());
939                                 }),
940                   "\n");
941 
942         p = p->next;
943       } while (p != head);
944     }
945     StrAppend(&out, "  Potentially removable copies:\n");
946     for (const auto& pair : copy_map_) {
947       const HloInstruction* copy = pair.first;
948       const CopyNodes& copy_info = pair.second;
949 
950       StrAppend(&out, "    ", copy->name(), " : ",
951                 copy_info.src->value->ToShortString(), " => ",
952                 copy_info.dest->value->ToShortString(), "\n");
953     }
954     return out;
955   }
956 
957  private:
958   const HloDataflowAnalysis& dataflow_;
959   const HloOrdering& ordering_;
960 
961   // The heads of all the value lists. Each value list represents the HLO
962   // values contained in a particular HLO buffer. The values in the list are
963   // in dependency order.
964   absl::flat_hash_set<const ValueNode*> value_lists_;
965 
966   // Copy removal requires fast access to the value list elements
967   // corresponding to the source and destination values of the kCopy
968   // instruction. This data structure holds pointers to these elements for
969   // each kCopy instruction in the graph.
970   struct CopyNodes {
971     // The source and destinations values of the kCopy instruction.
972     ValueNode* src = nullptr;
973     ValueNode* dest = nullptr;
974   };
975   absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
976 };
977 
978 }  // namespace
979 
980 // We add copies for all non-phi indices of the true and false computation
981 // roots, in order to resolve interference. We later rely on
982 // RemoveUnnecessaryCopies to drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)983 Status CopyInsertion::AddCopiesForConditional(
984     const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
985   VLOG(2) << "Adding copies for kConditional instruction "
986           << conditional->name();
987   ShapeTree<bool> indices_to_copy(conditional->shape());
988   TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
989   if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
990                                    conditional, &indices_to_copy)) {
991     VLOG(2) << "No copies necessary for kWhile instruction "
992             << conditional->name();
993     return Status::OK();
994   }
995 
996   for (HloComputation* computation : conditional->branch_computations()) {
997     HloInstruction* root = computation->root_instruction();
998     std::vector<HloInstruction*> users = root->users();
999     TF_ASSIGN_OR_RETURN(
1000         HloInstruction * deep_copy,
1001         computation->DeepCopyInstruction(root, &indices_to_copy));
1002     for (HloInstruction* user : users) {
1003       TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
1004     }
1005     computation->set_root_instruction(deep_copy);
1006   }
1007   return Status::OK();
1008 }
1009 
1010 // Add kCopy instructions to the given module to guarantee there is no
1011 // live-range interference. Generally interference can only occur around kWhile
1012 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module)1013 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
1014   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1015                       HloAliasAnalysis::Run(module, can_share_buffer_));
1016 
1017   for (HloComputation* computation : module->MakeNonfusionComputations()) {
1018     for (HloInstruction* instruction :
1019          computation->MakeInstructionPostOrder()) {
1020       if (instruction->opcode() == HloOpcode::kWhile) {
1021         TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
1022       } else if (instruction->opcode() == HloOpcode::kConditional) {
1023         TF_RETURN_IF_ERROR(
1024             AddCopiesForConditional(*alias_analysis, instruction));
1025       } else {
1026         for (const auto& operand_and_output_index :
1027              HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
1028           const HloUse& operand = operand_and_output_index.first;
1029           CHECK_EQ(operand.operand_index, ShapeIndex{})
1030               << "Support for non-{} shape operand not currently implemented.";
1031           TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
1032               *alias_analysis, instruction, operand.operand_number));
1033         }
1034       }
1035     }
1036   }
1037 
1038   TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
1039   return Status::OK();
1040 }
1041 
AddSpecialCaseCopies(HloModule * module)1042 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
1043   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1044   return AddSpecialCaseCopies(*call_graph, module);
1045 }
1046 
AddSpecialCaseCopies(const CallGraph & call_graph,HloModule * module)1047 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
1048                                            HloModule* module) {
1049   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1050                       HloAliasAnalysis::Run(module, can_share_buffer_));
1051 
1052   // Identify which shape indices of which instructions need to be copied. Store
1053   // these results in 'instructions_to_copy'.
1054   HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
1055   auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
1056                                                    const ShapeIndex& index) {
1057     auto it = instructions_to_copy.find(instruction);
1058     if (it == instructions_to_copy.end()) {
1059       auto it_added = instructions_to_copy.emplace(
1060           std::piecewise_construct, std::forward_as_tuple(instruction),
1061           std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1062       it = it_added.first;
1063     }
1064     *it->second.mutable_element(index) = true;
1065   };
1066 
1067   // Iterate through values of all constants and entry parameters. These values
1068   // are special because they are held in read-only buffers. If any of these
1069   // values share a buffer with other values (for example, the init value of a
1070   // while is a constant) then copy the value at its definition and replace all
1071   // its uses with the copy.
1072   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1073     if (ValueIsReadOnly(*value) &&
1074         alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
1075       VLOG(2) << "Value " << value->ToShortString()
1076               << " is read only, but its buffer contains more than one value. "
1077                  "Copying.";
1078       add_index_to_copy(value->defining_instruction(), value->defining_index());
1079     }
1080   }
1081 
1082   // Identify copies which must be added at root instructions
1083   for (HloComputation* computation : module->computations()) {
1084     const CallGraphNode& node = call_graph.GetNode(computation);
1085     if (node.context() == CallContext::kParallel) {
1086       continue;
1087     }
1088     TF_RET_CHECK(node.context() == CallContext::kSequential);
1089 
1090     SpecialCaseCopyPolicy policy =
1091         GetSpecialCaseCopyPolicy(node, module, computation);
1092     HloInstruction* root = computation->root_instruction();
1093 
1094     // Mark nondistinct/ambiguous indices.
1095     absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
1096     ShapeUtil::ForEachSubshape(
1097         root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1098           std::vector<const HloBuffer*> buffers_at_index =
1099               alias_analysis->ComputeBuffersAt(root, index);
1100           bool buffer_seen_before = false;
1101           for (const HloBuffer* buffer : buffers_at_index) {
1102             buffer_seen_before |= !seen.emplace(buffer, index).second;
1103           }
1104 
1105           if (buffer_seen_before && policy.copy_root_replicated_buffers &&
1106               computation == module->entry_computation() &&
1107               module->input_output_alias_config().OutputHasAlias(index) &&
1108               buffers_at_index.size() == 1) {
1109             absl::optional<HloInputOutputAliasConfig::Alias> alias =
1110                 module->input_output_alias_config().GetAliasedParameter(index);
1111             CHECK(alias) << "Alias does not exist";
1112             const ShapeIndex& other_index = seen[buffers_at_index[0]];
1113             VLOG(2) << "Output indices " << index.ToString() << " and "
1114                     << other_index.ToString() << " are both aliased to "
1115                     << alias->parameter_number << " copying " << other_index;
1116             add_index_to_copy(root, other_index);
1117             return;
1118           }
1119 
1120           if (buffers_at_index.size() > 1 ||
1121               (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1122             VLOG(2) << "Index " << index << " of computation "
1123                     << computation->name() << " (" << root->name()
1124                     << ") has ambiguous or non-distinct buffer. Copying.";
1125             add_index_to_copy(root, index);
1126           }
1127         });
1128 
1129     for (const auto& pair :
1130          alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1131       const ShapeIndex& index = pair.first;
1132       const HloValueSet& value_set = pair.second;
1133       for (const HloValue* value : value_set.values()) {
1134         if (ShouldCopyRootValue(*value, policy)) {
1135           VLOG(2) << "Root of (" << root->name() << ") of computation("
1136                   << computation->name()
1137                   << ") has constant or parameter value at index " << index
1138                   << ". Copying.";
1139           add_index_to_copy(root, index);
1140         }
1141       }
1142     }
1143   }
1144 
1145   // Add copy instructions indicated in 'instructions_to_copy' to the module.
1146   for (const auto& pair : instructions_to_copy) {
1147     HloInstruction* instruction = pair.first;
1148     const ShapeTree<bool>& indices_to_copy = pair.second;
1149 
1150     ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1151     std::vector<HloInstruction*> users = instruction->users();
1152     TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1153                         instruction->parent()->DeepCopyInstruction(
1154                             instruction, &indices_to_copy, &copies_added));
1155     for (HloInstruction* user : users) {
1156       TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
1157     }
1158     if (instruction == instruction->parent()->root_instruction()) {
1159       instruction->parent()->set_root_instruction(deep_copy);
1160     }
1161   }
1162   return Status::OK();
1163 }
1164 
GetNumExistingCopies(const HloModule * module)1165 static int64 GetNumExistingCopies(const HloModule* module) {
1166   int64 num_existing_copies = 0;
1167   for (HloComputation* computation : module->computations()) {
1168     for (HloInstruction* instruction : computation->instructions()) {
1169       if (instruction->opcode() == HloOpcode::kCopy) {
1170         ++num_existing_copies;
1171       }
1172     }
1173   }
1174   return num_existing_copies;
1175 }
1176 
RemoveUnnecessaryCopies(const HloOrdering & ordering,HloModule * module)1177 Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
1178                                               HloModule* module) {
1179   XLA_VLOG_LINES(4, module->ToString());
1180   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1181                       HloAliasAnalysis::Run(module, can_share_buffer_));
1182 
1183   CopyRemover copy_remover(*module, *alias_analysis, ordering);
1184   if (VLOG_IS_ON(3)) {
1185     LOG(INFO) << "Removing unnecessary copies in " << module->name();
1186     LOG(INFO) << "Buffer values, in dependency order: ";
1187     for (const HloBuffer& buffer : alias_analysis->buffers()) {
1188       LOG(INFO) << "    HloBuffer " << buffer.id();
1189     }
1190   }
1191 
1192   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1193 
1194   int64 num_existing_copies = GetNumExistingCopies(module);
1195   bool changed = true;
1196   int64 num_iterations = -1;
1197   while (changed) {
1198     CHECK_LE(++num_iterations, num_existing_copies);
1199     changed = false;
1200     VLOG(2) << "Running fixpoint iteration " << num_iterations
1201             << " of copy elision";
1202     for (HloComputation* computation : module->computations()) {
1203       for (HloInstruction* instruction : computation->instructions()) {
1204         if (instruction->opcode() == HloOpcode::kCopy &&
1205             copy_remover.TryElideCopy(instruction)) {
1206           changed = true;
1207           TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
1208           TF_RETURN_IF_ERROR(
1209               instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
1210         }
1211       }
1212     }
1213   }
1214   return Status::OK();
1215 }
1216 
Run(HloModule * module)1217 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
1218   // Copy insertion is performed in three steps:
1219   //
1220   // (1) Add copies conservatively to guarantee that there is no live-range
1221   //     interference. This is done simplistically and usually results in more
1222   //     copies than is strictly necessary.
1223   //
1224   // (2) Using a more fine-grained analysis, remove as many copies that were
1225   //     added in (1) as possible while ensuring no live-range interference.
1226   //
1227   // (3) Add copies to resolve issues not related to live range interference
1228   //     such as parameters and constants live out of the entry computation.
1229   //
1230   // We add copies then remove them (step (1) then (2)) rather than simply
1231   // adding only the copies that are necessary because, in general, it is
1232   // difficult to figure out the minimal set of copies to add once there is
1233   // interference. On the other hand, it is easy to determine if removing a copy
1234   // will introduce interference.
1235   //
1236   // The final copy insertion in (3) is done separately to simplify the
1237   // implementation of copy removal in (2) which is the most complicated part of
1238   // the pass. As is, copy removal only has to reason about live range
1239   // interference. If all copies were added in step (1) then copy removal would
1240   // also have to reason about things like constants and parameters live out of
1241   // the computation.
1242   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1243   if (!call_graph->IsFlattened()) {
1244     return FailedPrecondition(
1245         "Call graph must be flattened before copy insertion.");
1246   }
1247 
1248   TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
1249 
1250   // Simplify the tuple structures introduced by the deep copies. This should be
1251   // done before removing copies (RemoveUnnecessaryCopies) because tuple
1252   // simplification changes dependencies in the graph which changes live range
1253   // interference in the graph. Also run DCE to remove the dead Tuple/GTE
1254   // instructions introduced by tuple simplification.
1255   TupleSimplifier tuple_simplifier;
1256   HloDCE dce;
1257   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1258   TF_RETURN_IF_ERROR(dce.Run(module).status());
1259   DumpHloModuleDuringPassIfEnabled(
1260       name(), "after adding copies to resolve interference", *module);
1261 
1262   TF_RETURN_IF_ERROR(
1263       RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
1264   DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
1265                                    *module);
1266   TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
1267   DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
1268                                    *module);
1269 
1270   TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1271   TF_RETURN_IF_ERROR(dce.Run(module).status());
1272 
1273   if (VLOG_IS_ON(1)) {
1274     int64 num_total_copies = 0;
1275     for (HloComputation* computation : module->computations()) {
1276       for (HloInstruction* instruction : computation->instructions()) {
1277         if (instruction->opcode() == HloOpcode::kCopy) {
1278           num_total_copies++;
1279         }
1280       }
1281     }
1282     VLOG(1) << "Num copies before copy-insertion: "
1283             << GetNumExistingCopies(module);
1284     VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
1285   }
1286 
1287   return true;
1288 }
1289 }  // namespace xla
1290