1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
7     http://www.apache.org/licenses/LICENSE-2.0
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
47 namespace xla {
49 using absl::StrCat;
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52     HloInstruction* root_instruction) {
53   int parameter_count = 0;
54   for (auto& instruction : instructions_) {
55     if (instruction->opcode() == HloOpcode::kParameter) {
56       parameter_count++;
57     }
58   }
59   // If root_instruction is not specified use the last added instruction.
60   HloInstruction* root =
61       root_instruction ? root_instruction : last_added_instruction_;
62   CHECK_NE(nullptr, root);
63   return absl::WrapUnique(new HloComputation(
64       name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68     const string& name, int parameter_count,
69     std::vector<std::unique_ptr<HloInstruction>>* instructions,
70     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71     : name_(NameUniquer::GetSanitizedName(name)),
72       unique_id_(-1),
73       root_instruction_(root_instruction),
74       fusion_instruction_(fusion_instruction) {
75   param_instructions_.resize(parameter_count, nullptr);
76   bool root_found = false;
77   for (auto& instruction : *instructions) {
78     if (instruction->opcode() == HloOpcode::kParameter) {
79       int64 param_no = instruction->parameter_number();
80       CHECK(param_no >= 0 && param_no < parameter_count)
81           << "\nERROR: invalid parameter number.  Expected [0, "
82           << parameter_count << "), got " << param_no;
83       CHECK(param_instructions_[param_no] == nullptr)
84           << "\nERROR: parameter number " << param_no
85           << " already allocated in this computation";
86       param_instructions_[param_no] = instruction.get();
87     }
88     root_found |= instruction.get() == root_instruction_;
89     AddInstructionInternal(std::move(instruction));
90   }
91   CHECK(root_found)
92       << "\nERROR: root instruction is not present in computation.";
93 }
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)95 HloInstruction* HloComputation::AddInstruction(
96     std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
97   CHECK(instruction->opcode() != HloOpcode::kParameter)
98       << "Parameter instructions cannot be added to a computation after "
99       << "it has been built";
100   if (!new_name.empty()) {
101     instruction->SetAndSanitizeName(new_name);
102   }
103   return AddInstructionInternal(std::move(instruction));
104 }
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)106 HloInstruction* HloComputation::AddInstructionInternal(
107     std::unique_ptr<HloInstruction> instruction) {
108   if (parent() != nullptr) {
109     instruction->UniquifyName(&parent()->instruction_name_uniquer());
110     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
111   }
112   instruction->set_parent(this);
113   HloInstruction* pinst = instruction.get();
114   instruction_iterators_[pinst] =
115       instructions_.insert(instructions_.end(), std::move(instruction));
116   return pinst;
117 }
AddParameter(std::unique_ptr<HloInstruction> instruction)119 HloInstruction* HloComputation::AddParameter(
120     std::unique_ptr<HloInstruction> instruction) {
121   CHECK(instruction->opcode() == HloOpcode::kParameter);
122   CHECK(IsFusionComputation());
123   CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
124   instruction->set_parent(this);
125   param_instructions_.push_back(instruction.get());
126   AddInstructionInternal(std::move(instruction));
127   return instructions_.back().get();
128 }
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)130 HloInstruction* HloComputation::AddEntryComputationParameter(
131     std::unique_ptr<HloInstruction> instruction) {
132   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
133   CHECK_EQ(instruction->parameter_number(), num_parameters());
134   CHECK(parent()->entry_computation() == this);
136   HloModuleConfig config = parent()->config();
137   config.mutable_entry_computation_layout()->add_parameter_layout(
138       ShapeLayout(instruction->shape()));
139   parent()->set_config(config);
141   instruction->set_parent(this);
142   param_instructions_.push_back(instruction.get());
143   AddInstructionInternal(std::move(instruction));
145   return instructions_.back().get();
146 }
ReplaceEntryComputationParameter(int64 param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)148 Status HloComputation::ReplaceEntryComputationParameter(
149     int64 param_no, HloInstruction* old_instruction,
150     std::unique_ptr<HloInstruction> instruction) {
151   CHECK_GE(param_no, 0);
152   CHECK_LT(param_no, param_instructions_.size());
153   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
154   CHECK(parent()->entry_computation() == this);
156   HloModuleConfig config = parent()->config();
157   *config.mutable_entry_computation_layout()->mutable_parameter_layout(
158       param_no) = ShapeLayout(instruction->shape());
159   parent()->set_config(config);
161   instruction->set_parent(this);
162   param_instructions_[param_no] = instruction.get();
163   AddInstructionInternal(std::move(instruction));
165   return ForceRemoveInstruction(old_instruction);
166 }
RemoveParameter(int64 param_no)168 Status HloComputation::RemoveParameter(int64 param_no) {
169   CHECK_GE(param_no, 0);
170   CHECK_LT(param_no, param_instructions_.size());
171   CHECK(IsFusionComputation());
172   HloInstruction* param_instruction = param_instructions_[param_no];
173   auto param_instruction_iterator = param_instructions_.begin() + param_no;
174   param_instructions_.erase(param_instruction_iterator);
175   // Throw removed fused parameter instruction away.
176   TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
178   while (param_no < param_instructions_.size()) {
179     param_instruction = param_instructions_[param_no];
180     HloInstruction* new_instr =
181         AddInstructionInternal(HloInstruction::CreateParameter(
182             param_no, param_instruction->shape(), StrCat("param_", param_no)));
183     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
184     param_instructions_[param_no] = new_instr;
185     TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
186     param_no++;
187   }
189   return Status::OK();
190 }
RemoveUnusedParametersFromFusedComputation()192 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
193   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
194 }
RemoveUnusedParametersFromAnyComputation()196 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
197   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
198 }
RemoveUnusedParametersImpl(bool allow_non_fusion)200 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
201   CHECK(allow_non_fusion || IsFusionComputation());
202   int64 removed = 0;
203   for (int64 i = 0; i < param_instructions_.size(); ++i) {
204     HloInstruction* param_instruction = param_instructions_[i];
205     if (param_instruction->user_count() == 0 &&
206         param_instruction != root_instruction()) {
208           RemoveInstructionImpl(param_instruction, allow_non_fusion));
209       ++removed;
210       continue;
211     }
213     if (removed > 0) {
214       const int64 param_no = i - removed;
215       HloInstruction* new_instr = AddInstructionInternal(
216           HloInstruction::CreateParameter(param_no, param_instruction->shape(),
217                                           StrCat("param_", param_no)));
218       TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
219       param_instructions_[param_no] = new_instr;
221           RemoveInstructionImpl(param_instruction, allow_non_fusion));
222     }
223   }
224   param_instructions_.resize(param_instructions_.size() - removed);
225   return Status::OK();
226 }
IsSafelyRemovable(const HloInstruction * instruction)228 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
229   // If the instruction has control predecessors or successors then we cannot
230   // remove the instruction without violating ordering constraints (added, for
231   // example, to avert interference due to buffer aliasing).
232   if (!instruction->control_predecessors().empty() ||
233       !instruction->control_successors().empty()) {
234     return false;
235   }
237   if (instruction->opcode() == HloOpcode::kParameter &&
238       !IsFusionComputation()) {
239     return false;
240   }
242   return true;
243 }
HasSideEffect() const245 bool HloComputation::HasSideEffect() const {
246   for (auto* instruction : instructions()) {
247     if (instruction->HasSideEffect()) {
248       return true;
249     }
250   }
251   return false;
252 }
IsMarkedAsDead(const HloInstruction * inst)254 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
255   return inst->IsMarkedAsDead();
256 }
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)258 Status HloComputation::RemoveInstructionAndUnusedOperands(
259     HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
260   TF_RET_CHECK(root_instruction() != instruction);
262   TF_RET_CHECK(instruction->user_count() == 0);
263   TF_RET_CHECK(IsSafelyRemovable(instruction))
264       << "Cannot remove instruction: " << instruction->ToString();
265   absl::flat_hash_set<HloInstruction*> removed;
266   std::queue<HloInstruction*> worklist;
267   worklist.push(instruction);
268   while (!worklist.empty()) {
269     HloInstruction* item = worklist.front();
270     worklist.pop();
272     if (removed.contains(item) || item->user_count() != 0 ||
273         item == root_instruction() || !IsSafelyRemovable(item) ||
274         (item->HasSideEffect() && item != instruction)) {
275       continue;
276     }
277     for (int i = 0; i < item->operand_count(); ++i) {
278       worklist.push(item->mutable_operand(i));
279     }
281     if (cleanup) {
282       cleanup(item);
283     }
284     TF_RETURN_IF_ERROR(RemoveInstruction(item));
285     removed.insert(item);
286   }
287   return Status::OK();
288 }
RemoveInstruction(HloInstruction * instruction)290 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
291   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
292 }
ForceRemoveInstruction(HloInstruction * instruction)294 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
295   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
296 }
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)298 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
299                                              bool ignore_safety_check) {
300   VLOG(2) << "Removing instruction " << instruction->name()
301           << " from computation " << name();
302   TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
303       << "cannot remove instruction: " << instruction->ToString();
304   TF_RET_CHECK(root_instruction() != instruction)
305       << "cannot remove root instruction " << instruction->name();
306   TF_RET_CHECK(instruction->user_count() == 0)
307       << "instruction " << instruction->name()
308       << " has users and cannot be removed";
309   TF_RET_CHECK(instruction->control_predecessors().empty())
310       << "instruction " << instruction->name()
311       << " has control predecessors and cannot be removed";
312   TF_RET_CHECK(instruction->control_successors().empty())
313       << "instruction " << instruction->name()
314       << " has control successors and cannot be removed";
316   auto inst_it = instruction_iterators_.find(instruction);
317   TF_RET_CHECK(inst_it != instruction_iterators_.end());
318   (*inst_it->second)->set_parent(nullptr);
319   to_be_deleted_.emplace_back(inst_it->second->release());
320   to_be_deleted_.back()->DetachFromOperandsAndUsers();
321   // Clear all operands to avoid Null operands.
322   to_be_deleted_.back()->RemoveAllOperands();
323   to_be_deleted_.back()->ClearCalledComputations();
324   to_be_deleted_.back()->MarkAsDead();
325   instructions_.erase(inst_it->second);
326   instruction_iterators_.erase(inst_it);
327   return Status::OK();
328 }
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)330 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
331                                           bool accept_different_shape) {
332   // The shape of the root (ignoring layout) is an invariant of the computation
333   // for non-fusion cases.
334   if (!IsFusionComputation() && !accept_different_shape) {
335     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
336                                 root_instruction_->shape()))
337         << new_root_instruction->shape() << " is incompatible with "
338         << root_instruction_->shape();
339   }
340   bool root_found = false;
341   for (auto& instruction : instructions_) {
342     if (new_root_instruction == instruction.get()) {
343       root_found = true;
344       break;
345     }
346   }
347   DCHECK(root_found);
349   if (parent() && parent()->has_entry_computation() &&
350       parent()->entry_computation() == this) {
351     if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
352                                        root_instruction_->shape())) {
353       // Rebuild input output alias config now that we have a new output shape.
354       parent()->input_output_alias_config() =
355           HloInputOutputAliasConfig(new_root_instruction->shape());
356     }
357   }
359   root_instruction_ = new_root_instruction;
360 }
362 namespace {
364 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)365 void ComputeComputationPostOrder(HloComputation* computation,
366                                  absl::flat_hash_set<HloComputation*>* visited,
367                                  std::vector<HloComputation*>* post_order) {
368   if (visited->insert(computation).second) {
369     for (auto* instruction : computation->instructions()) {
370       for (HloComputation* called_computation :
371            instruction->called_computations()) {
372         ComputeComputationPostOrder(called_computation, visited, post_order);
373       }
374     }
375     post_order->push_back(computation);
376   }
377 }
379 }  // namespace
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const381 void HloComputation::ComputeInstructionPostOrder(
382     const HloComputation::ChannelDependencyGroup& channel_dependency_group,
383     std::vector<HloInstruction*>* post_order, HloInstruction* root,
384     absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
385   std::vector<HloInstruction*> dfs_stack;
386   dfs_stack.push_back(root);
387   while (!dfs_stack.empty()) {
388     const auto current = dfs_stack.back();
389     CHECK_EQ(current->parent(), this)
390         << "Instruction " << current->name()
391         << " is not in the current computation (" << name() << ").";
392     auto it = visited->find(current);
393     if (it != visited->end()) {
394       if (it->second == kVisited) {
395         // Already visited.
396         dfs_stack.pop_back();
397         continue;
398       }
399       // Visit this node.
400       CHECK_EQ(kVisiting, it->second);
401       dfs_stack.pop_back();
402       post_order->push_back(current);
403       it->second = kVisited;
404       continue;
405     }
407     visited->insert({current, kVisiting});
409     const auto get_channel_id =
410         [](HloInstruction* inst) -> absl::optional<int64> {
411       switch (inst->opcode()) {
412         case HloOpcode::kRecvDone:
413           return inst->channel_id();
414         case HloOpcode::kAllReduce:
415           return inst->channel_id();
416         default:
417           return absl::nullopt;
418       }
419     };
421     // When adding a predecessor to the dfs_stack, we need to also add its
422     // associated channel dependencies.
423     const auto add_dfs_stack = [&](HloInstruction* inst) {
424       auto channel_id = get_channel_id(inst);
425       if (channel_id && channel_dependency_group.count(*channel_id)) {
426         auto it = channel_dependency_group.find(*channel_id);
427         for (HloInstruction* cinst : it->second) {
428           dfs_stack.emplace_back(cinst);
429         }
430       } else {
431         dfs_stack.emplace_back(inst);
432       }
433     };
435     const auto add_predecessors = [&](HloInstruction* inst) {
436       // Add the operands to the stack in reverse order so the first operand is
437       // processed first. This will produce a more natural ordering and a nicer
438       // result for things like HLO stringification.
439       const auto& operands = inst->operands();
440       for (int64 i = operands.size() - 1; i >= 0; --i) {
441         add_dfs_stack(operands[i]);
442       }
444       for (HloInstruction* op : inst->control_predecessors()) {
445         add_dfs_stack(op);
446       }
447     };
449     // If the current instruction is a channel instruction, add the dependencies
450     // from all associated instructions of the channel.
451     auto channel_id = get_channel_id(current);
452     if (channel_id && channel_dependency_group.count(*channel_id)) {
453       auto it = channel_dependency_group.find(*channel_id);
454       for (HloInstruction* cinst : it->second) {
455         add_predecessors(cinst);
456       }
457     } else {
458       add_predecessors(current);
459     }
460   }
461 }
463 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const464 HloComputation::ComputeChannelDependencies() const {
465   ChannelDependencyGroup channel_dependency_group;
466   for (const auto& instruction : instructions_) {
467     switch (instruction->opcode()) {
468       case HloOpcode::kSend:
469       case HloOpcode::kRecvDone:
470       case HloOpcode::kAllReduce: {
471         auto channel_id = instruction->channel_id();
472         if (channel_id) {
473           channel_dependency_group[channel_id.value()].push_back(
474               instruction.get());
475         }
476         break;
477       }
478       default:
479         break;
480     }
481   }
482   return channel_dependency_group;
483 }
HasOnlyTraceUsers(const HloInstruction * instruction)485 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
486   return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
487     return user->opcode() == HloOpcode::kTrace;
488   });
489 }
MakeInstructionPostOrder() const491 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
492   auto channel_dependency_group = ComputeChannelDependencies();
493   std::vector<HloInstruction*> post_order;
494   post_order.reserve(instruction_count());
495   std::vector<HloInstruction*> trace_instructions;
496   absl::flat_hash_map<HloInstruction*, VisitState> visited;
497   visited.reserve(instruction_count());
498   for (auto& instruction : instructions_) {
499     if (instruction->opcode() == HloOpcode::kTrace) {
500       // Trace instructions aren't handled by the DFS visitor. Add trace
501       // instructions to the post order at the end (necessarily they have no
502       // users).
503       trace_instructions.push_back(instruction.get());
504     } else if (HasOnlyTraceUsers(instruction.get())) {
505       ComputeInstructionPostOrder(channel_dependency_group, &post_order,
506                                   instruction.get(), &visited);
507     }
508   }
509   post_order.insert(post_order.end(), trace_instructions.begin(),
510                     trace_instructions.end());
511   CHECK_EQ(instructions_.size(), post_order.size())
512       << "number of instructions does not match post order size";
513   return post_order;
514 }
MakeEmbeddedComputationsList() const516 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
517     const {
518   absl::flat_hash_set<HloComputation*> visited;
519   std::vector<HloComputation*> post_order;
521   // To avoid special handling of this computation, cast away const of
522   // 'this'. 'this' is immediately removed from the post order after
523   // construction.
524   //
525   // TODO(b/78350259): This violates const-correctness, since while the original
526   // computation is not returned, we still retrieve non-const computations from
527   // a const one. Consider also avoiding const for HloComputation, or review XLA
528   // for const-correctness of non-HloInstruction* types like this.
529   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
530                               &post_order);
532   // We don't want to include this computation in the post order.
533   CHECK_EQ(this, post_order.back());
534   post_order.pop_back();
536   return post_order;
537 }
ToString(const HloPrintOptions & options) const539 string HloComputation::ToString(const HloPrintOptions& options) const {
540   return ToString(options, MakeInstructionPostOrder());
541 }
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const543 string HloComputation::ToString(
544     const HloPrintOptions& options,
545     absl::Span<const HloInstruction* const> instruction_order) const {
546   CHECK_EQ(instruction_order.size(), instruction_count());
548   const string tab(2 * options.indent_amount(), ' ');
550   std::ostringstream s;
551   s << tab;
553   if (!options.is_in_nested_computation()) {
554     if (options.print_percent()) {
555       s << "%";
556     }
557     if (options.print_ids()) {
558       // Exclude entry computation's name because it includes and leads to
559       // non-deterministic fingerprint.
560       s << PrintName(name(), options.print_ids()) << " ";
561     }
562   }
564   if (options.print_program_shape()) {
565     s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
566       << " ";
567   }
568   s << "{\n";
570   // There are instructions which are required to be printed. Additionally, we
571   // print some instructions before and after required ones. The resulting
572   // output has the following format.
573   //
574   //  computation {
575   //    ...
576   //    additional_instructions
577   //    required_instructions
578   //    additional_instructions
579   //    ...
580   //    additional_instructions
581   //    required_instructions
582   //    additional_instructions
583   //    ...
584   //  }
585   std::set<int> instructions_to_print;
586   {
587     // Find all the instructions that should be printed.
588     auto add_instruction = [&instructions_to_print,
589                             &instruction_order](int index) {
590       if (index < 0 || index >= instruction_order.size()) {
591         return;
592       }
593       instructions_to_print.insert(index);
594     };
596     auto add_instructions_arround = [&add_instruction, &options](int index) {
597       for (int i = index - options.leading_and_trailing_instructions_number();
598            i <= index + options.leading_and_trailing_instructions_number();
599            ++i) {
600         add_instruction(i);
601       }
602     };
604     for (int i = 0; i < instruction_order.size(); ++i) {
605       const HloInstruction* instruction = instruction_order[i];
606       CHECK_EQ(this, instruction->parent());
607       if (options.print_instruction(instruction)) {
608         add_instructions_arround(i);
609       }
610     }
611   }
613   {
614     // Print the instructions in this computation.
615     HloPrintOptions new_options = options;
616     new_options.set_indent_amount(options.indent_amount() + 1)
617         .set_is_in_nested_computation(true);
619     const string new_tab(2 * new_options.indent_amount(), ' ');
621     CanonicalNameMap name_map;
623     bool print_prev = true;
624     for (int index = 0; index < instruction_order.size(); ++index) {
625       const HloInstruction* instruction = instruction_order[index];
626       if (instructions_to_print.find(index) != instructions_to_print.end()) {
627         s << new_options.format_instruction(
628                  instruction,
629                  instruction->ToStringWithCanonicalNameMap(new_options,
630                                                            &name_map),
631                  new_options.indent_amount(), instruction == root_instruction_)
632           << "\n";
633         print_prev = true;
634       } else if (print_prev) {
635         s << new_tab << "...\n";
636         print_prev = false;
637       }
638     }
639   }
641   s << tab << "}";
642   return s.str();
643 }
ToProto() const645 HloComputationProto HloComputation::ToProto() const {
646   HloComputationProto proto;
647   CHECK(unique_id_ != -1)
648       << "This computation does not have a valid id. Please make sure the "
649          "computation is inside a module before dumping it.";
650   proto.set_id(unique_id_);
651   proto.set_name(name_);
652   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
653     HloInstructionProto instruction_proto = instruction->ToProto();
654     proto.add_instructions()->Swap(&instruction_proto);
655   }
656   proto.set_root_id(root_instruction()->unique_id());
657   *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
658   return proto;
659 }
661 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)662 HloComputation::CreateFromProto(
663     const HloComputationProto& proto,
664     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
665     bool prohibit_empty_literal) {
666   absl::flat_hash_map<int64, HloInstruction*> instruction_map;
667   absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
668   std::vector<std::unique_ptr<HloInstruction>> instructions;
669   int64 parameter_count = 0;
670   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
671     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
672                         HloInstruction::CreateFromProto(
673                             instruction_proto, instruction_map, computation_map,
674                             prohibit_empty_literal));
675     if (instruction->opcode() == HloOpcode::kParameter) {
676       parameter_count++;
677     }
678     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
679     instruction_map[instruction_proto.id()] = instruction.get();
680     to_proto_id[instruction.get()] = instruction_proto.id();
681     instructions.push_back(std::move(instruction));
682   }
684   TF_RET_CHECK(proto.root_id() != -1);
685   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
686   HloInstruction* root = instruction_map.at(proto.root_id());
688   // Sort the instructions in the proto id's order.
689   absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
690                                  const std::unique_ptr<HloInstruction>& b) {
691     return to_proto_id[a.get()] < to_proto_id[b.get()];
692   });
694   TF_RETURN_IF_ERROR([&]() -> Status {
695     std::vector<bool> parameters_seen(parameter_count);
696     int parameters_seen_count = 0;
697     for (auto& instruction : instructions) {
698       if (instruction->opcode() == HloOpcode::kParameter) {
699         int64 param_no = instruction->parameter_number();
700         TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
701             << "Invalid parameter number.  Expected [0, " << parameter_count
702             << "), got " << param_no;
703         TF_RET_CHECK(!parameters_seen[param_no])
704             << "Parameter number " << param_no
705             << " already allocated in this computation";
706         parameters_seen[param_no] = true;
707         parameters_seen_count++;
708       }
709     }
710     TF_RET_CHECK(parameters_seen_count == parameter_count)
711         << "Not all parameters in range [0, " << parameter_count
712         << ") were referenced";
713     return Status::OK();
714   }());
716   auto computation = absl::WrapUnique(
717       new HloComputation(proto.name(), parameter_count, &instructions, root,
718                          /*fusion_instruction=*/nullptr));
719   computation->unique_id_ = proto.id();
720   return std::move(computation);
721 }
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)723 void HloComputation::FuseInstructionsInto(
724     absl::Span<HloInstruction* const> instructions_to_fuse,
725     HloInstruction* fusion_instruction) {
726   CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
727   HloInstruction* root = instructions_to_fuse.front();
728   TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
729   if (root == root_instruction()) {
730     set_root_instruction(fusion_instruction);
731   }
732   TF_CHECK_OK(RemoveInstruction(root));
733   for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
734     HloInstruction* instruction = instructions_to_fuse[i];
735     fusion_instruction->FuseInstruction(instruction);
736     if (instruction->user_count() == 0) {
737       TF_CHECK_OK(RemoveInstruction(instruction));
738     }
739   }
740 }
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)742 HloInstruction* HloComputation::CreateFusionInstruction(
743     absl::Span<HloInstruction* const> instructions_to_fuse,
744     HloInstruction::FusionKind fusion_kind) {
745   HloInstruction* root = instructions_to_fuse.front();
746   HloInstruction* fusion_instruction = AddInstruction(
747       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
748   FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
749   return fusion_instruction;
750 }
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)752 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
753     HloInstruction* instruction, ShapeIndex* index,
754     const std::function<
755         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
756                         HloComputation* computation)>& copy_leaf) {
757   if (instruction->shape().IsTuple()) {
758     std::vector<HloInstruction*> elements;
759     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
760          i++) {
761       HloInstruction* gte =
762           AddInstruction(HloInstruction::CreateGetTupleElement(
763               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
764               instruction, i));
766       index->push_back(i);
767       TF_ASSIGN_OR_RETURN(HloInstruction * element,
768                           DeepCopyHelper(gte, index, copy_leaf));
769       elements.push_back(element);
770       index->pop_back();
771     }
772     return AddInstruction(HloInstruction::CreateTuple(elements));
773   }
774   if (instruction->shape().IsToken()) {
775     // Tokens have no on-device representation and cannot be copied. Pass
776     // through transparently.
777     return instruction;
778   }
780   // Array shape.
781   TF_RET_CHECK(instruction->shape().IsArray());
782   return copy_leaf(instruction, *index, this);
783 }
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)785 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
786     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
787     ShapeTree<HloInstruction*>* copies_added) {
788   if (instruction->parent() != this) {
789     return FailedPrecondition(
790         "Can't deep copy instruction %s: instruction is not in computation %s",
791         instruction->name(), name());
792   }
793   if (indices_to_copy != nullptr &&
794       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
795     return FailedPrecondition(
796         "Can't deep copy instruction %s: given shape tree of indices to copy "
797         "has incompatible shapes: %s vs. %s",
798         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
799         ShapeUtil::HumanString(indices_to_copy->shape()));
800   }
802   ShapeIndex index;
803   auto copy_leaf = [indices_to_copy, copies_added](
804                        HloInstruction* leaf, const ShapeIndex& leaf_index,
805                        HloComputation* computation) {
806     if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
807       HloInstruction* copy = computation->AddInstruction(
808           HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
809       if (copies_added != nullptr) {
810         *copies_added->mutable_element(leaf_index) = copy;
811       }
812       return copy;
813     }
814     // Elements which are not to be copied are passed through
815     // transparently.
816     return leaf;
817   };
818   return DeepCopyHelper(instruction, &index, copy_leaf);
819 }
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)821 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
822     HloInstruction* instruction,
823     const std::function<
824         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
825                         HloComputation* computation)>& copy_leaf) {
826   if (instruction->parent() != this) {
827     return FailedPrecondition(
828         "Can't deep copy instruction %s: instruction is not in computation %s",
829         instruction->name(), name());
830   }
831   ShapeIndex index;
832   return DeepCopyHelper(instruction, &index, copy_leaf);
833 }
ComputeProgramShape(bool include_ids) const835 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
836   ProgramShape program_shape;
838   for (auto* param_instruction : param_instructions_) {
839     *program_shape.add_parameters() = param_instruction->shape();
840     *program_shape.add_parameter_names() =
841         PrintName(param_instruction->name(), include_ids);
842   }
843   *program_shape.mutable_result() = root_instruction_->shape();
845   return program_shape;
846 }
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values) const848 bool HloComputation::EqualInternal(const HloComputation& other,
849                                    bool is_layout_sensitive,
850                                    bool ignore_channel_id_values) const {
851   if (this == &other) {
852     return true;
853   }
854   absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
855       visited;
856   std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
858   worklist.push_back({root_instruction(), other.root_instruction()});
860   while (!worklist.empty()) {
861     auto pair = worklist.back();
862     worklist.pop_back();
864     if (visited.contains(pair)) {
865       continue;
866     }
867     visited.emplace(pair);
868     // TODO(b/123082518): Avoid recursively invoking Equal because it may
869     // cause a stack overflow with deeply nested subcomputations.
870     auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
871       return true;
872     };
873     auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
874       return a->EqualInternal(*b, is_layout_sensitive,
875                               ignore_channel_id_values);
876     };
877     bool identical_ignoring_operands =
878         ignore_channel_id_values
879             ? pair.first->IdenticalIgnoringChannelIdValues(
880                   *pair.second, operands_eq, comp_eq, is_layout_sensitive)
881             : pair.first->Identical(*pair.second, operands_eq, comp_eq,
882                                     is_layout_sensitive);
883     if (!identical_ignoring_operands) {
884       return false;
885     }
886     for (size_t i = 0; i < pair.first->operands().size(); ++i) {
887       worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
888     }
889   }
890   return true;
891 }
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)893 Status HloComputation::ReplaceWithNewInstruction(
894     HloInstruction* old_instruction,
895     std::unique_ptr<HloInstruction> new_instruction) {
896   return ReplaceInstruction(old_instruction,
897                             AddInstruction(std::move(new_instruction)));
898 }
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)900 Status HloComputation::ReplaceWithNewEntryComputationParameter(
901     HloInstruction* old_instruction,
902     std::unique_ptr<HloInstruction> new_instruction) {
903   return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
904                                                  std::move(new_instruction)));
905 }
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)907 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
908                                           HloInstruction* new_instruction) {
910       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
911       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
912       << ShapeUtil::HumanString(new_instruction->shape());
914   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
915            << new_instruction->ToString();
916   // Try to add metadata for HLO instructions that are created to replace
917   // existing HLO instructions (e.g. during optimizations). The assumption is
918   // that the old instruction and the new instruction would perform the same
919   // function, and that they would be correlated to the same TF op. This might
920   // not always be correct since HLO optimizations can cross TF op boundaries.
921   // But still this seems to be better than nothing.
922   bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
923                            !old_instruction->metadata().op_name().empty();
924   bool overwrite_pass_id =
925       new_instruction->metadata().op_name().empty() &&
926       new_instruction->metadata().logical_creation_pass_id() == 0 &&
927       old_instruction->metadata().logical_creation_pass_id() != 0;
928   if (overwrite_op_name || overwrite_pass_id) {
929     new_instruction->set_metadata(old_instruction->metadata());
930   }
931   if (new_instruction->frontend_attributes().map().empty()) {
932     new_instruction->set_frontend_attributes(
933         old_instruction->frontend_attributes());
934   }
936   // Like the metadata above, if the user didn't specify any sharding
937   // information on the new instruction we should copy the old sharding
938   // information (if any).
939   if (!new_instruction->has_sharding()) {
940     new_instruction->set_sharding(old_instruction->sharding_ptr());
941   }
943   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
944   return RemoveInstructionAndUnusedOperands(old_instruction);
945 }
CollectUnreachableRoots() const947 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
948   std::vector<HloInstruction*> unreachable_roots;
949   for (auto* instruction : instructions()) {
950     if (instruction->user_count() == 0 &&
951         instruction->control_successors().empty() &&
952         instruction != root_instruction()) {
953       unreachable_roots.push_back(instruction);
954     }
955   }
956   VLOG(3) << "Unreachable roots:"
957           << absl::StrJoin(unreachable_roots, "\n\t",
958                            [](string* out, const HloInstruction* hlo) {
959                              absl::StrAppend(out, hlo->ToString());
960                            });
961   return unreachable_roots;
962 }
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const964 Status HloComputation::AcceptWithOperandOrder(
965     DfsHloVisitor* visitor,
966     const HloInstruction::CompareFunction& operand_order) const {
967   // Visit unreachable roots. Beware that the visitor might delete the currently
968   // visited root, which would invalidate iterators if the unreachable roots
969   // weren't computed ahead of time.
970   for (HloInstruction* root : CollectUnreachableRoots()) {
972         root->AcceptWithOperandOrder(visitor, operand_order,
973                                      /*call_finish_visit=*/false));
974   }
975   // Visit the computation root instruction last.
976   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
977                                                     /*call_finish_visit=*/true);
978 }
Clone(const string & suffix,HloCloneContext * context)980 std::unique_ptr<HloComputation> HloComputation::Clone(
981     const string& suffix, HloCloneContext* context) {
982   return CloneWithReplacements(
983       /*replacements=*/absl::flat_hash_map<const HloInstruction*,
984                                            std::unique_ptr<HloInstruction>>(),
985       /*extra_parameters=*/{}, context, suffix);
986 }
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)988 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
989     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
990     HloCloneContext* context, const string& suffix) {
991   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
992       replacements;
993   replacements.emplace(std::move(r1));
994   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
995                                context, suffix);
996 }
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)998 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
999     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1000     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1001     HloCloneContext* context, const string& suffix) {
1002   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1003       replacements;
1004   replacements.emplace(std::move(r1));
1005   replacements.emplace(std::move(r2));
1006   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1007                                context, suffix);
1008 }
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)1010 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1011     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1012     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1013     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1014     HloCloneContext* context, const string& suffix) {
1015   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1016       replacements;
1017   replacements.emplace(std::move(r1));
1018   replacements.emplace(std::move(r2));
1019   replacements.emplace(std::move(r3));
1020   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1021                                context, suffix);
1022 }
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix,const HloInstruction * new_root)1024 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1025     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1026         replacements,
1027     absl::Span<const HloInstruction* const> extra_parameters,
1028     HloCloneContext* context, const string& suffix,
1029     const HloInstruction* new_root) {
1030   std::unique_ptr<HloCloneContext> context_ptr;
1031   if (context == nullptr) {
1032     context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
1033     context = context_ptr.get();
1034   }
1035   if (new_root == nullptr) {
1036     new_root = root_instruction();
1037   }
1039   // Look up instr in the replacements map, and return either the replacement,
1040   // or instr, if the replacement isn't present.
1041   //
1042   // Note: This can return null, indicating that instr should not be present in
1043   // the new computation.
1044   auto replace = [&](const HloInstruction* instr) {
1045     auto it = replacements.find(instr);
1046     return it != replacements.end() ? it->second.get() : instr;
1047   };
1049   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1051   // We want to do a postorder walk over [replace(i) for i in instructions_].
1052   // We can't reuse MakeInstructionPostOrder() for this, because that will
1053   // generate a postorder of plain instructions_, and our replacements may
1054   // change the postorder!
1055   //
1056   // The postorder we want here is simpler than what MakeInstructionPostOrder()
1057   // does -- we only care about operand dependencies -- so let's just do it
1058   // ourselves.
1059   std::vector<const HloInstruction*> postorder;
1060   absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1061   for (const auto& instr : instructions_) {
1062     std::vector<const HloInstruction*> dfs_stack;
1063     const HloInstruction* new_instr = replace(instr.get());
1064     if (!new_instr) {
1065       continue;
1066     }
1067     dfs_stack.push_back(new_instr);
1069     while (!dfs_stack.empty()) {
1070       auto* cur = dfs_stack.back();
1071       auto it = visited.find(cur);
1072       if (it != visited.end()) {
1073         dfs_stack.pop_back();
1074         if (it->second == kVisited) {
1075           continue;
1076         }
1077         CHECK_EQ(it->second, kVisiting);
1078         postorder.push_back(cur);
1079         it->second = kVisited;
1080         continue;
1081       }
1083       visited.insert({cur, kVisiting});
1084       for (HloInstruction* operand : cur->operands()) {
1085         const HloInstruction* new_operand = replace(operand);
1086         if (new_operand) {
1087           dfs_stack.emplace_back(new_operand);
1088         }
1089       }
1090     }
1091   }
1093   std::vector<std::unique_ptr<HloInstruction>> instructions;
1094   // First add the extra parameters to 'instructions'.
1095   for (const auto& instr : extra_parameters) {
1096     CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1097         << "Only parameter instructions are allowed in 'extra_parameters'";
1098     instructions.emplace_back(instr->Clone());
1099   }
1100   for (auto instr : postorder) {
1101     std::vector<HloInstruction*> new_operands;
1102     for (auto operand : instr->operands()) {
1103       auto replaced_operand = replace(operand);
1104       CHECK_NE(replaced_operand, nullptr)
1105           << "replacements map tried to eliminate a used instruction "
1106           << operand->ToString() << ", used by " << instr->ToString();
1107       new_operands.push_back(context->GetInstruction(replaced_operand));
1108     }
1109     std::unique_ptr<HloInstruction> new_instr =
1110         instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1111     if (instr->opcode() == HloOpcode::kParameter &&
1112         instr->parameter_replicated_at_leaf_buffers().has_value()) {
1113       new_instr->set_parameter_replicated_at_leaf_buffers(
1114           instr->parameter_replicated_at_leaf_buffers().value());
1115     }
1116     instructions.push_back(std::move(new_instr));
1117   }
1118   Builder builder(name() + "." + suffix);
1119   for (auto& instr : instructions) {
1120     builder.AddInstruction(std::move(instr));
1121   }
1122   auto result = builder.Build(
1123       /*root_instruction=*/context->GetInstruction(replace(new_root)));
1125   // Clone control dependencies.
1126   for (auto instr : postorder) {
1127     HloInstruction* new_instr = context->GetInstruction(instr);
1128     for (auto successor : instr->control_successors()) {
1129       auto replaced_successor = replace(successor);
1130       // successor may not have been remapped, because it might have been
1131       // removed by the replacements map.
1132       if (replaced_successor != nullptr) {
1133         TF_CHECK_OK(new_instr->AddControlDependencyTo(
1134             context->GetInstruction(replaced_successor)));
1135       }
1136     }
1137   }
1138   context->MapComputation(this, result.get());
1139   return result;
1140 }
UniquifyName(NameUniquer * name_uniquer)1142 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1143   name_ = name_uniquer->GetUniqueName(name_);
1144 }
GetInstructionWithName(absl::string_view name)1146 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1147   auto instructions_in_computation = instructions();
1148   auto it = absl::c_find_if(
1149       instructions_in_computation,
1150       [&](HloInstruction* instr) { return instr->name() == name; });
1151   return it == instructions_in_computation.end() ? nullptr : *it;
1152 }
IsEntryComputation() const1154 bool HloComputation::IsEntryComputation() const {
1155   return parent()->entry_computation() == this;
1156 }
1157 }  // namespace xla