1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_computation.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <functional>
21 #include <list>
22 #include <queue>
23 #include <set>
24 #include <sstream>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/numbers.h"
31 #include "absl/strings/str_cat.h"
32 #include "absl/strings/str_join.h"
33 #include "tensorflow/compiler/xla/layout_util.h"
34 #include "tensorflow/compiler/xla/map_util.h"
35 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
36 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/status_macros.h"
41 #include "tensorflow/compiler/xla/types.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/status.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 
49 using absl::StrCat;
50 
Build(HloInstruction * root_instruction)51 std::unique_ptr<HloComputation> HloComputation::Builder::Build(
52     HloInstruction* root_instruction) {
53   int parameter_count = 0;
54   for (auto& instruction : instructions_) {
55     if (instruction->opcode() == HloOpcode::kParameter) {
56       parameter_count++;
57     }
58   }
59   // If root_instruction is not specified use the last added instruction.
60   HloInstruction* root =
61       root_instruction ? root_instruction : last_added_instruction_;
62   CHECK_NE(nullptr, root);
63   return absl::WrapUnique(new HloComputation(
64       name_, parameter_count, &instructions_, root, fusion_instruction_));
65 }
66 
HloComputation(const string & name,int parameter_count,std::vector<std::unique_ptr<HloInstruction>> * instructions,HloInstruction * root_instruction,HloInstruction * fusion_instruction)67 HloComputation::HloComputation(
68     const string& name, int parameter_count,
69     std::vector<std::unique_ptr<HloInstruction>>* instructions,
70     HloInstruction* root_instruction, HloInstruction* fusion_instruction)
71     : name_(NameUniquer::GetSanitizedName(name)),
72       unique_id_(-1),
73       root_instruction_(root_instruction),
74       fusion_instruction_(fusion_instruction) {
75   param_instructions_.resize(parameter_count, nullptr);
76   bool root_found = false;
77   for (auto& instruction : *instructions) {
78     if (instruction->opcode() == HloOpcode::kParameter) {
79       int64 param_no = instruction->parameter_number();
80       CHECK(param_no >= 0 && param_no < parameter_count)
81           << "\nERROR: invalid parameter number.  Expected [0, "
82           << parameter_count << "), got " << param_no;
83       CHECK(param_instructions_[param_no] == nullptr)
84           << "\nERROR: parameter number " << param_no
85           << " already allocated in this computation";
86       param_instructions_[param_no] = instruction.get();
87     }
88     root_found |= instruction.get() == root_instruction_;
89     AddInstructionInternal(std::move(instruction));
90   }
91   CHECK(root_found)
92       << "\nERROR: root instruction is not present in computation.";
93 }
94 
AddInstruction(std::unique_ptr<HloInstruction> instruction,const std::string & new_name)95 HloInstruction* HloComputation::AddInstruction(
96     std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
97   CHECK(instruction->opcode() != HloOpcode::kParameter)
98       << "Parameter instructions cannot be added to a computation after "
99       << "it has been built";
100   if (!new_name.empty()) {
101     instruction->SetAndSanitizeName(new_name);
102   }
103   return AddInstructionInternal(std::move(instruction));
104 }
105 
AddInstructionInternal(std::unique_ptr<HloInstruction> instruction)106 HloInstruction* HloComputation::AddInstructionInternal(
107     std::unique_ptr<HloInstruction> instruction) {
108   if (parent() != nullptr) {
109     instruction->UniquifyName(&parent()->instruction_name_uniquer());
110     instruction->SetUniqueId(parent()->NewUniqueInstructionId());
111   }
112   instruction->set_parent(this);
113   HloInstruction* pinst = instruction.get();
114   instruction_iterators_[pinst] =
115       instructions_.insert(instructions_.end(), std::move(instruction));
116   return pinst;
117 }
118 
AddParameter(std::unique_ptr<HloInstruction> instruction)119 HloInstruction* HloComputation::AddParameter(
120     std::unique_ptr<HloInstruction> instruction) {
121   CHECK(instruction->opcode() == HloOpcode::kParameter);
122   CHECK(IsFusionComputation());
123   CHECK(fusion_instruction_->operand_count() == param_instructions_.size());
124   instruction->set_parent(this);
125   param_instructions_.push_back(instruction.get());
126   AddInstructionInternal(std::move(instruction));
127   return instructions_.back().get();
128 }
129 
AddEntryComputationParameter(std::unique_ptr<HloInstruction> instruction)130 HloInstruction* HloComputation::AddEntryComputationParameter(
131     std::unique_ptr<HloInstruction> instruction) {
132   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
133   CHECK_EQ(instruction->parameter_number(), num_parameters());
134   CHECK(parent()->entry_computation() == this);
135 
136   HloModuleConfig config = parent()->config();
137   config.mutable_entry_computation_layout()->add_parameter_layout(
138       ShapeLayout(instruction->shape()));
139   parent()->set_config(config);
140 
141   instruction->set_parent(this);
142   param_instructions_.push_back(instruction.get());
143   AddInstructionInternal(std::move(instruction));
144 
145   return instructions_.back().get();
146 }
147 
ReplaceEntryComputationParameter(int64 param_no,HloInstruction * old_instruction,std::unique_ptr<HloInstruction> instruction)148 Status HloComputation::ReplaceEntryComputationParameter(
149     int64 param_no, HloInstruction* old_instruction,
150     std::unique_ptr<HloInstruction> instruction) {
151   CHECK_GE(param_no, 0);
152   CHECK_LT(param_no, param_instructions_.size());
153   CHECK_EQ(instruction->opcode(), HloOpcode::kParameter);
154   CHECK(parent()->entry_computation() == this);
155 
156   HloModuleConfig config = parent()->config();
157   *config.mutable_entry_computation_layout()->mutable_parameter_layout(
158       param_no) = ShapeLayout(instruction->shape());
159   parent()->set_config(config);
160 
161   instruction->set_parent(this);
162   param_instructions_[param_no] = instruction.get();
163   AddInstructionInternal(std::move(instruction));
164 
165   return ForceRemoveInstruction(old_instruction);
166 }
167 
RemoveParameter(int64 param_no)168 Status HloComputation::RemoveParameter(int64 param_no) {
169   CHECK_GE(param_no, 0);
170   CHECK_LT(param_no, param_instructions_.size());
171   CHECK(IsFusionComputation());
172   HloInstruction* param_instruction = param_instructions_[param_no];
173   auto param_instruction_iterator = param_instructions_.begin() + param_no;
174   param_instructions_.erase(param_instruction_iterator);
175   // Throw removed fused parameter instruction away.
176   TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
177 
178   while (param_no < param_instructions_.size()) {
179     param_instruction = param_instructions_[param_no];
180     HloInstruction* new_instr =
181         AddInstructionInternal(HloInstruction::CreateParameter(
182             param_no, param_instruction->shape(), StrCat("param_", param_no)));
183     TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
184     param_instructions_[param_no] = new_instr;
185     TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
186     param_no++;
187   }
188 
189   return Status::OK();
190 }
191 
RemoveUnusedParametersFromFusedComputation()192 Status HloComputation::RemoveUnusedParametersFromFusedComputation() {
193   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/false);
194 }
195 
RemoveUnusedParametersFromAnyComputation()196 Status HloComputation::RemoveUnusedParametersFromAnyComputation() {
197   return RemoveUnusedParametersImpl(/*allow_non_fusion=*/true);
198 }
199 
RemoveUnusedParametersImpl(bool allow_non_fusion)200 Status HloComputation::RemoveUnusedParametersImpl(bool allow_non_fusion) {
201   CHECK(allow_non_fusion || IsFusionComputation());
202   int64 removed = 0;
203   for (int64 i = 0; i < param_instructions_.size(); ++i) {
204     HloInstruction* param_instruction = param_instructions_[i];
205     if (param_instruction->user_count() == 0 &&
206         param_instruction != root_instruction()) {
207       TF_RETURN_IF_ERROR(
208           RemoveInstructionImpl(param_instruction, allow_non_fusion));
209       ++removed;
210       continue;
211     }
212 
213     if (removed > 0) {
214       const int64 param_no = i - removed;
215       HloInstruction* new_instr = AddInstructionInternal(
216           HloInstruction::CreateParameter(param_no, param_instruction->shape(),
217                                           StrCat("param_", param_no)));
218       TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
219       param_instructions_[param_no] = new_instr;
220       TF_RETURN_IF_ERROR(
221           RemoveInstructionImpl(param_instruction, allow_non_fusion));
222     }
223   }
224   param_instructions_.resize(param_instructions_.size() - removed);
225   return Status::OK();
226 }
227 
IsSafelyRemovable(const HloInstruction * instruction)228 bool HloComputation::IsSafelyRemovable(const HloInstruction* instruction) {
229   // If the instruction has control predecessors or successors then we cannot
230   // remove the instruction without violating ordering constraints (added, for
231   // example, to avert interference due to buffer aliasing).
232   if (!instruction->control_predecessors().empty() ||
233       !instruction->control_successors().empty()) {
234     return false;
235   }
236 
237   if (instruction->opcode() == HloOpcode::kParameter &&
238       !IsFusionComputation()) {
239     return false;
240   }
241 
242   return true;
243 }
244 
HasSideEffect() const245 bool HloComputation::HasSideEffect() const {
246   for (auto* instruction : instructions()) {
247     if (instruction->HasSideEffect()) {
248       return true;
249     }
250   }
251   return false;
252 }
253 
IsMarkedAsDead(const HloInstruction * inst)254 bool HloComputation::IsMarkedAsDead(const HloInstruction* inst) {
255   return inst->IsMarkedAsDead();
256 }
257 
RemoveInstructionAndUnusedOperands(HloInstruction * instruction,std::function<void (HloInstruction *)> cleanup)258 Status HloComputation::RemoveInstructionAndUnusedOperands(
259     HloInstruction* instruction, std::function<void(HloInstruction*)> cleanup) {
260   TF_RET_CHECK(root_instruction() != instruction);
261 
262   TF_RET_CHECK(instruction->user_count() == 0);
263   TF_RET_CHECK(IsSafelyRemovable(instruction))
264       << "Cannot remove instruction: " << instruction->ToString();
265   absl::flat_hash_set<HloInstruction*> removed;
266   std::queue<HloInstruction*> worklist;
267   worklist.push(instruction);
268   while (!worklist.empty()) {
269     HloInstruction* item = worklist.front();
270     worklist.pop();
271 
272     if (removed.contains(item) || item->user_count() != 0 ||
273         item == root_instruction() || !IsSafelyRemovable(item) ||
274         (item->HasSideEffect() && item != instruction)) {
275       continue;
276     }
277     for (int i = 0; i < item->operand_count(); ++i) {
278       worklist.push(item->mutable_operand(i));
279     }
280 
281     if (cleanup) {
282       cleanup(item);
283     }
284     TF_RETURN_IF_ERROR(RemoveInstruction(item));
285     removed.insert(item);
286   }
287   return Status::OK();
288 }
289 
RemoveInstruction(HloInstruction * instruction)290 Status HloComputation::RemoveInstruction(HloInstruction* instruction) {
291   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/false);
292 }
293 
ForceRemoveInstruction(HloInstruction * instruction)294 Status HloComputation::ForceRemoveInstruction(HloInstruction* instruction) {
295   return RemoveInstructionImpl(instruction, /*ignore_safety_check=*/true);
296 }
297 
RemoveInstructionImpl(HloInstruction * instruction,bool ignore_safety_check)298 Status HloComputation::RemoveInstructionImpl(HloInstruction* instruction,
299                                              bool ignore_safety_check) {
300   VLOG(2) << "Removing instruction " << instruction->name()
301           << " from computation " << name();
302   TF_RET_CHECK(ignore_safety_check || IsSafelyRemovable(instruction))
303       << "cannot remove instruction: " << instruction->ToString();
304   TF_RET_CHECK(root_instruction() != instruction)
305       << "cannot remove root instruction " << instruction->name();
306   TF_RET_CHECK(instruction->user_count() == 0)
307       << "instruction " << instruction->name()
308       << " has users and cannot be removed";
309   TF_RET_CHECK(instruction->control_predecessors().empty())
310       << "instruction " << instruction->name()
311       << " has control predecessors and cannot be removed";
312   TF_RET_CHECK(instruction->control_successors().empty())
313       << "instruction " << instruction->name()
314       << " has control successors and cannot be removed";
315 
316   auto inst_it = instruction_iterators_.find(instruction);
317   TF_RET_CHECK(inst_it != instruction_iterators_.end());
318   (*inst_it->second)->set_parent(nullptr);
319   to_be_deleted_.emplace_back(inst_it->second->release());
320   to_be_deleted_.back()->DetachFromOperandsAndUsers();
321   // Clear all operands to avoid Null operands.
322   to_be_deleted_.back()->RemoveAllOperands();
323   to_be_deleted_.back()->ClearCalledComputations();
324   to_be_deleted_.back()->MarkAsDead();
325   instructions_.erase(inst_it->second);
326   instruction_iterators_.erase(inst_it);
327   return Status::OK();
328 }
329 
set_root_instruction(HloInstruction * new_root_instruction,bool accept_different_shape)330 void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
331                                           bool accept_different_shape) {
332   // The shape of the root (ignoring layout) is an invariant of the computation
333   // for non-fusion cases.
334   if (!IsFusionComputation() && !accept_different_shape) {
335     CHECK(ShapeUtil::Compatible(new_root_instruction->shape(),
336                                 root_instruction_->shape()))
337         << new_root_instruction->shape() << " is incompatible with "
338         << root_instruction_->shape();
339   }
340   bool root_found = false;
341   for (auto& instruction : instructions_) {
342     if (new_root_instruction == instruction.get()) {
343       root_found = true;
344       break;
345     }
346   }
347   DCHECK(root_found);
348 
349   if (parent() && parent()->has_entry_computation() &&
350       parent()->entry_computation() == this) {
351     if (!Shape::Equal().IgnoreLayout()(new_root_instruction->shape(),
352                                        root_instruction_->shape())) {
353       // Rebuild input output alias config now that we have a new output shape.
354       parent()->input_output_alias_config() =
355           HloInputOutputAliasConfig(new_root_instruction->shape());
356     }
357   }
358 
359   root_instruction_ = new_root_instruction;
360 }
361 
362 namespace {
363 
364 // Helper which builds a post order of the HLO call graph.
ComputeComputationPostOrder(HloComputation * computation,absl::flat_hash_set<HloComputation * > * visited,std::vector<HloComputation * > * post_order)365 void ComputeComputationPostOrder(HloComputation* computation,
366                                  absl::flat_hash_set<HloComputation*>* visited,
367                                  std::vector<HloComputation*>* post_order) {
368   if (visited->insert(computation).second) {
369     for (auto* instruction : computation->instructions()) {
370       for (HloComputation* called_computation :
371            instruction->called_computations()) {
372         ComputeComputationPostOrder(called_computation, visited, post_order);
373       }
374     }
375     post_order->push_back(computation);
376   }
377 }
378 
379 }  // namespace
380 
ComputeInstructionPostOrder(const HloComputation::ChannelDependencyGroup & channel_dependency_group,std::vector<HloInstruction * > * post_order,HloInstruction * root,absl::flat_hash_map<HloInstruction *,VisitState> * visited) const381 void HloComputation::ComputeInstructionPostOrder(
382     const HloComputation::ChannelDependencyGroup& channel_dependency_group,
383     std::vector<HloInstruction*>* post_order, HloInstruction* root,
384     absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
385   std::vector<HloInstruction*> dfs_stack;
386   dfs_stack.push_back(root);
387   while (!dfs_stack.empty()) {
388     const auto current = dfs_stack.back();
389     CHECK_EQ(current->parent(), this)
390         << "Instruction " << current->name()
391         << " is not in the current computation (" << name() << ").";
392     auto it = visited->find(current);
393     if (it != visited->end()) {
394       if (it->second == kVisited) {
395         // Already visited.
396         dfs_stack.pop_back();
397         continue;
398       }
399       // Visit this node.
400       CHECK_EQ(kVisiting, it->second);
401       dfs_stack.pop_back();
402       post_order->push_back(current);
403       it->second = kVisited;
404       continue;
405     }
406 
407     visited->insert({current, kVisiting});
408 
409     const auto get_channel_id =
410         [](HloInstruction* inst) -> absl::optional<int64> {
411       switch (inst->opcode()) {
412         case HloOpcode::kRecvDone:
413           return inst->channel_id();
414         case HloOpcode::kAllReduce:
415           return inst->channel_id();
416         default:
417           return absl::nullopt;
418       }
419     };
420 
421     // When adding a predecessor to the dfs_stack, we need to also add its
422     // associated channel dependencies.
423     const auto add_dfs_stack = [&](HloInstruction* inst) {
424       auto channel_id = get_channel_id(inst);
425       if (channel_id && channel_dependency_group.count(*channel_id)) {
426         auto it = channel_dependency_group.find(*channel_id);
427         for (HloInstruction* cinst : it->second) {
428           dfs_stack.emplace_back(cinst);
429         }
430       } else {
431         dfs_stack.emplace_back(inst);
432       }
433     };
434 
435     const auto add_predecessors = [&](HloInstruction* inst) {
436       // Add the operands to the stack in reverse order so the first operand is
437       // processed first. This will produce a more natural ordering and a nicer
438       // result for things like HLO stringification.
439       const auto& operands = inst->operands();
440       for (int64 i = operands.size() - 1; i >= 0; --i) {
441         add_dfs_stack(operands[i]);
442       }
443 
444       for (HloInstruction* op : inst->control_predecessors()) {
445         add_dfs_stack(op);
446       }
447     };
448 
449     // If the current instruction is a channel instruction, add the dependencies
450     // from all associated instructions of the channel.
451     auto channel_id = get_channel_id(current);
452     if (channel_id && channel_dependency_group.count(*channel_id)) {
453       auto it = channel_dependency_group.find(*channel_id);
454       for (HloInstruction* cinst : it->second) {
455         add_predecessors(cinst);
456       }
457     } else {
458       add_predecessors(current);
459     }
460   }
461 }
462 
463 HloComputation::ChannelDependencyGroup
ComputeChannelDependencies() const464 HloComputation::ComputeChannelDependencies() const {
465   ChannelDependencyGroup channel_dependency_group;
466   for (const auto& instruction : instructions_) {
467     switch (instruction->opcode()) {
468       case HloOpcode::kSend:
469       case HloOpcode::kRecvDone:
470       case HloOpcode::kAllReduce: {
471         auto channel_id = instruction->channel_id();
472         if (channel_id) {
473           channel_dependency_group[channel_id.value()].push_back(
474               instruction.get());
475         }
476         break;
477       }
478       default:
479         break;
480     }
481   }
482   return channel_dependency_group;
483 }
484 
HasOnlyTraceUsers(const HloInstruction * instruction)485 static inline bool HasOnlyTraceUsers(const HloInstruction* instruction) {
486   return absl::c_all_of(instruction->users(), [](HloInstruction* user) {
487     return user->opcode() == HloOpcode::kTrace;
488   });
489 }
490 
MakeInstructionPostOrder() const491 std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
492   auto channel_dependency_group = ComputeChannelDependencies();
493   std::vector<HloInstruction*> post_order;
494   post_order.reserve(instruction_count());
495   std::vector<HloInstruction*> trace_instructions;
496   absl::flat_hash_map<HloInstruction*, VisitState> visited;
497   visited.reserve(instruction_count());
498   for (auto& instruction : instructions_) {
499     if (instruction->opcode() == HloOpcode::kTrace) {
500       // Trace instructions aren't handled by the DFS visitor. Add trace
501       // instructions to the post order at the end (necessarily they have no
502       // users).
503       trace_instructions.push_back(instruction.get());
504     } else if (HasOnlyTraceUsers(instruction.get())) {
505       ComputeInstructionPostOrder(channel_dependency_group, &post_order,
506                                   instruction.get(), &visited);
507     }
508   }
509   post_order.insert(post_order.end(), trace_instructions.begin(),
510                     trace_instructions.end());
511   CHECK_EQ(instructions_.size(), post_order.size())
512       << "number of instructions does not match post order size";
513   return post_order;
514 }
515 
MakeEmbeddedComputationsList() const516 std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
517     const {
518   absl::flat_hash_set<HloComputation*> visited;
519   std::vector<HloComputation*> post_order;
520 
521   // To avoid special handling of this computation, cast away const of
522   // 'this'. 'this' is immediately removed from the post order after
523   // construction.
524   //
525   // TODO(b/78350259): This violates const-correctness, since while the original
526   // computation is not returned, we still retrieve non-const computations from
527   // a const one. Consider also avoiding const for HloComputation, or review XLA
528   // for const-correctness of non-HloInstruction* types like this.
529   ComputeComputationPostOrder(const_cast<HloComputation*>(this), &visited,
530                               &post_order);
531 
532   // We don't want to include this computation in the post order.
533   CHECK_EQ(this, post_order.back());
534   post_order.pop_back();
535 
536   return post_order;
537 }
538 
ToString(const HloPrintOptions & options) const539 string HloComputation::ToString(const HloPrintOptions& options) const {
540   return ToString(options, MakeInstructionPostOrder());
541 }
542 
ToString(const HloPrintOptions & options,absl::Span<const HloInstruction * const> instruction_order) const543 string HloComputation::ToString(
544     const HloPrintOptions& options,
545     absl::Span<const HloInstruction* const> instruction_order) const {
546   CHECK_EQ(instruction_order.size(), instruction_count());
547 
548   const string tab(2 * options.indent_amount(), ' ');
549 
550   std::ostringstream s;
551   s << tab;
552 
553   if (!options.is_in_nested_computation()) {
554     if (options.print_percent()) {
555       s << "%";
556     }
557     if (options.print_ids()) {
558       // Exclude entry computation's name because it includes and leads to
559       // non-deterministic fingerprint.
560       s << PrintName(name(), options.print_ids()) << " ";
561     }
562   }
563 
564   if (options.print_program_shape()) {
565     s << ShapeUtil::HumanString(ComputeProgramShape(options.print_ids()))
566       << " ";
567   }
568   s << "{\n";
569 
570   // There are instructions which are required to be printed. Additionally, we
571   // print some instructions before and after required ones. The resulting
572   // output has the following format.
573   //
574   //  computation {
575   //    ...
576   //    additional_instructions
577   //    required_instructions
578   //    additional_instructions
579   //    ...
580   //    additional_instructions
581   //    required_instructions
582   //    additional_instructions
583   //    ...
584   //  }
585   std::set<int> instructions_to_print;
586   {
587     // Find all the instructions that should be printed.
588     auto add_instruction = [&instructions_to_print,
589                             &instruction_order](int index) {
590       if (index < 0 || index >= instruction_order.size()) {
591         return;
592       }
593       instructions_to_print.insert(index);
594     };
595 
596     auto add_instructions_arround = [&add_instruction, &options](int index) {
597       for (int i = index - options.leading_and_trailing_instructions_number();
598            i <= index + options.leading_and_trailing_instructions_number();
599            ++i) {
600         add_instruction(i);
601       }
602     };
603 
604     for (int i = 0; i < instruction_order.size(); ++i) {
605       const HloInstruction* instruction = instruction_order[i];
606       CHECK_EQ(this, instruction->parent());
607       if (options.print_instruction(instruction)) {
608         add_instructions_arround(i);
609       }
610     }
611   }
612 
613   {
614     // Print the instructions in this computation.
615     HloPrintOptions new_options = options;
616     new_options.set_indent_amount(options.indent_amount() + 1)
617         .set_is_in_nested_computation(true);
618 
619     const string new_tab(2 * new_options.indent_amount(), ' ');
620 
621     CanonicalNameMap name_map;
622 
623     bool print_prev = true;
624     for (int index = 0; index < instruction_order.size(); ++index) {
625       const HloInstruction* instruction = instruction_order[index];
626       if (instructions_to_print.find(index) != instructions_to_print.end()) {
627         s << new_options.format_instruction(
628                  instruction,
629                  instruction->ToStringWithCanonicalNameMap(new_options,
630                                                            &name_map),
631                  new_options.indent_amount(), instruction == root_instruction_)
632           << "\n";
633         print_prev = true;
634       } else if (print_prev) {
635         s << new_tab << "...\n";
636         print_prev = false;
637       }
638     }
639   }
640 
641   s << tab << "}";
642   return s.str();
643 }
644 
ToProto() const645 HloComputationProto HloComputation::ToProto() const {
646   HloComputationProto proto;
647   CHECK(unique_id_ != -1)
648       << "This computation does not have a valid id. Please make sure the "
649          "computation is inside a module before dumping it.";
650   proto.set_id(unique_id_);
651   proto.set_name(name_);
652   for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
653     HloInstructionProto instruction_proto = instruction->ToProto();
654     proto.add_instructions()->Swap(&instruction_proto);
655   }
656   proto.set_root_id(root_instruction()->unique_id());
657   *proto.mutable_program_shape() = ComputeProgramShape().ToProto();
658   return proto;
659 }
660 
661 /* static */ StatusOr<std::unique_ptr<HloComputation>>
CreateFromProto(const HloComputationProto & proto,const absl::flat_hash_map<int64,HloComputation * > & computation_map,bool prohibit_empty_literal)662 HloComputation::CreateFromProto(
663     const HloComputationProto& proto,
664     const absl::flat_hash_map<int64, HloComputation*>& computation_map,
665     bool prohibit_empty_literal) {
666   absl::flat_hash_map<int64, HloInstruction*> instruction_map;
667   absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
668   std::vector<std::unique_ptr<HloInstruction>> instructions;
669   int64 parameter_count = 0;
670   for (const HloInstructionProto& instruction_proto : proto.instructions()) {
671     TF_ASSIGN_OR_RETURN(std::unique_ptr<HloInstruction> instruction,
672                         HloInstruction::CreateFromProto(
673                             instruction_proto, instruction_map, computation_map,
674                             prohibit_empty_literal));
675     if (instruction->opcode() == HloOpcode::kParameter) {
676       parameter_count++;
677     }
678     TF_RET_CHECK(!ContainsKey(instruction_map, instruction_proto.id()));
679     instruction_map[instruction_proto.id()] = instruction.get();
680     to_proto_id[instruction.get()] = instruction_proto.id();
681     instructions.push_back(std::move(instruction));
682   }
683 
684   TF_RET_CHECK(proto.root_id() != -1);
685   TF_RET_CHECK(ContainsKey(instruction_map, proto.root_id()));
686   HloInstruction* root = instruction_map.at(proto.root_id());
687 
688   // Sort the instructions in the proto id's order.
689   absl::c_sort(instructions, [&](const std::unique_ptr<HloInstruction>& a,
690                                  const std::unique_ptr<HloInstruction>& b) {
691     return to_proto_id[a.get()] < to_proto_id[b.get()];
692   });
693 
694   TF_RETURN_IF_ERROR([&]() -> Status {
695     std::vector<bool> parameters_seen(parameter_count);
696     int parameters_seen_count = 0;
697     for (auto& instruction : instructions) {
698       if (instruction->opcode() == HloOpcode::kParameter) {
699         int64 param_no = instruction->parameter_number();
700         TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
701             << "Invalid parameter number.  Expected [0, " << parameter_count
702             << "), got " << param_no;
703         TF_RET_CHECK(!parameters_seen[param_no])
704             << "Parameter number " << param_no
705             << " already allocated in this computation";
706         parameters_seen[param_no] = true;
707         parameters_seen_count++;
708       }
709     }
710     TF_RET_CHECK(parameters_seen_count == parameter_count)
711         << "Not all parameters in range [0, " << parameter_count
712         << ") were referenced";
713     return Status::OK();
714   }());
715 
716   auto computation = absl::WrapUnique(
717       new HloComputation(proto.name(), parameter_count, &instructions, root,
718                          /*fusion_instruction=*/nullptr));
719   computation->unique_id_ = proto.id();
720   return std::move(computation);
721 }
722 
FuseInstructionsInto(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction * fusion_instruction)723 void HloComputation::FuseInstructionsInto(
724     absl::Span<HloInstruction* const> instructions_to_fuse,
725     HloInstruction* fusion_instruction) {
726   CHECK_EQ(HloOpcode::kFusion, fusion_instruction->opcode());
727   HloInstruction* root = instructions_to_fuse.front();
728   TF_CHECK_OK(root->ReplaceAllUsesWith(fusion_instruction));
729   if (root == root_instruction()) {
730     set_root_instruction(fusion_instruction);
731   }
732   TF_CHECK_OK(RemoveInstruction(root));
733   for (size_t i = 1; i < instructions_to_fuse.size(); ++i) {
734     HloInstruction* instruction = instructions_to_fuse[i];
735     fusion_instruction->FuseInstruction(instruction);
736     if (instruction->user_count() == 0) {
737       TF_CHECK_OK(RemoveInstruction(instruction));
738     }
739   }
740 }
741 
CreateFusionInstruction(absl::Span<HloInstruction * const> instructions_to_fuse,HloInstruction::FusionKind fusion_kind)742 HloInstruction* HloComputation::CreateFusionInstruction(
743     absl::Span<HloInstruction* const> instructions_to_fuse,
744     HloInstruction::FusionKind fusion_kind) {
745   HloInstruction* root = instructions_to_fuse.front();
746   HloInstruction* fusion_instruction = AddInstruction(
747       HloInstruction::CreateFusion(root->shape(), fusion_kind, root));
748   FuseInstructionsInto(instructions_to_fuse, fusion_instruction);
749   return fusion_instruction;
750 }
751 
DeepCopyHelper(HloInstruction * instruction,ShapeIndex * index,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)752 StatusOr<HloInstruction*> HloComputation::DeepCopyHelper(
753     HloInstruction* instruction, ShapeIndex* index,
754     const std::function<
755         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
756                         HloComputation* computation)>& copy_leaf) {
757   if (instruction->shape().IsTuple()) {
758     std::vector<HloInstruction*> elements;
759     for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape());
760          i++) {
761       HloInstruction* gte =
762           AddInstruction(HloInstruction::CreateGetTupleElement(
763               ShapeUtil::GetTupleElementShape(instruction->shape(), i),
764               instruction, i));
765 
766       index->push_back(i);
767       TF_ASSIGN_OR_RETURN(HloInstruction * element,
768                           DeepCopyHelper(gte, index, copy_leaf));
769       elements.push_back(element);
770       index->pop_back();
771     }
772     return AddInstruction(HloInstruction::CreateTuple(elements));
773   }
774   if (instruction->shape().IsToken()) {
775     // Tokens have no on-device representation and cannot be copied. Pass
776     // through transparently.
777     return instruction;
778   }
779 
780   // Array shape.
781   TF_RET_CHECK(instruction->shape().IsArray());
782   return copy_leaf(instruction, *index, this);
783 }
784 
DeepCopyInstruction(HloInstruction * instruction,const ShapeTree<bool> * indices_to_copy,ShapeTree<HloInstruction * > * copies_added)785 StatusOr<HloInstruction*> HloComputation::DeepCopyInstruction(
786     HloInstruction* instruction, const ShapeTree<bool>* indices_to_copy,
787     ShapeTree<HloInstruction*>* copies_added) {
788   if (instruction->parent() != this) {
789     return FailedPrecondition(
790         "Can't deep copy instruction %s: instruction is not in computation %s",
791         instruction->name(), name());
792   }
793   if (indices_to_copy != nullptr &&
794       !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) {
795     return FailedPrecondition(
796         "Can't deep copy instruction %s: given shape tree of indices to copy "
797         "has incompatible shapes: %s vs. %s",
798         instruction->name(), ShapeUtil::HumanString(instruction->shape()),
799         ShapeUtil::HumanString(indices_to_copy->shape()));
800   }
801 
802   ShapeIndex index;
803   auto copy_leaf = [indices_to_copy, copies_added](
804                        HloInstruction* leaf, const ShapeIndex& leaf_index,
805                        HloComputation* computation) {
806     if (indices_to_copy == nullptr || indices_to_copy->element(leaf_index)) {
807       HloInstruction* copy = computation->AddInstruction(
808           HloInstruction::CreateUnary(leaf->shape(), HloOpcode::kCopy, leaf));
809       if (copies_added != nullptr) {
810         *copies_added->mutable_element(leaf_index) = copy;
811       }
812       return copy;
813     }
814     // Elements which are not to be copied are passed through
815     // transparently.
816     return leaf;
817   };
818   return DeepCopyHelper(instruction, &index, copy_leaf);
819 }
820 
DeepCopyInstructionWithCustomCopier(HloInstruction * instruction,const std::function<HloInstruction * (HloInstruction * leaf,const ShapeIndex & leaf_index,HloComputation * computation)> & copy_leaf)821 StatusOr<HloInstruction*> HloComputation::DeepCopyInstructionWithCustomCopier(
822     HloInstruction* instruction,
823     const std::function<
824         HloInstruction*(HloInstruction* leaf, const ShapeIndex& leaf_index,
825                         HloComputation* computation)>& copy_leaf) {
826   if (instruction->parent() != this) {
827     return FailedPrecondition(
828         "Can't deep copy instruction %s: instruction is not in computation %s",
829         instruction->name(), name());
830   }
831   ShapeIndex index;
832   return DeepCopyHelper(instruction, &index, copy_leaf);
833 }
834 
ComputeProgramShape(bool include_ids) const835 ProgramShape HloComputation::ComputeProgramShape(bool include_ids) const {
836   ProgramShape program_shape;
837 
838   for (auto* param_instruction : param_instructions_) {
839     *program_shape.add_parameters() = param_instruction->shape();
840     *program_shape.add_parameter_names() =
841         PrintName(param_instruction->name(), include_ids);
842   }
843   *program_shape.mutable_result() = root_instruction_->shape();
844 
845   return program_shape;
846 }
847 
EqualInternal(const HloComputation & other,bool is_layout_sensitive,bool ignore_channel_id_values) const848 bool HloComputation::EqualInternal(const HloComputation& other,
849                                    bool is_layout_sensitive,
850                                    bool ignore_channel_id_values) const {
851   if (this == &other) {
852     return true;
853   }
854   absl::flat_hash_set<std::pair<const HloInstruction*, const HloInstruction*>>
855       visited;
856   std::vector<std::pair<const HloInstruction*, const HloInstruction*>> worklist;
857 
858   worklist.push_back({root_instruction(), other.root_instruction()});
859 
860   while (!worklist.empty()) {
861     auto pair = worklist.back();
862     worklist.pop_back();
863 
864     if (visited.contains(pair)) {
865       continue;
866     }
867     visited.emplace(pair);
868     // TODO(b/123082518): Avoid recursively invoking Equal because it may
869     // cause a stack overflow with deeply nested subcomputations.
870     auto operands_eq = [](const HloInstruction*, const HloInstruction*) {
871       return true;
872     };
873     auto comp_eq = [&](const HloComputation* a, const HloComputation* b) {
874       return a->EqualInternal(*b, is_layout_sensitive,
875                               ignore_channel_id_values);
876     };
877     bool identical_ignoring_operands =
878         ignore_channel_id_values
879             ? pair.first->IdenticalIgnoringChannelIdValues(
880                   *pair.second, operands_eq, comp_eq, is_layout_sensitive)
881             : pair.first->Identical(*pair.second, operands_eq, comp_eq,
882                                     is_layout_sensitive);
883     if (!identical_ignoring_operands) {
884       return false;
885     }
886     for (size_t i = 0; i < pair.first->operands().size(); ++i) {
887       worklist.push_back({pair.first->operand(i), pair.second->operand(i)});
888     }
889   }
890   return true;
891 }
892 
ReplaceWithNewInstruction(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)893 Status HloComputation::ReplaceWithNewInstruction(
894     HloInstruction* old_instruction,
895     std::unique_ptr<HloInstruction> new_instruction) {
896   return ReplaceInstruction(old_instruction,
897                             AddInstruction(std::move(new_instruction)));
898 }
899 
ReplaceWithNewEntryComputationParameter(HloInstruction * old_instruction,std::unique_ptr<HloInstruction> new_instruction)900 Status HloComputation::ReplaceWithNewEntryComputationParameter(
901     HloInstruction* old_instruction,
902     std::unique_ptr<HloInstruction> new_instruction) {
903   return ReplaceInstruction(old_instruction, AddEntryComputationParameter(
904                                                  std::move(new_instruction)));
905 }
906 
ReplaceInstruction(HloInstruction * old_instruction,HloInstruction * new_instruction)907 Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction,
908                                           HloInstruction* new_instruction) {
909   TF_RET_CHECK(
910       ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape()))
911       << ShapeUtil::HumanString(old_instruction->shape()) << " vs "
912       << ShapeUtil::HumanString(new_instruction->shape());
913 
914   VLOG(10) << "transformed " << old_instruction->ToString() << " to "
915            << new_instruction->ToString();
916   // Try to add metadata for HLO instructions that are created to replace
917   // existing HLO instructions (e.g. during optimizations). The assumption is
918   // that the old instruction and the new instruction would perform the same
919   // function, and that they would be correlated to the same TF op. This might
920   // not always be correct since HLO optimizations can cross TF op boundaries.
921   // But still this seems to be better than nothing.
922   bool overwrite_op_name = new_instruction->metadata().op_name().empty() &&
923                            !old_instruction->metadata().op_name().empty();
924   bool overwrite_pass_id =
925       new_instruction->metadata().op_name().empty() &&
926       new_instruction->metadata().logical_creation_pass_id() == 0 &&
927       old_instruction->metadata().logical_creation_pass_id() != 0;
928   if (overwrite_op_name || overwrite_pass_id) {
929     new_instruction->set_metadata(old_instruction->metadata());
930   }
931   if (new_instruction->frontend_attributes().map().empty()) {
932     new_instruction->set_frontend_attributes(
933         old_instruction->frontend_attributes());
934   }
935 
936   // Like the metadata above, if the user didn't specify any sharding
937   // information on the new instruction we should copy the old sharding
938   // information (if any).
939   if (!new_instruction->has_sharding()) {
940     new_instruction->set_sharding(old_instruction->sharding_ptr());
941   }
942 
943   TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction));
944   return RemoveInstructionAndUnusedOperands(old_instruction);
945 }
946 
CollectUnreachableRoots() const947 std::vector<HloInstruction*> HloComputation::CollectUnreachableRoots() const {
948   std::vector<HloInstruction*> unreachable_roots;
949   for (auto* instruction : instructions()) {
950     if (instruction->user_count() == 0 &&
951         instruction->control_successors().empty() &&
952         instruction != root_instruction()) {
953       unreachable_roots.push_back(instruction);
954     }
955   }
956   VLOG(3) << "Unreachable roots:"
957           << absl::StrJoin(unreachable_roots, "\n\t",
958                            [](string* out, const HloInstruction* hlo) {
959                              absl::StrAppend(out, hlo->ToString());
960                            });
961   return unreachable_roots;
962 }
963 
AcceptWithOperandOrder(DfsHloVisitor * visitor,const HloInstruction::CompareFunction & operand_order) const964 Status HloComputation::AcceptWithOperandOrder(
965     DfsHloVisitor* visitor,
966     const HloInstruction::CompareFunction& operand_order) const {
967   // Visit unreachable roots. Beware that the visitor might delete the currently
968   // visited root, which would invalidate iterators if the unreachable roots
969   // weren't computed ahead of time.
970   for (HloInstruction* root : CollectUnreachableRoots()) {
971     TF_RETURN_IF_ERROR(
972         root->AcceptWithOperandOrder(visitor, operand_order,
973                                      /*call_finish_visit=*/false));
974   }
975   // Visit the computation root instruction last.
976   return root_instruction()->AcceptWithOperandOrder(visitor, operand_order,
977                                                     /*call_finish_visit=*/true);
978 }
979 
Clone(const string & suffix,HloCloneContext * context)980 std::unique_ptr<HloComputation> HloComputation::Clone(
981     const string& suffix, HloCloneContext* context) {
982   return CloneWithReplacements(
983       /*replacements=*/absl::flat_hash_map<const HloInstruction*,
984                                            std::unique_ptr<HloInstruction>>(),
985       /*extra_parameters=*/{}, context, suffix);
986 }
987 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,HloCloneContext * context,const string & suffix)988 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
989     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
990     HloCloneContext* context, const string& suffix) {
991   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
992       replacements;
993   replacements.emplace(std::move(r1));
994   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
995                                context, suffix);
996 }
997 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,HloCloneContext * context,const string & suffix)998 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
999     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1000     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1001     HloCloneContext* context, const string& suffix) {
1002   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1003       replacements;
1004   replacements.emplace(std::move(r1));
1005   replacements.emplace(std::move(r2));
1006   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1007                                context, suffix);
1008 }
1009 
CloneWithReplacementPairs(std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r1,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r2,std::pair<const HloInstruction *,std::unique_ptr<HloInstruction>> r3,HloCloneContext * context,const string & suffix)1010 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacementPairs(
1011     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r1,
1012     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r2,
1013     std::pair<const HloInstruction*, std::unique_ptr<HloInstruction>> r3,
1014     HloCloneContext* context, const string& suffix) {
1015   absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1016       replacements;
1017   replacements.emplace(std::move(r1));
1018   replacements.emplace(std::move(r2));
1019   replacements.emplace(std::move(r3));
1020   return CloneWithReplacements(std::move(replacements), /*extra_parameters=*/{},
1021                                context, suffix);
1022 }
1023 
CloneWithReplacements(absl::flat_hash_map<const HloInstruction *,std::unique_ptr<HloInstruction>> replacements,absl::Span<const HloInstruction * const> extra_parameters,HloCloneContext * context,const string & suffix,const HloInstruction * new_root)1024 std::unique_ptr<HloComputation> HloComputation::CloneWithReplacements(
1025     absl::flat_hash_map<const HloInstruction*, std::unique_ptr<HloInstruction>>
1026         replacements,
1027     absl::Span<const HloInstruction* const> extra_parameters,
1028     HloCloneContext* context, const string& suffix,
1029     const HloInstruction* new_root) {
1030   std::unique_ptr<HloCloneContext> context_ptr;
1031   if (context == nullptr) {
1032     context_ptr = absl::make_unique<HloCloneContext>(parent(), suffix);
1033     context = context_ptr.get();
1034   }
1035   if (new_root == nullptr) {
1036     new_root = root_instruction();
1037   }
1038 
1039   // Look up instr in the replacements map, and return either the replacement,
1040   // or instr, if the replacement isn't present.
1041   //
1042   // Note: This can return null, indicating that instr should not be present in
1043   // the new computation.
1044   auto replace = [&](const HloInstruction* instr) {
1045     auto it = replacements.find(instr);
1046     return it != replacements.end() ? it->second.get() : instr;
1047   };
1048 
1049   VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n";
1050 
1051   // We want to do a postorder walk over [replace(i) for i in instructions_].
1052   // We can't reuse MakeInstructionPostOrder() for this, because that will
1053   // generate a postorder of plain instructions_, and our replacements may
1054   // change the postorder!
1055   //
1056   // The postorder we want here is simpler than what MakeInstructionPostOrder()
1057   // does -- we only care about operand dependencies -- so let's just do it
1058   // ourselves.
1059   std::vector<const HloInstruction*> postorder;
1060   absl::flat_hash_map<const HloInstruction*, VisitState> visited;
1061   for (const auto& instr : instructions_) {
1062     std::vector<const HloInstruction*> dfs_stack;
1063     const HloInstruction* new_instr = replace(instr.get());
1064     if (!new_instr) {
1065       continue;
1066     }
1067     dfs_stack.push_back(new_instr);
1068 
1069     while (!dfs_stack.empty()) {
1070       auto* cur = dfs_stack.back();
1071       auto it = visited.find(cur);
1072       if (it != visited.end()) {
1073         dfs_stack.pop_back();
1074         if (it->second == kVisited) {
1075           continue;
1076         }
1077         CHECK_EQ(it->second, kVisiting);
1078         postorder.push_back(cur);
1079         it->second = kVisited;
1080         continue;
1081       }
1082 
1083       visited.insert({cur, kVisiting});
1084       for (HloInstruction* operand : cur->operands()) {
1085         const HloInstruction* new_operand = replace(operand);
1086         if (new_operand) {
1087           dfs_stack.emplace_back(new_operand);
1088         }
1089       }
1090     }
1091   }
1092 
1093   std::vector<std::unique_ptr<HloInstruction>> instructions;
1094   // First add the extra parameters to 'instructions'.
1095   for (const auto& instr : extra_parameters) {
1096     CHECK_EQ(instr->opcode(), HloOpcode::kParameter)
1097         << "Only parameter instructions are allowed in 'extra_parameters'";
1098     instructions.emplace_back(instr->Clone());
1099   }
1100   for (auto instr : postorder) {
1101     std::vector<HloInstruction*> new_operands;
1102     for (auto operand : instr->operands()) {
1103       auto replaced_operand = replace(operand);
1104       CHECK_NE(replaced_operand, nullptr)
1105           << "replacements map tried to eliminate a used instruction "
1106           << operand->ToString() << ", used by " << instr->ToString();
1107       new_operands.push_back(context->GetInstruction(replaced_operand));
1108     }
1109     std::unique_ptr<HloInstruction> new_instr =
1110         instr->CloneWithNewOperands(instr->shape(), new_operands, context);
1111     if (instr->opcode() == HloOpcode::kParameter &&
1112         instr->parameter_replicated_at_leaf_buffers().has_value()) {
1113       new_instr->set_parameter_replicated_at_leaf_buffers(
1114           instr->parameter_replicated_at_leaf_buffers().value());
1115     }
1116     instructions.push_back(std::move(new_instr));
1117   }
1118   Builder builder(name() + "." + suffix);
1119   for (auto& instr : instructions) {
1120     builder.AddInstruction(std::move(instr));
1121   }
1122   auto result = builder.Build(
1123       /*root_instruction=*/context->GetInstruction(replace(new_root)));
1124 
1125   // Clone control dependencies.
1126   for (auto instr : postorder) {
1127     HloInstruction* new_instr = context->GetInstruction(instr);
1128     for (auto successor : instr->control_successors()) {
1129       auto replaced_successor = replace(successor);
1130       // successor may not have been remapped, because it might have been
1131       // removed by the replacements map.
1132       if (replaced_successor != nullptr) {
1133         TF_CHECK_OK(new_instr->AddControlDependencyTo(
1134             context->GetInstruction(replaced_successor)));
1135       }
1136     }
1137   }
1138   context->MapComputation(this, result.get());
1139   return result;
1140 }
1141 
UniquifyName(NameUniquer * name_uniquer)1142 void HloComputation::UniquifyName(NameUniquer* name_uniquer) {
1143   name_ = name_uniquer->GetUniqueName(name_);
1144 }
1145 
GetInstructionWithName(absl::string_view name)1146 HloInstruction* HloComputation::GetInstructionWithName(absl::string_view name) {
1147   auto instructions_in_computation = instructions();
1148   auto it = absl::c_find_if(
1149       instructions_in_computation,
1150       [&](HloInstruction* instr) { return instr->name() == name; });
1151   return it == instructions_in_computation.end() ? nullptr : *it;
1152 }
1153 
IsEntryComputation() const1154 bool HloComputation::IsEntryComputation() const {
1155   return parent()->entry_computation() == this;
1156 }
1157 }  // namespace xla
1158