1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dump.h"
23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40 namespace {
41
42 using absl::StrAppend;
43
IsReadonlyEntryParameterValue(const HloValue & value)44 bool IsReadonlyEntryParameterValue(const HloValue& value) {
45 const HloComputation* computation = value.defining_instruction()->parent();
46 return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
47 computation == computation->parent()->entry_computation() &&
48 !computation->parent()->input_output_alias_config().ParameterHasAlias(
49 value.defining_instruction()->parameter_number(), value.index());
50 }
51
IsConstantValue(const HloValue & value)52 bool IsConstantValue(const HloValue& value) {
53 return value.defining_instruction()->opcode() == HloOpcode::kConstant;
54 }
55
ValueIsReadOnly(const HloValue & value)56 bool ValueIsReadOnly(const HloValue& value) {
57 return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
58 }
59
60 // Data structure describing the action which should be taken on parts of a
61 // computation buffers, with respect to the adding of special case copies.
62 struct SpecialCaseCopyPolicy {
63 // Insert a copy if the same buffer is found at multiple indices within the
64 // output tuple.
65 bool copy_root_replicated_buffers = false;
66 // If true, insert a copy if a buffer coming from a constant or a parameter
67 // is found within the output tuple.
68 bool copy_parameters_and_constants = false;
69 };
70
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)71 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
72 HloModule* module,
73 HloComputation* computation) {
74 SpecialCaseCopyPolicy policy;
75 if (computation == module->entry_computation()) {
76 policy.copy_parameters_and_constants = true;
77 policy.copy_root_replicated_buffers = true;
78 }
79 return policy;
80 }
81
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)82 bool ShouldCopyRootValue(const HloValue& value,
83 const SpecialCaseCopyPolicy& policy) {
84 if (policy.copy_parameters_and_constants) {
85 return ValueIsReadOnly(value);
86 }
87 return false;
88 }
89
90 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
91 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
92 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
93 // of 'to'.
94 //
95 // Requirements: 'from' and 'to' must have compatible shapes.
96 //
97 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
98 // the only index to copy. Prior to deep-copying we have:
99 //
100 //
101 // 'from'
102 // |
103 // ...
104 // |
105 // 'to'
106 //
107 // DeepCopyAndAddControlEdges produces:
108 //
109 // 'from'
110 // / \
111 // GTE GTE
112 // | |
113 // Copy |
114 // / \ /
115 // | Tuple
116 // | |
117 // ctrl ...
118 // edge |
119 // | |
120 // | 'to'
121 // | / \
122 // | GTE GTE
123 // \ | |
124 // Copy |
125 // \ /
126 // Tuple
127 //
128 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)129 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
130 const ShapeTree<bool>& indices_to_copy) {
131 DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
132 // to/from_copy_tree hold the kCopy instruction produces by the deep
133 // copies. Elements which are not copied (indices_to_copy.element(index) ==
134 // false) have nullptr at that index.
135 ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
136 /*init_value=*/nullptr);
137 TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
138 from->parent()->DeepCopyInstruction(
139 from, &indices_to_copy, &from_copy_tree));
140
141 ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
142 TF_ASSIGN_OR_RETURN(
143 HloInstruction * to_deep_copy,
144 to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
145
146 // Add control edges between the respective kCopy instructions.
147 for (const auto& pair : from_copy_tree) {
148 const ShapeIndex& index = pair.first;
149 HloInstruction* from_copy = pair.second;
150 HloInstruction* to_copy = to_copy_tree.element(index);
151 if (from_copy == nullptr) {
152 TF_RET_CHECK(to_copy == nullptr);
153 continue;
154 }
155 TF_RET_CHECK(to_copy != nullptr);
156 TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
157 }
158
159 return std::make_pair(from_deep_copy, to_deep_copy);
160 }
161
162 // Compute the indices of the loop state which need copies in order to avoid
163 // live range interference. Generally, an element in the loop state does not
164 // need to be copied if the element is passed through transparently through the
165 // body.
166 //
167 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)168 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
169 const HloInstruction* xla_while,
170 ShapeTree<bool>* indices_to_copy) {
171 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
172
173 bool any_copies = false;
174 const HloInstruction* init = xla_while->operand(0);
175 for (auto& pair : *indices_to_copy) {
176 const ShapeIndex& index = pair.first;
177 bool& should_copy = pair.second;
178 // If there is any ambiguity, then loop state must be copied.
179 if (dataflow.GetValueSet(init, index).values().size() > 1 ||
180 dataflow.GetValueSet(xla_while, index).values().size() > 1) {
181 should_copy = true;
182 } else {
183 // If the output of the while instruction is not the same as the init
184 // value of the while, then this element is not passed through the body
185 // transparently and must be copied.
186 should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
187 dataflow.GetUniqueValueAt(init, index);
188 }
189 any_copies |= should_copy;
190 }
191 return any_copies;
192 }
193
194 // Compute the indices of the conditional outputs which need copies. Umambiguous
195 // buffers(buffer with only one value) don't need copies.
IndicesToCopyForConditional(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_conditional,ShapeTree<bool> * indices_to_copy)196 bool IndicesToCopyForConditional(const HloDataflowAnalysis& dataflow,
197 const HloInstruction* xla_conditional,
198 ShapeTree<bool>* indices_to_copy) {
199 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(),
200 xla_conditional->shape()));
201
202 bool any_copies = false;
203 for (auto& pair : *indices_to_copy) {
204 const ShapeIndex& index = pair.first;
205 bool& should_copy = pair.second;
206
207 CHECK_EQ(dataflow.GetValueSet(xla_conditional, index).values().size(), 1);
208
209 auto value = dataflow.GetValueSet(xla_conditional, index).values()[0];
210 // The conditional must be copied if the value is a phi.
211 should_copy =
212 value->is_phi() && value->defining_instruction() == xla_conditional;
213 any_copies |= should_copy;
214 }
215 return any_copies;
216 }
217
218 // Add kCopy instructions around the given kWhile instruction to eliminate any
219 // possible live range interference of HLO values assuming a dependency-based
220 // ordering. Copies are added conservatively. There likely are copies which are
221 // not strictly necessary, but they are removed later in the pass via
222 // RemoveUnnecessaryCopies.
223 //
224 // Elements (each ShapeIndex) in the loop state are considered independently. A
225 // copy is added to each element of the loop state which is modified in the
226 // while body. For each such element, a total of three kCopy instructions are
227 // added at following locations:
228 //
229 // (1) The init value is copied before the kWhile instruction. Before:
230 //
231 // (Init)
232 // |
233 // kWhile
234 // |
235 // ...
236 //
237 // After:
238 //
239 // (Init)
240 // |
241 // kCopy
242 // |
243 // kWhile
244 // |
245 // ...
246 //
247 // This copy is necessary in case the init value is simultaneously live
248 // with the kWhile.
249 //
250 // (2) Copies are added to the parameter and root of the while body
251 // computation. Before:
252 //
253 // kParameter
254 // |
255 // ...
256 // |
257 // (body root)
258 //
259 // After:
260 //
261 // kParameter
262 // |
263 // kCopy ----------+
264 // | |
265 // ... ctrl
266 // | edge
267 // (body root) |
268 // | |
269 // kCopy <---------+
270 //
271 // The root kCopy becomes the new root of the computation. Both copies are
272 // necessary to any potential interference between the parameter value and
273 // the root value. The control edge prevents potential interference
274 // between the copies themselves.
275 //
276 // If the loop state is a tuple then the above kCopy instructions are a deep
277 // copy constructed of kCopy, kGetTupleElement, and kTuple instruction as
278 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)279 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
280 HloInstruction* xla_while) {
281 VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
282 TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
283
284 ShapeTree<bool> indices_to_copy(xla_while->shape());
285 if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
286 &indices_to_copy)) {
287 VLOG(2) << "No copies necessary for kWhile instruction "
288 << xla_while->name();
289 return Status::OK();
290 }
291
292 VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
293 for (auto& pair : indices_to_copy) {
294 if (pair.second) {
295 VLOG(2) << " " << pair.first;
296 }
297 }
298
299 // Deep copy init.
300 HloInstruction* while_init = xla_while->mutable_operand(0);
301 TF_ASSIGN_OR_RETURN(
302 HloInstruction * while_init_copy,
303 xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
304 TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
305
306 // Deep copy the parameter and the root. Extend a control edge from the copy
307 // of the parameter value to the corresponding copy value of the root.
308 HloComputation* body = xla_while->while_body();
309 HloInstruction* param = body->parameter_instruction(0);
310 HloInstruction* root = body->root_instruction();
311
312 // If param is the root then all indices should have been passed through the
313 // while body and we should have returned early above.
314 TF_RET_CHECK(param != root);
315
316 // Copy users before making a deep copy of the parameter as the deep copy
317 // will create new users of the parameter (eg, the GTE instructions of the
318 // deep copy).
319 std::vector<HloInstruction*> param_users = param->users();
320
321 TF_ASSIGN_OR_RETURN(auto pair,
322 DeepCopyAndAddControlEdges(param, root, indices_to_copy));
323
324 HloInstruction* param_copy = pair.first;
325 HloInstruction* root_copy = pair.second;
326
327 for (HloInstruction* user : param_users) {
328 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
329 }
330
331 body->set_root_instruction(root_copy);
332 return Status::OK();
333 }
334
335 // Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
336 // will remove the unnecessary copies.
AddCopiesForInPlaceOperation(const HloAliasAnalysis & alias_analysis,HloInstruction * in_place_op,int64 operand_number)337 Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
338 HloInstruction* in_place_op,
339 int64 operand_number) {
340 VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
341 HloInstruction* operand = in_place_op->mutable_operand(operand_number);
342 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
343 in_place_op->parent()->DeepCopyInstruction(operand));
344 TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
345 return Status::OK();
346 }
347
348 // Conservatively adds copies before root instruction of entry computation and
349 // each aliased parameter to resolve interference of aliased input and output
350 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
351 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)352 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
353 HloComputation* entry = module->entry_computation();
354 HloInstruction* root = entry->root_instruction();
355
356 ShapeTree<bool> output_indices_to_copy(root->shape());
357 std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
358 entry->num_parameters());
359 bool has_alias = false;
360 for (auto* param : entry->parameter_instructions()) {
361 bool param_has_alias = false;
362 ShapeTree<bool> param_indices_to_copy(param->shape());
363
364 module->input_output_alias_config().ForEachAlias(
365 [&](const ShapeIndex& output_index,
366 const HloInputOutputAliasConfig::Alias& alias) {
367 if (alias.parameter_number == param->parameter_number()) {
368 param_has_alias = true;
369 *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
370 true;
371 *(output_indices_to_copy.mutable_element(output_index)) = true;
372 }
373 });
374
375 if (!param_has_alias) {
376 continue;
377 }
378
379 TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
380 TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
381
382 has_alias = true;
383 // Store a snapshot of users before DeepCopyInstruction, as
384 // DeepCopyInstruction introduces new users of the instruction.
385 std::vector<HloInstruction*> users = param->users();
386 ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
387 /*init_value=*/nullptr);
388 TF_ASSIGN_OR_RETURN(HloInstruction * copied,
389 entry->DeepCopyInstruction(
390 param, ¶m_indices_to_copy, ¶m_copy_tree));
391 for (HloInstruction* user : users) {
392 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
393 }
394
395 copied_parameters[param->parameter_number()] = param_copy_tree;
396 }
397
398 if (!has_alias) {
399 return Status::OK();
400 }
401
402 // Add copies before root instruction.
403 ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
404 /*init_value=*/nullptr);
405
406 TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
407 root->parent()->DeepCopyInstruction(
408 root, &output_indices_to_copy, &output_copy_tree));
409
410 // Add control dependencies between the input/output copies.
411 TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
412 [&](const ShapeIndex& output_index,
413 const HloInputOutputAliasConfig::Alias& alias) -> Status {
414 if (!copied_parameters[alias.parameter_number]) {
415 return Status::OK();
416 }
417 HloInstruction* from =
418 copied_parameters[alias.parameter_number]->element(
419 alias.parameter_index);
420 HloInstruction* to = output_copy_tree.element(output_index);
421
422 TF_RET_CHECK(from != nullptr);
423 TF_RET_CHECK(to != nullptr);
424 TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
425 return Status::OK();
426 }));
427
428 entry->set_root_instruction(root_copied);
429
430 return Status::OK();
431 }
432
433 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)434 Status StripControlDependenciesFrom(HloInstruction* instruction) {
435 while (!instruction->control_successors().empty()) {
436 TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
437 instruction->control_successors().front()));
438 }
439
440 while (!instruction->control_predecessors().empty()) {
441 TF_RETURN_IF_ERROR(
442 instruction->control_predecessors().front()->RemoveControlDependencyTo(
443 instruction));
444 }
445
446 return Status::OK();
447 }
448
449 // Class which tracks the HLO values within each HLO buffer in the module
450 // during copy removal.
451 //
452 // The values are held in a linked list where there is one list for each
453 // buffer. Removing a copy instruction merges together the values in the
454 // source buffer of the copy to the destination buffer of the copy. This class
455 // tracks these value lists as copies are removed from the graph (and value
456 // lists are merged).
457 //
458 // The CopyRemover object is initialized to match the state of
459 // HloAliasAnalysis. However, as copies are removed this state diverges. The
460 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
461 // a fully updatable alias analysis is very slow.
462 class CopyRemover {
463 public:
464 // The values held in a single HLO buffer are represented using a linked
465 // list. An element type in this list is ValueNode.
466 //
467 // This linked list is hand-rolled to enable efficient splicing of lists
468 // using only references to list elements without knowing which lists are
469 // being spliced. std::list requires a reference to the list object to
470 // splice.
471 struct ValueNode {
ValueNodexla::__anon747efefc0111::CopyRemover::ValueNode472 explicit ValueNode(const HloValue* v) : value(v) {}
473
474 const HloValue* value;
475
476 // The uses are maintained outside of HloValue::uses() because
477 // HloValue::uses() is not updatable (a fully updatable dataflow analysis
478 // is slow).
479 std::vector<const HloUse*> uses;
480
481 // next/prev elements in the linked list. The list is circularly linked so
482 // these values are never null for elements in the list.
483 ValueNode* prev = nullptr;
484 ValueNode* next = nullptr;
485 };
486
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,const HloOrdering & ordering)487 CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
488 const HloOrdering& ordering)
489 : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
490 // Construct a list for each HLO buffer in the alias analysis. Maintain a
491 // map from HloValue to the respective list element representing that
492 // value. The map is used to construct the copy info map below.
493 absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
494 for (const HloBuffer& buffer : alias_analysis.buffers()) {
495 // No copies should have been inserted within fused computations, so no
496 // need to remove them. HloOrdering isn't compatible with HloValues inside
497 // fusions, so skip copy removal for them.
498 if (buffer.values().at(0)->defining_instruction()->IsFused()) {
499 continue;
500 }
501 // Verify values contained in the buffer are strictly ordered. This
502 // should always be the case after adding copies to eliminate
503 // interference. Specifically, the addition of the control flow edges
504 // between copies added around aliased operations (kWhile) guarantees
505 // this strict order.
506 for (const HloValue* value_a : buffer.values()) {
507 if (value_a->shape().IsToken()) {
508 // Token values have no representation and cannot interfere.
509 continue;
510 }
511 for (const HloValue* value_b : buffer.values()) {
512 if (value_a != value_b) {
513 DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
514 dataflow_) ||
515 ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
516 dataflow_))
517 << value_a->ToShortString() << " and "
518 << value_b->ToShortString() << " are not ordered";
519 }
520 }
521 }
522
523 std::vector<const HloValue*> values = buffer.values();
524 absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
525 return ordering_.IsDefinedBefore(*a, *b);
526 });
527
528 // Create a list containing all of the values in the buffer.
529 AddValueList(values, &value_to_node);
530 }
531
532 // Create copy_map_ which contains the source and destination values
533 // of all copies.
534 CreateCopyMap(module, value_to_node);
535
536 XLA_VLOG_LINES(3, ToString());
537 TF_DCHECK_OK(Verify());
538 }
539
540 // Add a list containing the given values to CopyRemover. This
541 // represents the values contained in a single buffer. For each value in
542 // 'values' an entry is created in value_to_node which indicates the
543 // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)544 void AddValueList(
545 absl::Span<const HloValue* const> values,
546 absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
547 ValueNode* tail = nullptr;
548 ValueNode* head = nullptr;
549 for (const HloValue* value : values) {
550 auto new_node = new ValueNode(value);
551 (*value_to_node)[value] = new_node;
552
553 // Copy the HLO values's uses into the ValueNode for the value. These
554 // uses in ValueNode are updated as copies are removed.
555 new_node->uses.reserve(value->uses().size());
556 for (const HloUse& use : value->uses()) {
557 new_node->uses.push_back(&use);
558 }
559
560 // Connect the new node into the linked list.
561 if (tail == nullptr) {
562 head = new_node;
563 } else {
564 tail->next = new_node;
565 new_node->prev = tail;
566 }
567 tail = new_node;
568 }
569
570 // The linked list is circular so connect the head and tail.
571 tail->next = head;
572 head->prev = tail;
573 value_lists_.insert(head);
574 }
575
576 // This method also fills in copy_map_ which indicates which nodes
577 // in the value lists corresponding to the source and destination values of
578 // kCopy instructions. value_to_node should map each HloValue to its
579 // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)580 void CreateCopyMap(
581 const HloModule& module,
582 const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
583 for (HloComputation* computation : module.MakeNonfusionComputations()) {
584 for (HloInstruction* instruction : computation->instructions()) {
585 // Add copies with unambiguous source values to the map. Copies with
586 // ambiguous sources are not removable.
587 if (instruction->opcode() == HloOpcode::kCopy) {
588 const HloValueSet& src_value_set =
589 dataflow_.GetValueSet(instruction->operand(0));
590 if (src_value_set.values().size() == 1) {
591 CopyNodes& copy_node = copy_map_[instruction];
592 copy_node.dest =
593 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
594 copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
595 }
596 }
597 }
598 }
599 }
600
~CopyRemover()601 ~CopyRemover() {
602 for (const ValueNode* head : value_lists_) {
603 const ValueNode* p = head;
604 do {
605 const ValueNode* tmp = p->next;
606 delete p;
607 p = tmp;
608 } while (p != head);
609 }
610 }
611
612 // Verify invariants within the linked lists.
Verify() const613 Status Verify() const {
614 for (const ValueNode* head : value_lists_) {
615 const ValueNode* p = head;
616 do {
617 // Verify links between elements are consistent.
618 TF_RET_CHECK(p->prev->next == p);
619 TF_RET_CHECK(p->next->prev == p);
620
621 const HloInstruction* def = p->value->defining_instruction();
622 if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
623 TF_RET_CHECK(copy_map_.at(def).dest == p);
624 }
625 for (const HloUse* use : p->uses) {
626 if (use->instruction->opcode() == HloOpcode::kCopy &&
627 ContainsKey(copy_map_, use->instruction)) {
628 TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
629 }
630 }
631
632 p = p->next;
633 } while (p != head);
634 }
635 return Status::OK();
636 }
637
638 // Try to elide the given copy. Elision of a copy is possible only if no
639 // live range interference is introduced by the copy's elimination. If
640 // elision is possible, then the internal state (value lists) are updated,
641 // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy)642 bool TryElideCopy(const HloInstruction* copy) {
643 VLOG(2) << "Trying to remove " << copy->name();
644
645 if (!ContainsKey(copy_map_, copy)) {
646 VLOG(2) << copy->name() << " is not removable";
647 return false;
648 }
649 if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
650 VLOG(2) << copy->name() << " is not removable (shape mismatch)";
651 return false;
652 }
653 const CopyNodes& copy_node = copy_map_.at(copy);
654 ValueNode* src = copy_node.src;
655 ValueNode* dest = copy_node.dest;
656 DCHECK(src != nullptr);
657 DCHECK(dest != nullptr);
658
659 VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
660 VLOG(3) << "Source buffer values: " << ValueListToString(src);
661 VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
662
663 // A kCopy instruction copies an HLO value from a source buffer and
664 // defines an HLO value in a destination buffer. Most generally, the
665 // source and destination buffers may each hold more than one value at
666 // different points in the computation so we define the following:
667 //
668 // Values in source buffer: {s_0, ..., s_n}
669 // Values in destination buffer: {d_0, ..., d_m}
670 //
671 // A kCopy instruction between these buffers copies a value s_x in the
672 // source buffer and defines a value d_y in the destination buffer. The
673 // elision of a copy merges the source and destination buffers together,
674 // so the list of values for the source and destination buffers are
675 // merged.
676 //
677 // We handle two different cases for copy elision:
678 //
679 // (1) the kCopy defines the first value in the destination buffer (d_0).
680 //
681 // (2) the kCopy copies the last value in the source buffer (s_n).
682 //
683 // For the remaining case where the kCopy copies a not-last value from the
684 // source buffer to a not-first value of the destination buffer, the kCopy
685 // instruction cannot be removed. This case is generated, for example, if
686 // the kCopy copies a while body parameter of the loop state at one tuple
687 // index to a different tuple index in the while body root. Removal of the
688 // copy necessarily results in live range interference of values in the
689 // loop state at the two different tuple indices.
690 //
691 // We can only perform copy elision if the resulting merged values have
692 // totally ordered live ranges; otherwise the merged buffer would have
693 // live range interference.
694 if (src->next == dest) {
695 // In the process of eliding copies, its possible for a copy to have the
696 // same source and destination buffer. In this case, the copy can be
697 // safely removed.
698 VLOG(2) << copy->name() << " source and destination buffers are same.";
699 } else if (IsHead(*dest)) {
700 // The copy copies an arbitrary value in the source buffer (call it s_x)
701 // and defines d_0, the first value in the destination buffer. After
702 // merging, the values in the combined buffer must be strictly ordered
703 // as follows** to elide the copy:
704 //
705 // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
706 //
707 // Removing the copy eliminates d_0, and uses of d_0 become uses of
708 // s_x. In the above ordering, the live range of d_m will be ordered
709 // before the live range of s_{x+1} and the definition and all uses of
710 // s_x will be ordered before the definition of d_1. To make sure the
711 // copy elision is safe, the following code checks that this ordering is
712 // valid --- in particular we check it is safe to order d_m ahead of all
713 // the liverages at and after x_{x+1}, and it is safe to order all uses
714 // of s_x before the definition of d_1, by checking the live range
715 // constraints for each pair --- we cannot skip the later checks because
716 // the live range ordering is not guranteed to be transitive --- while it
717 // may be ok to have lr_1 before lr_2, and lr_2 before lv_3 while merging
718 // their buffers, it may not be ok to merge the buffers of lr_1 and lv_3,
719 // because the exclusiveness relation of non-overlapping computations is
720 // not transitive.
721 //
722 // ** Technically it might be possible to have a non-interfering
723 // non-trivial interleaving of the values of the source and
724 // destination buffers in the resulting order. However, this case is
725 // slow and complicated to check and likely not worth it. So instead
726 // we simply check for the case where *all* values of the destination
727 // buffer (d_1 through d_m) are spliced into the point where the copy
728 // used to be.
729 VLOG(2) << copy->name() << " defines the first value in its buffer";
730 for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
731 next_dest = Next(*next_dest)) {
732 // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
733 if (!LiveRangeBefore(*src, *next_dest)) {
734 VLOG(2) << "Not removing the copy: live range of "
735 << src->value->ToShortString() << " is not before "
736 << next_dest->value->ToShortString();
737 return false;
738 }
739 }
740 for (ValueNode* next_src = Next(*src); next_src != nullptr;
741 next_src = Next(*next_src)) {
742 // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
743 ValueNode* last_dest = dest->prev;
744 DCHECK(IsTail(*last_dest));
745 if (!LiveRangeBefore(*last_dest, *next_src)) {
746 VLOG(2) << "Not removing the copy: live range of "
747 << last_dest->value->ToShortString() << " is not before "
748 << next_src->value->ToShortString();
749 return false;
750 }
751 }
752
753 // Splice in destination buffer values list right after 'src'.
754 SpliceAfter(dest, src);
755 } else if (IsTail(*src)) {
756 // The copy copies the last value in the source buffer, s_n, and defines
757 // an arbitrary value in the destination buffer, d_y. After
758 // merging, the values in the combined buffer must be strictly ordered
759 // as follows** to elide the copy:
760 //
761 // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
762 //
763 // Removing the copy eliminates d_y, and uses of d_y become uses of
764 // s_n. To enforce the above order, the live range of d_{y-1} must be
765 // before the live range of s_0, and the live range of s_n must be
766 // before the live range of d_{y+1}.
767 //
768 // ** See comment above in the code handling Case (1).
769 VLOG(2) << copy->name() << " copies the last value ("
770 << src->value->ToShortString() << ") in its buffer";
771
772 ValueNode* first_src = src->next;
773 DCHECK(IsHead(*first_src));
774 for (ValueNode* prev_dest = Prev(*dest);
775 // nullptr condition handled above in the first 'if' case.
776 prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
777 if (!LiveRangeBefore(*prev_dest, *first_src)) {
778 // Live range of value d_{y-1} is not before s_0.
779 VLOG(2) << "Not removing the copy: live range of "
780 << prev_dest->value->ToShortString() << " is not before "
781 << first_src->value->ToShortString();
782 return false;
783 }
784 }
785 for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
786 next_dest = Next(*next_dest)) {
787 if (!LiveRangeBefore(*src, *next_dest)) {
788 // Live range of value s_n is not before d_{y+1}.
789 VLOG(2) << "Not removing the copy: live range of "
790 << src->value->ToShortString() << " is not before "
791 << next_dest->value->ToShortString();
792 return false;
793 }
794 }
795
796 // Splice source buffer values list right after 'prev_dest'.
797 SpliceAfter(first_src, Prev(*dest));
798 } else {
799 VLOG(2) << copy->name()
800 << " copies value in middle of source buffer to value in middle "
801 "of destination buffer";
802 return false;
803 }
804
805 RemoveCopyValue(dest);
806
807 XLA_VLOG_LINES(4, ToString());
808 TF_DCHECK_OK(Verify());
809
810 return true;
811 }
812
813 // Delete the given ValueNode associated with a elided kCopy
814 // instruction. This should be called after splicing the value lists of the
815 // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)816 void RemoveCopyValue(ValueNode* copy_value_node) {
817 CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
818 HloOpcode::kCopy);
819 ValueNode* operand_node = copy_value_node->prev;
820 CHECK(operand_node != copy_value_node);
821
822 VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
823 << " => " << copy_value_node->value->ToShortString();
824
825 // Splice out the copy value node.
826 operand_node->next = copy_value_node->next;
827 copy_value_node->next->prev = operand_node;
828
829 // Patch up uses. Remove use of copy from operand_node uses.
830 auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
831 const HloUse* use) {
832 return use->instruction == copy_value_node->value->defining_instruction();
833 });
834 CHECK(it != operand_node->uses.end());
835 operand_node->uses.erase(it);
836
837 // If the elided copy has any uses which are themselves kCopy instructions
838 // then patch up the copy info to reflect the that this kCopy instruction
839 // has a different operand (the operand of the elided copy).
840 for (const HloUse* copy_use : copy_value_node->uses) {
841 operand_node->uses.push_back(copy_use);
842 if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
843 ContainsKey(copy_map_, copy_use->instruction)) {
844 copy_map_.at(copy_use->instruction).src = operand_node;
845 }
846 }
847
848 // Delete the copy info and the value node.
849 copy_map_.erase(copy_value_node->value->defining_instruction());
850 delete copy_value_node;
851 }
852
853 // Returns true if the live range of given value 'a' is before the live
854 // range of 'b'.
855 //
856 // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
857 // updated as copies are removed.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)858 bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
859 if (a.uses.empty()) {
860 VLOG(2) << "Empty uses for " << *a.value;
861 return ordering_.IsDefinedBefore(*a.value, *b.value);
862 }
863 return ordering_.UsesBeforeValueDefinition(a.uses, *b.value, dataflow_);
864 }
865
866 // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const867 bool IsTail(const ValueNode& node) const {
868 return ContainsKey(value_lists_, node.next);
869 }
870
871 // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const872 bool IsHead(const ValueNode& node) const {
873 return ContainsKey(value_lists_, &node);
874 }
875
876 // Returns the next node in the list after 'node'. If 'node' is the
877 // tail, then nullptr is returned.
Next(const ValueNode & node) const878 ValueNode* Next(const ValueNode& node) const {
879 if (IsTail(node)) {
880 return nullptr;
881 } else {
882 return node.next;
883 }
884 }
885
886 // Returns the previous node in the list before 'node'. If 'node'
887 // is the head, then nullptr is returned.
Prev(const ValueNode & node) const888 ValueNode* Prev(const ValueNode& node) const {
889 if (IsHead(node)) {
890 return nullptr;
891 } else {
892 return node.prev;
893 }
894 }
895
896 // Splices the entire linked list with 'head' as its head right after the
897 // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)898 void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
899 DCHECK(IsHead(*head));
900 value_lists_.erase(head);
901
902 ValueNode* tail = head->prev;
903 tail->next = insert_after->next;
904 insert_after->next->prev = tail;
905
906 insert_after->next = head;
907 head->prev = insert_after;
908 }
909
ValueListToString(const ValueNode * element)910 string ValueListToString(const ValueNode* element) {
911 const ValueNode* head = element;
912 while (!IsHead(*head)) {
913 head = Prev(*head);
914 }
915 std::vector<const HloValue*> values;
916 for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
917 values.push_back(p->value);
918 }
919 return absl::StrCat("{",
920 absl::StrJoin(values, ", ",
921 [](string* s, const HloValue* value) {
922 StrAppend(s, value->ToShortString());
923 }),
924 "}");
925 }
926
ToString() const927 string ToString() const {
928 string out = absl::StrCat("CopyRemover:\n");
929 StrAppend(&out, " Def-use chains in each buffer:\n");
930 for (const ValueNode* head : value_lists_) {
931 StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
932 ":\n");
933 const ValueNode* p = head;
934 do {
935 StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
936 absl::StrJoin(p->uses, "; ",
937 [](string* s, const HloUse* use) {
938 StrAppend(s, use->ToString());
939 }),
940 "\n");
941
942 p = p->next;
943 } while (p != head);
944 }
945 StrAppend(&out, " Potentially removable copies:\n");
946 for (const auto& pair : copy_map_) {
947 const HloInstruction* copy = pair.first;
948 const CopyNodes& copy_info = pair.second;
949
950 StrAppend(&out, " ", copy->name(), " : ",
951 copy_info.src->value->ToShortString(), " => ",
952 copy_info.dest->value->ToShortString(), "\n");
953 }
954 return out;
955 }
956
957 private:
958 const HloDataflowAnalysis& dataflow_;
959 const HloOrdering& ordering_;
960
961 // The heads of all the value lists. Each value list represents the HLO
962 // values contained in a particular HLO buffer. The values in the list are
963 // in dependency order.
964 absl::flat_hash_set<const ValueNode*> value_lists_;
965
966 // Copy removal requires fast access to the value list elements
967 // corresponding to the source and destination values of the kCopy
968 // instruction. This data structure holds pointers to these elements for
969 // each kCopy instruction in the graph.
970 struct CopyNodes {
971 // The source and destinations values of the kCopy instruction.
972 ValueNode* src = nullptr;
973 ValueNode* dest = nullptr;
974 };
975 absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
976 };
977
978 } // namespace
979
980 // We add copies for all non-phi indices of the true and false computation
981 // roots, in order to resolve interference. We later rely on
982 // RemoveUnnecessaryCopies to drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)983 Status CopyInsertion::AddCopiesForConditional(
984 const HloAliasAnalysis& alias_analysis, HloInstruction* conditional) {
985 VLOG(2) << "Adding copies for kConditional instruction "
986 << conditional->name();
987 ShapeTree<bool> indices_to_copy(conditional->shape());
988 TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
989 if (!IndicesToCopyForConditional(alias_analysis.dataflow_analysis(),
990 conditional, &indices_to_copy)) {
991 VLOG(2) << "No copies necessary for kWhile instruction "
992 << conditional->name();
993 return Status::OK();
994 }
995
996 for (HloComputation* computation : conditional->branch_computations()) {
997 HloInstruction* root = computation->root_instruction();
998 std::vector<HloInstruction*> users = root->users();
999 TF_ASSIGN_OR_RETURN(
1000 HloInstruction * deep_copy,
1001 computation->DeepCopyInstruction(root, &indices_to_copy));
1002 for (HloInstruction* user : users) {
1003 TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
1004 }
1005 computation->set_root_instruction(deep_copy);
1006 }
1007 return Status::OK();
1008 }
1009
1010 // Add kCopy instructions to the given module to guarantee there is no
1011 // live-range interference. Generally interference can only occur around kWhile
1012 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module)1013 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
1014 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1015 HloAliasAnalysis::Run(module, can_share_buffer_));
1016
1017 for (HloComputation* computation : module->MakeNonfusionComputations()) {
1018 for (HloInstruction* instruction :
1019 computation->MakeInstructionPostOrder()) {
1020 if (instruction->opcode() == HloOpcode::kWhile) {
1021 TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
1022 } else if (instruction->opcode() == HloOpcode::kConditional) {
1023 TF_RETURN_IF_ERROR(
1024 AddCopiesForConditional(*alias_analysis, instruction));
1025 } else {
1026 for (const auto& operand_and_output_index :
1027 HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
1028 const HloUse& operand = operand_and_output_index.first;
1029 CHECK_EQ(operand.operand_index, ShapeIndex{})
1030 << "Support for non-{} shape operand not currently implemented.";
1031 TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
1032 *alias_analysis, instruction, operand.operand_number));
1033 }
1034 }
1035 }
1036 }
1037
1038 TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
1039 return Status::OK();
1040 }
1041
AddSpecialCaseCopies(HloModule * module)1042 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
1043 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1044 return AddSpecialCaseCopies(*call_graph, module);
1045 }
1046
AddSpecialCaseCopies(const CallGraph & call_graph,HloModule * module)1047 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
1048 HloModule* module) {
1049 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1050 HloAliasAnalysis::Run(module, can_share_buffer_));
1051
1052 // Identify which shape indices of which instructions need to be copied. Store
1053 // these results in 'instructions_to_copy'.
1054 HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
1055 auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
1056 const ShapeIndex& index) {
1057 auto it = instructions_to_copy.find(instruction);
1058 if (it == instructions_to_copy.end()) {
1059 auto it_added = instructions_to_copy.emplace(
1060 std::piecewise_construct, std::forward_as_tuple(instruction),
1061 std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1062 it = it_added.first;
1063 }
1064 *it->second.mutable_element(index) = true;
1065 };
1066
1067 // Iterate through values of all constants and entry parameters. These values
1068 // are special because they are held in read-only buffers. If any of these
1069 // values share a buffer with other values (for example, the init value of a
1070 // while is a constant) then copy the value at its definition and replace all
1071 // its uses with the copy.
1072 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1073 if (ValueIsReadOnly(*value) &&
1074 alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
1075 VLOG(2) << "Value " << value->ToShortString()
1076 << " is read only, but its buffer contains more than one value. "
1077 "Copying.";
1078 add_index_to_copy(value->defining_instruction(), value->defining_index());
1079 }
1080 }
1081
1082 // Identify copies which must be added at root instructions
1083 for (HloComputation* computation : module->computations()) {
1084 const CallGraphNode& node = call_graph.GetNode(computation);
1085 if (node.context() == CallContext::kParallel) {
1086 continue;
1087 }
1088 TF_RET_CHECK(node.context() == CallContext::kSequential);
1089
1090 SpecialCaseCopyPolicy policy =
1091 GetSpecialCaseCopyPolicy(node, module, computation);
1092 HloInstruction* root = computation->root_instruction();
1093
1094 // Mark nondistinct/ambiguous indices.
1095 absl::flat_hash_map<const HloBuffer*, ShapeIndex> seen;
1096 ShapeUtil::ForEachSubshape(
1097 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1098 std::vector<const HloBuffer*> buffers_at_index =
1099 alias_analysis->ComputeBuffersAt(root, index);
1100 bool buffer_seen_before = false;
1101 for (const HloBuffer* buffer : buffers_at_index) {
1102 buffer_seen_before |= !seen.emplace(buffer, index).second;
1103 }
1104
1105 if (buffer_seen_before && policy.copy_root_replicated_buffers &&
1106 computation == module->entry_computation() &&
1107 module->input_output_alias_config().OutputHasAlias(index) &&
1108 buffers_at_index.size() == 1) {
1109 absl::optional<HloInputOutputAliasConfig::Alias> alias =
1110 module->input_output_alias_config().GetAliasedParameter(index);
1111 CHECK(alias) << "Alias does not exist";
1112 const ShapeIndex& other_index = seen[buffers_at_index[0]];
1113 VLOG(2) << "Output indices " << index.ToString() << " and "
1114 << other_index.ToString() << " are both aliased to "
1115 << alias->parameter_number << " copying " << other_index;
1116 add_index_to_copy(root, other_index);
1117 return;
1118 }
1119
1120 if (buffers_at_index.size() > 1 ||
1121 (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1122 VLOG(2) << "Index " << index << " of computation "
1123 << computation->name() << " (" << root->name()
1124 << ") has ambiguous or non-distinct buffer. Copying.";
1125 add_index_to_copy(root, index);
1126 }
1127 });
1128
1129 for (const auto& pair :
1130 alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1131 const ShapeIndex& index = pair.first;
1132 const HloValueSet& value_set = pair.second;
1133 for (const HloValue* value : value_set.values()) {
1134 if (ShouldCopyRootValue(*value, policy)) {
1135 VLOG(2) << "Root of (" << root->name() << ") of computation("
1136 << computation->name()
1137 << ") has constant or parameter value at index " << index
1138 << ". Copying.";
1139 add_index_to_copy(root, index);
1140 }
1141 }
1142 }
1143 }
1144
1145 // Add copy instructions indicated in 'instructions_to_copy' to the module.
1146 for (const auto& pair : instructions_to_copy) {
1147 HloInstruction* instruction = pair.first;
1148 const ShapeTree<bool>& indices_to_copy = pair.second;
1149
1150 ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1151 std::vector<HloInstruction*> users = instruction->users();
1152 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1153 instruction->parent()->DeepCopyInstruction(
1154 instruction, &indices_to_copy, &copies_added));
1155 for (HloInstruction* user : users) {
1156 TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
1157 }
1158 if (instruction == instruction->parent()->root_instruction()) {
1159 instruction->parent()->set_root_instruction(deep_copy);
1160 }
1161 }
1162 return Status::OK();
1163 }
1164
GetNumExistingCopies(const HloModule * module)1165 static int64 GetNumExistingCopies(const HloModule* module) {
1166 int64 num_existing_copies = 0;
1167 for (HloComputation* computation : module->computations()) {
1168 for (HloInstruction* instruction : computation->instructions()) {
1169 if (instruction->opcode() == HloOpcode::kCopy) {
1170 ++num_existing_copies;
1171 }
1172 }
1173 }
1174 return num_existing_copies;
1175 }
1176
RemoveUnnecessaryCopies(const HloOrdering & ordering,HloModule * module)1177 Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
1178 HloModule* module) {
1179 XLA_VLOG_LINES(4, module->ToString());
1180 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1181 HloAliasAnalysis::Run(module, can_share_buffer_));
1182
1183 CopyRemover copy_remover(*module, *alias_analysis, ordering);
1184 if (VLOG_IS_ON(3)) {
1185 LOG(INFO) << "Removing unnecessary copies in " << module->name();
1186 LOG(INFO) << "Buffer values, in dependency order: ";
1187 for (const HloBuffer& buffer : alias_analysis->buffers()) {
1188 LOG(INFO) << " HloBuffer " << buffer.id();
1189 }
1190 }
1191
1192 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1193
1194 int64 num_existing_copies = GetNumExistingCopies(module);
1195 bool changed = true;
1196 int64 num_iterations = -1;
1197 while (changed) {
1198 CHECK_LE(++num_iterations, num_existing_copies);
1199 changed = false;
1200 VLOG(2) << "Running fixpoint iteration " << num_iterations
1201 << " of copy elision";
1202 for (HloComputation* computation : module->computations()) {
1203 for (HloInstruction* instruction : computation->instructions()) {
1204 if (instruction->opcode() == HloOpcode::kCopy &&
1205 copy_remover.TryElideCopy(instruction)) {
1206 changed = true;
1207 TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
1208 TF_RETURN_IF_ERROR(
1209 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
1210 }
1211 }
1212 }
1213 }
1214 return Status::OK();
1215 }
1216
Run(HloModule * module)1217 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
1218 // Copy insertion is performed in three steps:
1219 //
1220 // (1) Add copies conservatively to guarantee that there is no live-range
1221 // interference. This is done simplistically and usually results in more
1222 // copies than is strictly necessary.
1223 //
1224 // (2) Using a more fine-grained analysis, remove as many copies that were
1225 // added in (1) as possible while ensuring no live-range interference.
1226 //
1227 // (3) Add copies to resolve issues not related to live range interference
1228 // such as parameters and constants live out of the entry computation.
1229 //
1230 // We add copies then remove them (step (1) then (2)) rather than simply
1231 // adding only the copies that are necessary because, in general, it is
1232 // difficult to figure out the minimal set of copies to add once there is
1233 // interference. On the other hand, it is easy to determine if removing a copy
1234 // will introduce interference.
1235 //
1236 // The final copy insertion in (3) is done separately to simplify the
1237 // implementation of copy removal in (2) which is the most complicated part of
1238 // the pass. As is, copy removal only has to reason about live range
1239 // interference. If all copies were added in step (1) then copy removal would
1240 // also have to reason about things like constants and parameters live out of
1241 // the computation.
1242 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1243 if (!call_graph->IsFlattened()) {
1244 return FailedPrecondition(
1245 "Call graph must be flattened before copy insertion.");
1246 }
1247
1248 TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
1249
1250 // Simplify the tuple structures introduced by the deep copies. This should be
1251 // done before removing copies (RemoveUnnecessaryCopies) because tuple
1252 // simplification changes dependencies in the graph which changes live range
1253 // interference in the graph. Also run DCE to remove the dead Tuple/GTE
1254 // instructions introduced by tuple simplification.
1255 TupleSimplifier tuple_simplifier;
1256 HloDCE dce;
1257 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1258 TF_RETURN_IF_ERROR(dce.Run(module).status());
1259 DumpHloModuleDuringPassIfEnabled(
1260 name(), "after adding copies to resolve interference", *module);
1261
1262 TF_RETURN_IF_ERROR(
1263 RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
1264 DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
1265 *module);
1266 TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
1267 DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
1268 *module);
1269
1270 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1271 TF_RETURN_IF_ERROR(dce.Run(module).status());
1272
1273 if (VLOG_IS_ON(1)) {
1274 int64 num_total_copies = 0;
1275 for (HloComputation* computation : module->computations()) {
1276 for (HloInstruction* instruction : computation->instructions()) {
1277 if (instruction->opcode() == HloOpcode::kCopy) {
1278 num_total_copies++;
1279 }
1280 }
1281 }
1282 VLOG(1) << "Num copies before copy-insertion: "
1283 << GetNumExistingCopies(module);
1284 VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
1285 }
1286
1287 return true;
1288 }
1289 } // namespace xla
1290