1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/copy_insertion.h"
17
18 #include "absl/container/flat_hash_map.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/compiler/xla/service/dump.h"
23 #include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
31 #include "tensorflow/compiler/xla/service/logical_buffer.h"
32 #include "tensorflow/compiler/xla/service/tuple_simplifier.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40 namespace {
41
42 using absl::StrAppend;
43
IsReadonlyEntryParameterValue(const HloValue & value)44 bool IsReadonlyEntryParameterValue(const HloValue& value) {
45 const HloComputation* computation = value.defining_instruction()->parent();
46 return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
47 computation == computation->parent()->entry_computation() &&
48 !computation->parent()->input_output_alias_config().ParameterHasAlias(
49 value.defining_instruction()->parameter_number(), value.index());
50 }
51
IsConstantValue(const HloValue & value)52 bool IsConstantValue(const HloValue& value) {
53 return value.defining_instruction()->opcode() == HloOpcode::kConstant;
54 }
55
ValueIsReadOnly(const HloValue & value)56 bool ValueIsReadOnly(const HloValue& value) {
57 return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
58 }
59
60 // Data structure describing the action which should be taken on parts of a
61 // computation buffers, with respect to the adding of special case copies.
62 struct SpecialCaseCopyPolicy {
63 // Insert a copy if the same buffer is found at multiple indices within the
64 // output tuple.
65 bool copy_root_replicated_buffers = false;
66 // If true, insert a copy if a buffer coming from a constant or a parameter
67 // is found within the output tuple.
68 bool copy_parameters_and_constants = false;
69 };
70
GetSpecialCaseCopyPolicy(const CallGraphNode & node,HloModule * module,HloComputation * computation)71 SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node,
72 HloModule* module,
73 HloComputation* computation) {
74 SpecialCaseCopyPolicy policy;
75 if (computation == module->entry_computation()) {
76 policy.copy_parameters_and_constants = true;
77 policy.copy_root_replicated_buffers = true;
78 }
79 return policy;
80 }
81
ShouldCopyRootValue(const HloValue & value,const SpecialCaseCopyPolicy & policy)82 bool ShouldCopyRootValue(const HloValue& value,
83 const SpecialCaseCopyPolicy& policy) {
84 if (policy.copy_parameters_and_constants) {
85 return ValueIsReadOnly(value);
86 }
87 return false;
88 }
89
90 // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
91 // 'indices_to_copy'. Add control edges from the respective kCopy instructions
92 // in deep copy of 'from' to the respective kCopy instruction in the deep copy
93 // of 'to'.
94 //
95 // Requirements: 'from' and 'to' must have compatible shapes.
96 //
97 // For example, suppose 'from' and 'to' are two-element tuples where index 0 is
98 // the only index to copy. Prior to deep-copying we have:
99 //
100 //
101 // 'from'
102 // |
103 // ...
104 // |
105 // 'to'
106 //
107 // DeepCopyAndAddControlEdges produces:
108 //
109 // 'from'
110 // / \
111 // GTE GTE
112 // | |
113 // Copy |
114 // / \ /
115 // | Tuple
116 // | |
117 // ctrl ...
118 // edge |
119 // | |
120 // | 'to'
121 // | / \
122 // | GTE GTE
123 // \ | |
124 // Copy |
125 // \ /
126 // Tuple
127 //
128 StatusOr<std::pair<HloInstruction*, HloInstruction*>>
DeepCopyAndAddControlEdges(HloInstruction * from,HloInstruction * to,const ShapeTree<bool> & indices_to_copy)129 DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
130 const ShapeTree<bool>& indices_to_copy) {
131 DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
132 // to/from_copy_tree hold the kCopy instruction produces by the deep
133 // copies. Elements which are not copied (indices_to_copy.element(index) ==
134 // false) have nullptr at that index.
135 ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
136 /*init_value=*/nullptr);
137 TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
138 from->parent()->DeepCopyInstruction(
139 from, &indices_to_copy, &from_copy_tree));
140
141 ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
142 TF_ASSIGN_OR_RETURN(
143 HloInstruction * to_deep_copy,
144 to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
145
146 // Add control edges between the respective kCopy instructions.
147 for (const auto& pair : from_copy_tree) {
148 const ShapeIndex& index = pair.first;
149 HloInstruction* from_copy = pair.second;
150 HloInstruction* to_copy = to_copy_tree.element(index);
151 if (from_copy == nullptr) {
152 TF_RET_CHECK(to_copy == nullptr);
153 continue;
154 }
155 TF_RET_CHECK(to_copy != nullptr);
156 TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
157 }
158
159 return std::make_pair(from_deep_copy, to_deep_copy);
160 }
161
162 // Compute the indices of the loop state which need copies in order to avoid
163 // live range interference. Generally, an element in the loop state does not
164 // need to be copied if the element is passed through transparently through the
165 // body.
166 //
167 // Returns whether any indices need to be copied.
IndicesToCopyForWhile(const HloDataflowAnalysis & dataflow,const HloInstruction * xla_while,ShapeTree<bool> * indices_to_copy)168 bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
169 const HloInstruction* xla_while,
170 ShapeTree<bool>* indices_to_copy) {
171 DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
172
173 bool any_copies = false;
174 const HloInstruction* init = xla_while->operand(0);
175 for (auto& pair : *indices_to_copy) {
176 const ShapeIndex& index = pair.first;
177 bool& should_copy = pair.second;
178 // If there is any ambiguity, then loop state must be copied.
179 if (dataflow.GetValueSet(init, index).values().size() > 1 ||
180 dataflow.GetValueSet(xla_while, index).values().size() > 1) {
181 should_copy = true;
182 } else {
183 // If the output of the while instruction is not the same as the init
184 // value of the while, then this element is not passed through the body
185 // transparently and must be copied.
186 should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
187 dataflow.GetUniqueValueAt(init, index);
188 }
189 any_copies |= should_copy;
190 }
191 return any_copies;
192 }
193
194 // Add kCopy instructions around the given kWhile instruction to eliminate any
195 // possible live range interference of HLO values assuming a dependency-based
196 // ordering (HloDependencyOrdering). Copies are added conservatively. There
197 // likely are copies which are not strictly necessary, but they are removed
198 // later in the pass via RemoveUnnecessaryCopies.
199 //
200 //
201 // Elements (each ShapeIndex) in the loop state are considered independently. A
202 // copy is added to each element of the loop state which is modified in the
203 // while body. For each such element, a total of three kCopy instructions are
204 // added at following locations:
205 //
206 // (1) The init value is copied before the kWhile instruction. Before:
207 //
208 // (Init)
209 // |
210 // kWhile
211 // |
212 // ...
213 //
214 // After:
215 //
216 // (Init)
217 // |
218 // kCopy
219 // |
220 // kWhile
221 // |
222 // ...
223 //
224 // This copy is necessary in case the init value is simultaneously live
225 // with the kWhile.
226 //
227 // (2) Copies are added to the parameter and root of the while body
228 // computation. Before:
229 //
230 // kParameter
231 // |
232 // ...
233 // |
234 // (body root)
235 //
236 // After:
237 //
238 // kParameter
239 // |
240 // kCopy ----------+
241 // | |
242 // ... ctrl
243 // | edge
244 // (body root) |
245 // | |
246 // kCopy <---------+
247 //
248 // The root kCopy becomes the new root of the computation. Both copies are
249 // necessary to any potential interference between the parameter value and
250 // the root value. The control edge prevents potential interference
251 // between the copies themselves.
252 //
253 // If the loop state is a tuple then the above kCopy instructions are a deep
254 // copy constructed of kCopy, KGetTupleElement, and kTuple instruction as
255 // constructed by HloInstruction::DeepCopyInstruction.
AddCopiesForWhile(const HloAliasAnalysis & alias_analysis,HloInstruction * xla_while)256 Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
257 HloInstruction* xla_while) {
258 VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
259 TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
260
261 ShapeTree<bool> indices_to_copy(xla_while->shape());
262 if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
263 &indices_to_copy)) {
264 VLOG(2) << "No copies necessary for kWhile instruction "
265 << xla_while->name();
266 return Status::OK();
267 }
268
269 VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
270 for (auto& pair : indices_to_copy) {
271 if (pair.second) {
272 VLOG(2) << " " << pair.first;
273 }
274 }
275
276 // Deep copy init.
277 HloInstruction* while_init = xla_while->mutable_operand(0);
278 TF_ASSIGN_OR_RETURN(
279 HloInstruction * while_init_copy,
280 xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
281 TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
282
283 // Deep copy the parameter and the root. Extend a control edge from the copy
284 // of the parameter value to the corresponding copy value of the root.
285 HloComputation* body = xla_while->while_body();
286 HloInstruction* param = body->parameter_instruction(0);
287 HloInstruction* root = body->root_instruction();
288
289 // If param is the root then all indices should have been passed through the
290 // while body and we should have returned early above.
291 TF_RET_CHECK(param != root);
292
293 // Copy users before making a deep copy of the parameter as the deep copy
294 // will create new users of the parameter (eg, the GTE instructions of the
295 // deep copy).
296 std::vector<HloInstruction*> param_users = param->users();
297
298 ShapeIndex current_index;
299 TF_ASSIGN_OR_RETURN(auto pair,
300 DeepCopyAndAddControlEdges(param, root, indices_to_copy));
301
302 HloInstruction* param_copy = pair.first;
303 HloInstruction* root_copy = pair.second;
304
305 for (HloInstruction* user : param_users) {
306 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
307 }
308
309 body->set_root_instruction(root_copy);
310
311 return Status::OK();
312 }
313
314 // We add copies for all the indices of the true and false computation roots, in
315 // order to resolve interference. We later rely on RemoveUnnecessaryCopies to
316 // drop the unnecessary ones.
AddCopiesForConditional(const HloAliasAnalysis & alias_analysis,HloInstruction * conditional)317 Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
318 HloInstruction* conditional) {
319 VLOG(2) << "Adding copies for kConditional instruction "
320 << conditional->name();
321 TF_RET_CHECK(conditional->opcode() == HloOpcode::kConditional);
322
323 for (HloComputation* computation : conditional->branch_computations()) {
324 HloInstruction* root = computation->root_instruction();
325 std::vector<HloInstruction*> users = root->users();
326 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
327 computation->DeepCopyInstruction(root));
328 for (HloInstruction* user : users) {
329 TF_RETURN_IF_ERROR(root->ReplaceUseWith(user, deep_copy));
330 }
331 computation->set_root_instruction(deep_copy);
332 }
333 return Status::OK();
334 }
335
336 // Conservatively adds copies before root instruction of entry computation and
337 // each aliased parameter to resolve interference of aliased input and output
338 // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
339 // ones.
AddCopiesForAliasedInputOutputs(HloModule * module)340 Status AddCopiesForAliasedInputOutputs(HloModule* module) {
341 HloComputation* entry = module->entry_computation();
342 HloInstruction* root = entry->root_instruction();
343
344 ShapeTree<bool> output_indices_to_copy(root->shape());
345 std::vector<absl::optional<ShapeTree<HloInstruction*>>> copied_parameters(
346 entry->num_parameters());
347 bool has_alias = false;
348 for (auto* param : entry->parameter_instructions()) {
349 bool param_has_alias = false;
350 ShapeTree<bool> param_indices_to_copy(param->shape());
351
352 module->input_output_alias_config().ForEachAlias(
353 [&](const ShapeIndex& output_index,
354 const HloInputOutputAliasConfig::Alias& alias) {
355 if (alias.parameter_number == param->parameter_number()) {
356 param_has_alias = true;
357 *(param_indices_to_copy.mutable_element(alias.parameter_index)) =
358 true;
359 *(output_indices_to_copy.mutable_element(output_index)) = true;
360 }
361 });
362
363 if (!param_has_alias) {
364 continue;
365 }
366
367 TF_RET_CHECK(param->parameter_number() < entry->num_parameters());
368 TF_RET_CHECK(!copied_parameters[param->parameter_number()]);
369
370 has_alias = true;
371 // Store a snapshot of users before DeepCopyInstruction, as
372 // DeepCopyInstruction introduces new users of the instruction.
373 std::vector<HloInstruction*> users = param->users();
374 ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
375 /*init_value=*/nullptr);
376 TF_ASSIGN_OR_RETURN(HloInstruction * copied,
377 entry->DeepCopyInstruction(
378 param, ¶m_indices_to_copy, ¶m_copy_tree));
379 for (HloInstruction* user : users) {
380 TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
381 }
382
383 copied_parameters[param->parameter_number()] = param_copy_tree;
384 }
385
386 if (!has_alias) {
387 return Status::OK();
388 }
389
390 // Add copies before root instruction.
391 ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
392 /*init_value=*/nullptr);
393
394 TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
395 root->parent()->DeepCopyInstruction(
396 root, &output_indices_to_copy, &output_copy_tree));
397
398 // Add control dependencies between the input/output copies.
399 TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
400 [&](const ShapeIndex& output_index,
401 const HloInputOutputAliasConfig::Alias& alias) -> Status {
402 if (!copied_parameters[alias.parameter_number]) {
403 return Status::OK();
404 }
405 HloInstruction* from =
406 copied_parameters[alias.parameter_number]->element(
407 alias.parameter_index);
408 HloInstruction* to = output_copy_tree.element(output_index);
409
410 TF_RET_CHECK(from != nullptr);
411 TF_RET_CHECK(to != nullptr);
412 TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
413 return Status::OK();
414 }));
415
416 entry->set_root_instruction(root_copied);
417
418 return Status::OK();
419 }
420
421 // Removes any control dependencies to or from the given instruction.
StripControlDependenciesFrom(HloInstruction * instruction)422 Status StripControlDependenciesFrom(HloInstruction* instruction) {
423 while (!instruction->control_successors().empty()) {
424 TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
425 instruction->control_successors().front()));
426 }
427
428 while (!instruction->control_predecessors().empty()) {
429 TF_RETURN_IF_ERROR(
430 instruction->control_predecessors().front()->RemoveControlDependencyTo(
431 instruction));
432 }
433
434 return Status::OK();
435 }
436
437 // Class which tracks the HLO values within each HLO buffer in the module
438 // during copy removal.
439 //
440 // The values are held in a linked list where there is one list for each
441 // buffer. Removing a copy instruction merges together the values in the
442 // source buffer of the copy to the destination buffer of the copy. This class
443 // tracks these value lists as copies are removed from the graph (and value
444 // lists are merged).
445 //
446 // The CopyRemover object is initialized to match the state of
447 // HloAliasAnalysis. However, as copies are removed this state diverges. The
448 // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
449 // a fully updatable alias analysis is very slow.
450 class CopyRemover {
451 public:
452 // The values held in a single HLO buffer are represented using a linked
453 // list. An element type in this list is ValueNode.
454 //
455 // This linked list is hand-rolled to enable efficient splicing of lists
456 // using only references to list elements without knowing which lists are
457 // being spliced. std::list requires a reference to the list object to
458 // splice.
459 struct ValueNode {
ValueNodexla::__anon747efefc0111::CopyRemover::ValueNode460 explicit ValueNode(const HloValue* v) : value(v) {}
461
462 const HloValue* value;
463
464 // The uses are maintained outside of HloValue::uses() because
465 // HloValue::uses() is not updatable (a fully updatable dataflow analysis
466 // is slow).
467 std::vector<const HloUse*> uses;
468
469 // next/prev elements in the linked list. The list is circularly linked so
470 // these values are never null for elements in the list.
471 ValueNode* prev = nullptr;
472 ValueNode* next = nullptr;
473 };
474
CopyRemover(const HloModule & module,const HloAliasAnalysis & alias_analysis,const HloOrdering & ordering)475 CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
476 const HloOrdering& ordering)
477 : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
478 // Construct a list for each HLO buffer in the alias analysis. Maintain a
479 // map from HloValue to the respective list element representing that
480 // value. The map is used to construct the copy info map below.
481 absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
482 for (const HloBuffer& buffer : alias_analysis.buffers()) {
483 // Verify values contained in the buffer are strictly ordered. This
484 // should always be the case after adding copies to eliminate
485 // interference. Specifically, the addition of the control flow edges
486 // between copies added around aliased operations (kWhile) guarantees
487 // this strict order.
488 for (const HloValue* value_a : buffer.values()) {
489 if (value_a->shape().IsToken()) {
490 // Token values have no representation and cannot interfere.
491 continue;
492 }
493 for (const HloValue* value_b : buffer.values()) {
494 if (value_a != value_b) {
495 DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
496 dataflow_) ||
497 ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
498 dataflow_))
499 << value_a->ToShortString() << " and "
500 << value_b->ToShortString() << " are not ordered";
501 }
502 }
503 }
504
505 std::vector<const HloValue*> values = buffer.values();
506 absl::c_sort(values, [this](const HloValue* a, const HloValue* b) {
507 return ordering_.IsDefinedBefore(*a, *b);
508 });
509
510 // Create a list containing all of the values in the buffer.
511 AddValueList(values, &value_to_node);
512 }
513
514 // Create copy_map_ which contains the source and destination values
515 // of all copies.
516 CreateCopyMap(module, value_to_node);
517
518 XLA_VLOG_LINES(3, ToString());
519 TF_DCHECK_OK(Verify());
520 }
521
522 // Add a list containing the given values to CopyRemover. This
523 // represents the values contained in a single buffer. For each value in
524 // 'values' an entry is created in value_to_node which indicates the
525 // respective ValueNode representing that value.
AddValueList(absl::Span<const HloValue * const> values,absl::flat_hash_map<const HloValue *,ValueNode * > * value_to_node)526 void AddValueList(
527 absl::Span<const HloValue* const> values,
528 absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
529 ValueNode* tail = nullptr;
530 ValueNode* head = nullptr;
531 for (const HloValue* value : values) {
532 auto new_node = new ValueNode(value);
533 (*value_to_node)[value] = new_node;
534
535 // Copy the HLO values's uses into the ValueNode for the value. These
536 // uses in ValueNode are updated as copies are removed.
537 new_node->uses.reserve(value->uses().size());
538 for (const HloUse& use : value->uses()) {
539 new_node->uses.push_back(&use);
540 }
541
542 // Connect the new node into the linked list.
543 if (tail == nullptr) {
544 head = new_node;
545 } else {
546 tail->next = new_node;
547 new_node->prev = tail;
548 }
549 tail = new_node;
550 }
551
552 // The linked list is circular so connect the head and tail.
553 tail->next = head;
554 head->prev = tail;
555 value_lists_.insert(head);
556 }
557
558 // This method also fills in copy_map_ which indicates which nodes
559 // in the value lists corresponding to the source and destination values of
560 // kCopy instructions. value_to_node should map each HloValue to its
561 // respective ValueNode.
CreateCopyMap(const HloModule & module,const absl::flat_hash_map<const HloValue *,ValueNode * > & value_to_node)562 void CreateCopyMap(
563 const HloModule& module,
564 const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
565 for (HloComputation* computation : module.computations()) {
566 for (HloInstruction* instruction : computation->instructions()) {
567 // Add copies with unambiguous source values to the map. Copies with
568 // ambiguous sources are not removable.
569 if (instruction->opcode() == HloOpcode::kCopy) {
570 const HloValueSet& src_value_set =
571 dataflow_.GetValueSet(instruction->operand(0));
572 if (src_value_set.values().size() == 1) {
573 CopyNodes& copy_node = copy_map_[instruction];
574 copy_node.dest =
575 value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
576 copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
577 }
578 }
579 }
580 }
581 }
582
~CopyRemover()583 ~CopyRemover() {
584 for (const ValueNode* head : value_lists_) {
585 const ValueNode* p = head;
586 do {
587 const ValueNode* tmp = p->next;
588 delete p;
589 p = tmp;
590 } while (p != head);
591 }
592 }
593
594 // Verify invariants within the linked lists.
Verify() const595 Status Verify() const {
596 for (const ValueNode* head : value_lists_) {
597 const ValueNode* p = head;
598 do {
599 // Verify links between elements are consistent.
600 TF_RET_CHECK(p->prev->next == p);
601 TF_RET_CHECK(p->next->prev == p);
602
603 const HloInstruction* def = p->value->defining_instruction();
604 if (def->opcode() == HloOpcode::kCopy && ContainsKey(copy_map_, def)) {
605 TF_RET_CHECK(copy_map_.at(def).dest == p);
606 }
607 for (const HloUse* use : p->uses) {
608 if (use->instruction->opcode() == HloOpcode::kCopy &&
609 ContainsKey(copy_map_, use->instruction)) {
610 TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
611 }
612 }
613
614 p = p->next;
615 } while (p != head);
616 }
617 return Status::OK();
618 }
619
620 // Try to elide the given copy. Elision of a copy is possible only if no
621 // live range interference is introduced by the copy's elimination. If
622 // elision is possible, then the internal state (value lists) are updated,
623 // and true is returned. Returns false otherwise.
TryElideCopy(const HloInstruction * copy)624 bool TryElideCopy(const HloInstruction* copy) {
625 VLOG(2) << "Trying to remove " << copy->name();
626
627 if (!ContainsKey(copy_map_, copy)) {
628 VLOG(2) << copy->name() << " is not removable";
629 return false;
630 }
631 if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
632 VLOG(2) << copy->name() << " is not removable (shape mismatch)";
633 return false;
634 }
635 const CopyNodes& copy_node = copy_map_.at(copy);
636 ValueNode* src = copy_node.src;
637 ValueNode* dest = copy_node.dest;
638 DCHECK(src != nullptr);
639 DCHECK(dest != nullptr);
640
641 auto is_live_range_before = [this](const ValueNode& a, const ValueNode& b) {
642 VLOG(3) << "Checking live range of " << *a.value << " WRT " << *b.value;
643 if (LiveRangeBefore(a, b)) {
644 VLOG(2) << " Live range of " << a.value->ToShortString()
645 << " is before " << b.value->ToShortString();
646 return true;
647 } else {
648 VLOG(2) << " Live range of " << a.value->ToShortString()
649 << " is not before " << b.value->ToShortString();
650 return false;
651 }
652 };
653
654 VLOG(3) << copy->name() << " copies value " << src->value->ToShortString();
655 VLOG(3) << "Source buffer values: " << ValueListToString(src);
656 VLOG(3) << "Dest buffer values: " << ValueListToString(dest);
657
658 // A kCopy instruction copies an HLO value from a source buffer and
659 // defines an HLO value in a destination buffer. Most generally, the
660 // source and destination buffers may each hold more than one value at
661 // different points in the computation so we define the following:
662 //
663 // Values in source buffer: {s_0, ..., s_n}
664 // Values in destination buffer: {d_0, ..., d_m}
665 //
666 // A kCopy instruction between these buffers copies a value s_x in the
667 // source buffer and defines a value d_y in the destination buffer. The
668 // elision of a copy merges the source and destination buffers together,
669 // so the list of values for the source and destination buffers are
670 // merged.
671 //
672 // We handle two different cases for copy elision:
673 //
674 // (1) the kCopy defines the first value in the destination buffer (d_0).
675 //
676 // (2) the kCopy copies the last value in the source buffer (s_n).
677 //
678 // For the remaining case where the kCopy copies a not-last value from the
679 // source buffer to a not-first value of the destination buffer, the kCopy
680 // instruction cannot be removed. This case is generated, for example, if
681 // the kCopy copies a while body parameter of the loop state at one tuple
682 // index to a different tuple index in the while body root. Removal of the
683 // copy necessarily results in live range interference of values in the
684 // loop state at the two different tuple indices.
685 //
686 // We can only perform copy elision if the resulting merged values have
687 // totally ordered live ranges; otherwise the merged buffer would have
688 // live range interference.
689 if (src->next == dest) {
690 // In the process of eliding copies, its possible for a copy to have the
691 // same source and destination buffer. In this case, the copy can be
692 // safely removed.
693 VLOG(2) << copy->name() << " source and destination buffers are same.";
694 } else if (IsHead(*dest)) {
695 // The copy copies an arbitrary value in the source buffer (call it s_x)
696 // and defines d_0, the first value in the destination buffer. After
697 // merging, the values in the combined buffer must be strictly ordered
698 // as follows** to elide the copy:
699 //
700 // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
701 //
702 // Removing the copy eliminates d_0, and uses of d_0 become uses of
703 // s_x. In the above ordering, the live range of d_m must be ordered
704 // before the live range of s_{x+1} and the definition and all uses of
705 // s_x must be ordered before the definition of d_1. These conditions
706 // are checked below prior to elision.
707 //
708 // ** Technically it might be possible to have a non-interfering
709 // non-trivial interleaving of the values of the source and
710 // destination buffers in the resulting order. However, this case is
711 // slow and complicated to check and likely not worth it. So instead
712 // we simply check for the case where *all* values of the destination
713 // buffer (d_1 through d_m) are spliced into the point where the copy
714 // used to be.
715 VLOG(2) << copy->name() << " defines the first value in its buffer";
716 ValueNode* next_dest = Next(*dest);
717 if (next_dest != nullptr) {
718 // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
719 if (!is_live_range_before(*src, *next_dest)) {
720 return false;
721 }
722 }
723 ValueNode* next_src = Next(*src);
724
725 if (next_src != nullptr) {
726 // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
727 ValueNode* last_dest = dest->prev;
728 DCHECK(IsTail(*last_dest));
729 if (!is_live_range_before(*last_dest, *next_src)) {
730 return false;
731 }
732 }
733
734 // Splice in destination buffer values list right after 'src'.
735 SpliceAfter(dest, src);
736 } else if (IsTail(*src)) {
737 // The copy copies the last value in the source buffer, s_n, and defines
738 // an arbitrary value in the destination buffer, d_y. After
739 // merging, the values in the combined buffer must be strictly ordered
740 // as follows** to elide the copy:
741 //
742 // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
743 //
744 // Removing the copy eliminates d_y, and uses of d_y become uses of
745 // s_n. To enforce the above order, the live range of d_{y-1} must be
746 // before the live range of s_0, and the live range of s_n must be
747 // before the live range of d_{y+1}.
748 //
749 // ** See comment above in the code handling Case (1).
750 VLOG(2) << copy->name() << " copies the last value ("
751 << src->value->ToShortString() << ") in its buffer";
752
753 ValueNode* prev_dest = Prev(*dest);
754 // nullptr condition handled above in the first 'if' case.
755 DCHECK(prev_dest != nullptr);
756 ValueNode* first_src = src->next;
757 DCHECK(IsHead(*first_src));
758 if (!is_live_range_before(*prev_dest, *first_src)) {
759 // Live range of value d_{y-1} is not before s_0.
760 return false;
761 }
762 ValueNode* next_dest = Next(*dest);
763 if (next_dest != nullptr) {
764 if (!is_live_range_before(*src, *next_dest)) {
765 // Live range of value s_n is not before d_{y+1}.
766 return false;
767 }
768 }
769
770 // Splice source buffer values list right after 'prev_dest'.
771 SpliceAfter(first_src, prev_dest);
772 } else {
773 VLOG(2) << copy->name()
774 << " copies value in middle of source buffer to value in middle "
775 "of destination buffer";
776 return false;
777 }
778
779 RemoveCopyValue(dest);
780
781 XLA_VLOG_LINES(4, ToString());
782 TF_DCHECK_OK(Verify());
783
784 return true;
785 }
786
787 // Delete the given ValueNode associated with a elided kCopy
788 // instruction. This should be called after splicing the value lists of the
789 // source and destination buffers together.
RemoveCopyValue(ValueNode * copy_value_node)790 void RemoveCopyValue(ValueNode* copy_value_node) {
791 CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
792 HloOpcode::kCopy);
793 ValueNode* operand_node = copy_value_node->prev;
794 CHECK(operand_node != copy_value_node);
795
796 VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
797 << " => " << copy_value_node->value->ToShortString();
798
799 // Splice out the copy value node.
800 operand_node->next = copy_value_node->next;
801 copy_value_node->next->prev = operand_node;
802
803 // Patch up uses. Remove use of copy from operand_node uses.
804 auto it = absl::c_find_if(operand_node->uses, [copy_value_node](
805 const HloUse* use) {
806 return use->instruction == copy_value_node->value->defining_instruction();
807 });
808 CHECK(it != operand_node->uses.end());
809 operand_node->uses.erase(it);
810
811 // If the elided copy has any uses which are themselves kCopy instructions
812 // then patch up the copy info to reflect the that this kCopy instruction
813 // has a different operand (the operand of the elided copy).
814 for (const HloUse* copy_use : copy_value_node->uses) {
815 operand_node->uses.push_back(copy_use);
816 if (copy_use->instruction->opcode() == HloOpcode::kCopy &&
817 ContainsKey(copy_map_, copy_use->instruction)) {
818 copy_map_.at(copy_use->instruction).src = operand_node;
819 }
820 }
821
822 // Delete the copy info and the value node.
823 copy_map_.erase(copy_value_node->value->defining_instruction());
824 delete copy_value_node;
825 }
826
827 // Returns true if the live range of given value 'a' is before the live
828 // range of 'b'.
829 //
830 // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
831 // updated as copies are removed.
LiveRangeBefore(const ValueNode & a,const ValueNode & b)832 bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
833 if (a.uses.empty()) {
834 VLOG(2) << "Empty uses for " << *a.value;
835 return ordering_.IsDefinedBefore(*a.value, *b.value);
836 }
837 for (const HloUse* use : a.uses) {
838 VLOG(2) << "Checking use " << *use << " against " << *b.value;
839 if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
840 VLOG(2) << "Use " << *use << " is NOT before " << *b.value;
841 return false;
842 }
843 VLOG(2) << "Use " << *use << " is before " << *b.value;
844 }
845 return true;
846 }
847
848 // Returns whether 'node' is the last node in its list.
IsTail(const ValueNode & node) const849 bool IsTail(const ValueNode& node) const {
850 return ContainsKey(value_lists_, node.next);
851 }
852
853 // Returns whether 'node' is the first node in its list.
IsHead(const ValueNode & node) const854 bool IsHead(const ValueNode& node) const {
855 return ContainsKey(value_lists_, &node);
856 }
857
858 // Returns the next node in the list after 'node'. If 'node' is the
859 // tail, then nullptr is returned.
Next(const ValueNode & node) const860 ValueNode* Next(const ValueNode& node) const {
861 if (IsTail(node)) {
862 return nullptr;
863 } else {
864 return node.next;
865 }
866 }
867
868 // Returns the previous node in the list before 'node'. If 'node'
869 // is the head, then nullptr is returned.
Prev(const ValueNode & node) const870 ValueNode* Prev(const ValueNode& node) const {
871 if (IsHead(node)) {
872 return nullptr;
873 } else {
874 return node.prev;
875 }
876 }
877
878 // Splices the entire linked list with 'head' as its head right after the
879 // node 'insert_after' in another linked list.
SpliceAfter(ValueNode * head,ValueNode * insert_after)880 void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
881 DCHECK(IsHead(*head));
882 value_lists_.erase(head);
883
884 ValueNode* tail = head->prev;
885 tail->next = insert_after->next;
886 insert_after->next->prev = tail;
887
888 insert_after->next = head;
889 head->prev = insert_after;
890 }
891
ValueListToString(const ValueNode * element)892 string ValueListToString(const ValueNode* element) {
893 const ValueNode* head = element;
894 while (!IsHead(*head)) {
895 head = Prev(*head);
896 }
897 std::vector<const HloValue*> values;
898 for (const ValueNode* p = head; p != nullptr; p = Next(*p)) {
899 values.push_back(p->value);
900 }
901 return absl::StrCat("{",
902 absl::StrJoin(values, ", ",
903 [](string* s, const HloValue* value) {
904 StrAppend(s, value->ToShortString());
905 }),
906 "}");
907 }
908
ToString() const909 string ToString() const {
910 string out = absl::StrCat("CopyRemover:\n");
911 StrAppend(&out, " Def-use chains in each buffer:\n");
912 for (const ValueNode* head : value_lists_) {
913 StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
914 ":\n");
915 const ValueNode* p = head;
916 do {
917 StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
918 absl::StrJoin(p->uses, "; ",
919 [](string* s, const HloUse* use) {
920 StrAppend(s, use->ToString());
921 }),
922 "\n");
923
924 p = p->next;
925 } while (p != head);
926 }
927 StrAppend(&out, " Potentially removable copies:\n");
928 for (const auto& pair : copy_map_) {
929 const HloInstruction* copy = pair.first;
930 const CopyNodes& copy_info = pair.second;
931
932 StrAppend(&out, " ", copy->name(), " : ",
933 copy_info.src->value->ToShortString(), " => ",
934 copy_info.dest->value->ToShortString(), "\n");
935 }
936 return out;
937 }
938
939 private:
940 const HloDataflowAnalysis& dataflow_;
941 const HloOrdering& ordering_;
942
943 // The heads of all the value lists. Each value list represents the HLO
944 // values contained in a particular HLO buffer. The values in the list are
945 // in dependency order.
946 absl::flat_hash_set<const ValueNode*> value_lists_;
947
948 // Copy removal requires fast access to the value list elements
949 // corresponding to the source and destination values of the kCopy
950 // instruction. This data structure holds pointers to these elements for
951 // each kCopy instruction in the graph.
952 struct CopyNodes {
953 // The source and destinations values of the kCopy instruction.
954 ValueNode* src = nullptr;
955 ValueNode* dest = nullptr;
956 };
957 absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
958 };
959
960 } // namespace
961
962 // Add kCopy instructions to the given module to guarantee there is no
963 // live-range interference. Generally interference can only occur around kWhile
964 // instructions which have update-in-place semantics.
AddCopiesToResolveInterference(HloModule * module)965 Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
966 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
967 HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
968
969 for (HloComputation* computation : module->computations()) {
970 for (HloInstruction* instruction : computation->instructions()) {
971 if (instruction->opcode() == HloOpcode::kWhile) {
972 TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
973 } else if (instruction->opcode() == HloOpcode::kConditional) {
974 TF_RETURN_IF_ERROR(
975 AddCopiesForConditional(*alias_analysis, instruction));
976 }
977 }
978 }
979
980 TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
981 return Status::OK();
982 }
983
AddSpecialCaseCopies(HloModule * module)984 Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
985 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
986 return AddSpecialCaseCopies(*call_graph, module);
987 }
988
AddSpecialCaseCopies(const CallGraph & call_graph,HloModule * module)989 Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
990 HloModule* module) {
991 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
992 HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
993
994 // Identify which shape indices of which instructions need to be copied. Store
995 // these results in 'instructions_to_copy'.
996 HloInstructionMap<ShapeTree<bool>> instructions_to_copy;
997 auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
998 const ShapeIndex& index) {
999 auto it = instructions_to_copy.find(instruction);
1000 if (it == instructions_to_copy.end()) {
1001 auto it_added = instructions_to_copy.emplace(
1002 std::piecewise_construct, std::forward_as_tuple(instruction),
1003 std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
1004 it = it_added.first;
1005 }
1006 *it->second.mutable_element(index) = true;
1007 };
1008
1009 // Iterate through values of all constants and entry parameters. These values
1010 // are special because they are held in read-only buffers. If any of these
1011 // values share a buffer with other values (for example, the init value of a
1012 // while is a constant) then copy the value at its definition and replace all
1013 // its uses with the copy.
1014 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
1015 if (ValueIsReadOnly(*value) &&
1016 alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
1017 VLOG(2) << "Value " << value->ToShortString()
1018 << " is read only, but its buffer contains more than one value. "
1019 "Copying.";
1020 add_index_to_copy(value->defining_instruction(), value->defining_index());
1021 }
1022 }
1023
1024 // Identify copies which must be added at root instructions
1025 for (HloComputation* computation : module->computations()) {
1026 const CallGraphNode& node = call_graph.GetNode(computation);
1027 if (node.context() == CallContext::kParallel) {
1028 continue;
1029 }
1030 TF_RET_CHECK(node.context() == CallContext::kSequential);
1031
1032 SpecialCaseCopyPolicy policy =
1033 GetSpecialCaseCopyPolicy(node, module, computation);
1034 HloInstruction* root = computation->root_instruction();
1035
1036 // Mark nondistinct/ambiguous indices.
1037 absl::flat_hash_set<const HloBuffer*> seen;
1038 ShapeUtil::ForEachSubshape(
1039 root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
1040 std::vector<const HloBuffer*> buffers_at_index =
1041 alias_analysis->ComputeBuffersAt(root, index);
1042 bool buffer_seen_before = false;
1043 for (const HloBuffer* buffer : buffers_at_index) {
1044 buffer_seen_before |= !seen.insert(buffer).second;
1045 }
1046 if (buffers_at_index.size() > 1 ||
1047 (buffer_seen_before && policy.copy_root_replicated_buffers)) {
1048 VLOG(2) << "Index " << index << " of computation "
1049 << computation->name() << " (" << root->name()
1050 << ") has ambiguous or non-distinct buffer. Copying.";
1051 add_index_to_copy(root, index);
1052 }
1053 });
1054
1055 for (const auto& pair :
1056 alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
1057 const ShapeIndex& index = pair.first;
1058 const HloValueSet& value_set = pair.second;
1059 for (const HloValue* value : value_set.values()) {
1060 if (ShouldCopyRootValue(*value, policy)) {
1061 VLOG(2) << "Root of (" << root->name() << ") of computation("
1062 << computation->name()
1063 << ") has constant or parameter value at index " << index
1064 << ". Copying.";
1065 add_index_to_copy(root, index);
1066 }
1067 }
1068 }
1069 }
1070
1071 // Add copy instructions indicated in 'instructions_to_copy' to the module.
1072 for (const auto& pair : instructions_to_copy) {
1073 HloInstruction* instruction = pair.first;
1074 const ShapeTree<bool>& indices_to_copy = pair.second;
1075
1076 ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
1077 std::vector<HloInstruction*> users = instruction->users();
1078 TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
1079 instruction->parent()->DeepCopyInstruction(
1080 instruction, &indices_to_copy, &copies_added));
1081 for (HloInstruction* user : users) {
1082 TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
1083 }
1084 if (instruction == instruction->parent()->root_instruction()) {
1085 instruction->parent()->set_root_instruction(deep_copy);
1086 }
1087 }
1088 return Status::OK();
1089 }
1090
VerifyNoLiveRangeInterference(const HloOrdering & ordering,HloModule * module)1091 Status CopyInsertion::VerifyNoLiveRangeInterference(const HloOrdering& ordering,
1092 HloModule* module) {
1093 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1094 HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
1095 TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
1096 return Status::OK();
1097 }
1098
RemoveUnnecessaryCopies(const HloOrdering & ordering,HloModule * module)1099 Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
1100 HloModule* module) {
1101 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
1102 HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
1103
1104 CopyRemover copy_remover(*module, *alias_analysis, ordering);
1105 if (VLOG_IS_ON(3)) {
1106 LOG(INFO) << "Removing unnecessary copies in " << module->name();
1107 LOG(INFO) << "Buffer values, in dependency order: ";
1108 for (const HloBuffer& buffer : alias_analysis->buffers()) {
1109 LOG(INFO) << " HloBuffer " << buffer.id();
1110 }
1111 }
1112
1113 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1114 for (HloComputation* computation : module->computations()) {
1115 for (HloInstruction* instruction : computation->instructions()) {
1116 if (instruction->opcode() == HloOpcode::kCopy &&
1117 copy_remover.TryElideCopy(instruction)) {
1118 TF_RETURN_IF_ERROR(StripControlDependenciesFrom(instruction));
1119 TF_RETURN_IF_ERROR(
1120 instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
1121 }
1122 }
1123 }
1124 return Status::OK();
1125 }
1126
Run(HloModule * module)1127 StatusOr<bool> CopyInsertion::Run(HloModule* module) {
1128 // Copy insertion is performed in three steps:
1129 //
1130 // (1) Add copies conservatively to guarantee that there is no live-range
1131 // interference. This is done simplistically and usually results in more
1132 // copies than is strictly necessary.
1133 //
1134 // (2) Using a more fine-grained analysis, remove as many copies that were
1135 // added in (1) as possible while ensuring no live-range interference.
1136 //
1137 // (3) Add copies to resolve issues not related to live range interference
1138 // such as parameters and constants live out of the entry computation.
1139 //
1140 // We add copies then remove them (step (1) then (2)) rather than simply
1141 // adding only the copies that are necessary because, in general, it is
1142 // difficult to figure out the minimal set of copies to add once there is
1143 // interference. On the other hand, it is easy to determine if removing a copy
1144 // will introduce interference.
1145 //
1146 // The final copy insertion in (3) is done separately to simplify the
1147 // implementation of copy removal in (2) which is the most complicated part of
1148 // the pass. As is, copy removal only has to reason about live range
1149 // interference. If all copies were added in step (1) then copy removal would
1150 // also have to reason about things like constants and parameters live out of
1151 // the computation.
1152 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1153 if (!call_graph->IsFlattened()) {
1154 return FailedPrecondition(
1155 "Call graph must be flattened before copy insertion.");
1156 }
1157
1158 int64 num_existing_copies = 0;
1159 if (VLOG_IS_ON(1)) {
1160 for (HloComputation* computation : module->computations()) {
1161 for (HloInstruction* instruction : computation->instructions()) {
1162 if (instruction->opcode() == HloOpcode::kCopy) {
1163 ++num_existing_copies;
1164 }
1165 }
1166 }
1167 }
1168
1169 TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
1170
1171 // Simplify the tuple structures introduced by the deep copies. This should be
1172 // done before removing copies (RemoveUnnecessaryCopies) because tuple
1173 // simplification changes dependencies in the graph which changes live range
1174 // interference in the graph. Also run DCE to remove the dead Tuple/GTE
1175 // instructions introduced by tuple simplification.
1176 TupleSimplifier tuple_simplifier;
1177 HloDCE dce;
1178 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1179 TF_RETURN_IF_ERROR(dce.Run(module).status());
1180 DumpHloModuleDuringPassIfEnabled(
1181 name(), "after adding copies to resolve interference", *module);
1182
1183 DependencyHloOrdering dep_ordering(module);
1184 TF_DCHECK_OK(VerifyNoLiveRangeInterference(dep_ordering, module));
1185
1186 TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(dep_ordering, module));
1187 DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
1188 *module);
1189
1190 TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
1191 DumpHloModuleDuringPassIfEnabled(name(), "after adding special-case copies",
1192 *module);
1193
1194 TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
1195 TF_RETURN_IF_ERROR(dce.Run(module).status());
1196 TF_DCHECK_OK(
1197 VerifyNoLiveRangeInterference(DependencyHloOrdering(module), module));
1198
1199 if (VLOG_IS_ON(1)) {
1200 int64 num_total_copies = 0;
1201 for (HloComputation* computation : module->computations()) {
1202 for (HloInstruction* instruction : computation->instructions()) {
1203 if (instruction->opcode() == HloOpcode::kCopy) {
1204 num_total_copies++;
1205 }
1206 }
1207 }
1208 VLOG(1) << "Num copies before copy-insertion: " << num_existing_copies;
1209 VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
1210 }
1211
1212 return true;
1213 }
1214
1215 namespace {
1216
IsWhileBody(const HloComputation * computation,const CallGraph & call_graph)1217 bool IsWhileBody(const HloComputation* computation,
1218 const CallGraph& call_graph) {
1219 const CallGraphNode& node = call_graph.GetNode(computation);
1220
1221 if (node.context() == CallContext::kSequential &&
1222 !node.caller_callsites().empty()) {
1223 // Callgraph should be flattened so sequential context computations can
1224 // have at most one caller.
1225 CHECK_EQ(node.caller_callsites().size(), 1);
1226 const HloInstruction* calling_instruction =
1227 node.caller_callsites()[0].instruction();
1228 if (calling_instruction->opcode() == HloOpcode::kWhile &&
1229 calling_instruction->while_body() == node.computation()) {
1230 return true;
1231 }
1232 }
1233 return false;
1234 }
1235
1236 } // namespace
1237
AddCopiesForBufferAssignment(HloModule * module)1238 /* static */ StatusOr<bool> CopyInsertion::AddCopiesForBufferAssignment(
1239 HloModule* module) {
1240 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
1241 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow,
1242 HloDataflowAnalysis::Run(*module));
1243
1244 bool changed = false;
1245
1246 // If a buffer live out of a computation is a constant, a parameter, or not
1247 // defined in the computation, then copy it to account for the limited
1248 // computation-scoped analysis in buffer assignment. An exception to this rule
1249 // is the while body which is handled properly without copies.
1250 for (HloComputation* computation : module->computations()) {
1251 if (computation == module->entry_computation() ||
1252 IsWhileBody(computation, *call_graph)) {
1253 continue;
1254 }
1255
1256 HloInstruction* root = computation->root_instruction();
1257 ShapeTree<bool> indices_to_copy(root->shape(), /*init_value=*/false);
1258 bool copy_root = false;
1259 for (const auto& pair : dataflow->GetInstructionValueSet(root)) {
1260 const ShapeIndex& index = pair.first;
1261 const HloValueSet& value_set = pair.second;
1262 for (const HloValue* value : value_set.values()) {
1263 HloInstruction* def = value->defining_instruction();
1264 if (def->parent() != computation ||
1265 def->opcode() == HloOpcode::kConstant ||
1266 def->opcode() == HloOpcode::kParameter) {
1267 *indices_to_copy.mutable_element(index) = true;
1268 copy_root = true;
1269 }
1270 }
1271 }
1272 if (copy_root) {
1273 TF_ASSIGN_OR_RETURN(
1274 HloInstruction * root_copy,
1275 computation->DeepCopyInstruction(root, &indices_to_copy));
1276 computation->set_root_instruction(root_copy);
1277 changed = true;
1278 }
1279 }
1280
1281 TupleSimplifier tuple_simplifier;
1282 HloDCE dce;
1283 TF_ASSIGN_OR_RETURN(bool tuple_simplifier_changed,
1284 tuple_simplifier.Run(module));
1285 TF_ASSIGN_OR_RETURN(bool dce_changed, dce.Run(module));
1286
1287 return changed || tuple_simplifier_changed || dce_changed;
1288 }
1289
1290 } // namespace xla
1291