1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <iterator>
19 #include <set>
20 #include <sstream>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/flat_hash_map.h"
27 #include "absl/container/flat_hash_set.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "tensorflow/compiler/xla/map_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/gtl/map_util.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
HloModule(const string & name,const HloModuleConfig & config)39 HloModule::HloModule(const string& name, const HloModuleConfig& config)
40     : name_(NameUniquer::GetSanitizedName(name)),
41       config_(config),
42       unique_id_(next_unique_module_id_++) {}
43 
set_schedule(HloSchedule schedule)44 Status HloModule::set_schedule(HloSchedule schedule) {
45   TF_RET_CHECK(schedule.module() == this);
46   TF_RETURN_IF_ERROR(schedule.Verify());
47   schedule_ = std::move(schedule);
48   return Status::OK();
49 }
50 
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers)51 HloComputation* HloModule::AddComputationInternal(
52     std::unique_ptr<HloComputation> computation, bool is_entry,
53     bool uniquify_identifiers) {
54   if (is_entry) {
55     CHECK_EQ(nullptr, entry_computation_);
56     entry_computation_ = computation.get();
57 
58     // If the module configuration has no entry layout computation set, create a
59     // default one based on the program shape.
60     if (!config_.has_entry_computation_layout()) {
61       config_.SetDefaultComputationLayout(
62           entry_computation_->ComputeProgramShape());
63     }
64     input_output_alias_config_ = HloInputOutputAliasConfig(
65         entry_computation_->root_instruction()->shape());
66   }
67 
68   if (uniquify_identifiers) {
69     computation->UniquifyName(&computation_name_uniquer_);
70     for (auto* instruction : computation->instructions()) {
71       instruction->UniquifyName(&instruction_name_uniquer_);
72     }
73 
74     // Pick unique IDs for each instruction.
75     for (auto* instruction : computation->instructions()) {
76       instruction->SetUniqueId(NewUniqueInstructionId());
77     }
78     // Set unique id to this computation.
79     CHECK_NE(computation->root_instruction()->unique_id(), -1)
80         << "Root has no valid id: " << computation->ToString();
81     computation->SetUniqueId(computation->root_instruction()->unique_id());
82   } else {
83     // Don't uniquify the names of the computation or instruction, but we must
84     // run the names through the uniquifiers to prevent future name collisions
85     // for computations and instructions created later. Also, set the
86     // next_unique_id_ to the one greater than the max unique id of any
87     // instruction (or the computation) to avoid ID collisions.
88     computation_name_uniquer_.GetUniqueName(computation->name());
89     for (auto* instruction : computation->instructions()) {
90       instruction_name_uniquer_.GetUniqueName(instruction->name());
91       next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
92     }
93     if (next_unique_id_ < computation->unique_id() + 1) {
94       next_unique_id_ = computation->unique_id() + 1;
95     }
96   }
97 
98   computation->set_parent(this);
99   computations_.push_back(std::move(computation));
100   return computations_.back().get();
101 }
102 
AddEntryComputation(std::unique_ptr<HloComputation> computation)103 HloComputation* HloModule::AddEntryComputation(
104     std::unique_ptr<HloComputation> computation) {
105   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
106                                 /*uniquify_identifiers=*/true);
107 }
108 
RemoveEmbeddedComputation(HloComputation * to_remove)109 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
110   auto it = absl::c_find_if(
111       computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
112         return comp.get() == to_remove;
113       });
114   TF_RET_CHECK(it->get() == to_remove);
115   computations_.erase(it);
116   return Status::OK();
117 }
118 
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)119 HloComputation* HloModule::AddEmbeddedComputation(
120     std::unique_ptr<HloComputation> computation) {
121   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
122                                 /*uniquify_identifiers=*/true);
123 }
124 
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)125 void HloModule::ReplaceComputations(
126     const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
127   // Replace all uses of non-canonical computations with their
128   // representatives.
129   std::vector<std::unique_ptr<HloComputation>> new_computations;
130   new_computations.reserve(computations_.size());
131 
132   for (std::unique_ptr<HloComputation>& computation : computations_) {
133     for (auto* instruction : computation->instructions()) {
134       switch (instruction->opcode()) {
135         case HloOpcode::kCall:
136         case HloOpcode::kMap:
137         case HloOpcode::kReduce:
138         case HloOpcode::kReduceWindow:
139         case HloOpcode::kScatter: {
140           HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
141               replacements, instruction->to_apply(), nullptr);
142           if (new_arg != nullptr) {
143             instruction->set_to_apply(new_arg);
144           }
145           break;
146         }
147         case HloOpcode::kWhile: {
148           HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
149               replacements, instruction->while_condition(), nullptr);
150           if (new_condition != nullptr) {
151             instruction->set_while_condition(new_condition);
152           }
153           HloComputation* new_body = tensorflow::gtl::FindWithDefault(
154               replacements, instruction->while_body(), nullptr);
155           if (new_body != nullptr) {
156             instruction->set_while_body(new_body);
157           }
158           break;
159         }
160         case HloOpcode::kConditional: {
161           for (int b = 0; b < instruction->branch_count(); ++b) {
162             HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
163                 replacements, instruction->branch_computation(b), nullptr);
164             if (new_computation != nullptr) {
165               instruction->set_branch_computation(b, new_computation);
166             }
167           }
168           break;
169         }
170         case HloOpcode::kSelectAndScatter: {
171           HloComputation* new_select = tensorflow::gtl::FindWithDefault(
172               replacements, instruction->select(), nullptr);
173           if (new_select != nullptr) {
174             instruction->set_select(new_select);
175           }
176           HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
177               replacements, instruction->scatter(), nullptr);
178           if (new_scatter != nullptr) {
179             instruction->set_scatter(new_scatter);
180           }
181           break;
182         }
183         default:
184           break;
185       }
186     }
187 
188     if (replacements.find(computation.get()) == replacements.end()) {
189       new_computations.push_back(std::move(computation));
190     }
191   }
192 
193   // Replace entry_computation if necessary.
194   entry_computation_ = tensorflow::gtl::FindWithDefault(
195       replacements, entry_computation_, entry_computation_);
196 
197   computations_ = std::move(new_computations);
198 }
199 
ToString(const HloPrintOptions & options) const200 string HloModule::ToString(const HloPrintOptions& options) const {
201   std::ostringstream s;
202   s << "HloModule " << name();
203   if (has_schedule()) {
204     TF_CHECK_OK(schedule().Verify());
205     s << ", is_scheduled=true";
206   }
207   s << "\n\n";
208   for (const HloComputation* computation : MakeComputationPostOrder()) {
209     if (computation == entry_computation()) {
210       s << "ENTRY ";
211     }
212     if (has_schedule() && schedule().is_computation_scheduled(computation)) {
213       s << computation->ToString(
214                options, schedule().sequence(computation).instructions())
215         << "\n\n";
216     } else {
217       s << computation->ToString(options) << "\n\n";
218     }
219   }
220   return s.str();
221 }
222 
ToProto() const223 HloModuleProto HloModule::ToProto() const {
224   HloModuleProto proto;
225   proto.set_id(unique_id_);
226   proto.set_name(name_);
227   proto.set_entry_computation_name(entry_computation_->name());
228   proto.set_entry_computation_id(entry_computation_->unique_id());
229   for (const HloComputation* computation : MakeComputationPostOrder()) {
230     HloComputationProto computation_proto = computation->ToProto();
231     proto.add_computations()->Swap(&computation_proto);
232   }
233   if (has_schedule()) {
234     *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
235   }
236   *proto.mutable_host_program_shape() =
237       entry_computation_layout().ComputeProgramShape().ToProto();
238   *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
239   *proto.mutable_dynamic_parameter_binding() =
240       dynamic_parameter_binding().ToProto();
241   return proto;
242 }
243 
CheckUniqueNamesAndIdsForComputationsAndInstructions() const244 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
245   absl::flat_hash_set<string> computation_names;
246   absl::flat_hash_set<int> computation_ids;
247   absl::flat_hash_set<string> instruction_names;
248   absl::flat_hash_set<int> instruction_ids;
249 
250   for (const HloComputation* computation : computations()) {
251     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
252         << "Computation name is not unique: " << computation->name();
253     computation_names.insert(computation->name());
254 
255     TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
256         << "Computation id is not unique: " << computation->unique_id();
257     computation_ids.insert(computation->unique_id());
258 
259     for (const HloInstruction* instruction : computation->instructions()) {
260       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
261           << "Instruction name is not unique: " << instruction->name();
262       instruction_names.insert(instruction->name());
263 
264       TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
265           << "Instruction id is not unique: " << instruction->unique_id();
266       instruction_ids.insert(instruction->unique_id());
267     }
268   }
269   return Status::OK();
270 }
271 
272 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config)273 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
274     const HloModuleProto& proto, const HloModuleConfig& module_config) {
275   VLOG(2) << "CreateFromProto()";
276   XLA_VLOG_LINES(3, proto.DebugString());
277 
278   // The ProgramShape in the passed in module config must match the shapes of
279   // the entry parameters and root.
280   TF_RET_CHECK(proto.has_host_program_shape())
281       << "No program shape found in the proto";
282   ProgramShape expected_program_shape(proto.host_program_shape());
283   TF_RET_CHECK(expected_program_shape.parameters_size() ==
284                module_config.entry_computation_layout().parameter_count());
285   for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
286     const Shape& parameter_shape =
287         module_config.entry_computation_layout().parameter_layout(i).shape();
288     TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
289                                        parameter_shape))
290         << "HloModuleConfig has different shape for parameter " << i
291         << " than the HLO module. Expected: "
292         << ShapeUtil::HumanStringWithLayout(
293                expected_program_shape.parameters(i))
294         << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
295   }
296   const Shape& result_shape =
297       module_config.entry_computation_layout().result_layout().shape();
298   TF_RET_CHECK(
299       ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
300       << "HloModuleConfig has different result shape than the HLO module. "
301          "Expected: "
302       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
303       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
304 
305   absl::flat_hash_map<int64, HloComputation*> computation_map;
306   absl::flat_hash_map<HloComputation*, int64> to_proto_id;
307   std::vector<std::unique_ptr<HloComputation>> computations;
308   HloComputation* entry = nullptr;
309   for (const HloComputationProto& computation_proto : proto.computations()) {
310     TF_ASSIGN_OR_RETURN(
311         std::unique_ptr<HloComputation> computation,
312         HloComputation::CreateFromProto(computation_proto, computation_map));
313     CHECK_NE(computation.get(), nullptr);
314     int64 computation_id = computation_proto.id();
315     TF_RET_CHECK(computation_id != -1);
316     TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
317     computation_map[computation_id] = computation.get();
318     to_proto_id[computation.get()] = computation_id;
319     if (computation_id == proto.entry_computation_id()) {
320       entry = computation.get();
321     }
322     computations.push_back(std::move(computation));
323   }
324   TF_RET_CHECK(entry != nullptr);
325 
326   auto module = absl::make_unique<HloModule>(proto.name(), module_config);
327 
328   // Sort the computations in the proto id's order.
329   absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
330                                  const std::unique_ptr<HloComputation>& b) {
331     return to_proto_id[a.get()] < to_proto_id[b.get()];
332   });
333 
334   // Add sorted computations to the module.
335   for (auto& computation : computations) {
336     bool is_entry = computation.get() == entry;
337     // Don't uniquify names because we want names to be stable across
338     // serialization and deserialization.
339     module->AddComputationInternal(std::move(computation), is_entry,
340                                    /*uniquify_identifiers=*/false);
341   }
342   TF_RET_CHECK(module->entry_computation_ != nullptr);
343 
344   TF_ASSIGN_OR_RETURN(
345       module->input_output_alias_config_,
346       HloInputOutputAliasConfig::CreateFromProto(
347           entry->ComputeProgramShape().result(), proto.input_output_alias()));
348 
349   // Because we didn't uniquify the names or the ids, double-check that the
350   // instruction and computation names and ids are unique from the proto.
351   TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
352                       DynamicParameterBinding::CreateFromProto(
353                           proto.dynamic_parameter_binding()));
354 
355   TF_RETURN_IF_ERROR(
356       module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
357 
358   if (proto.has_schedule()) {
359     TF_ASSIGN_OR_RETURN(
360         HloSchedule schedule,
361         HloSchedule::CreateFromProto(module.get(), proto.schedule()));
362     TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
363   }
364 
365   return std::move(module);
366 }
367 
368 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options)369 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
370     const HloModuleProto& module, const DebugOptions& debug_options) {
371   TF_RET_CHECK(module.has_host_program_shape())
372       << "No program shape found in the proto";
373   ProgramShape program_shape(module.host_program_shape());
374 
375   HloModuleConfig module_config(ProgramShape{program_shape});
376   module_config.set_debug_options(debug_options);
377 
378   // The module config is constructed with default layouts regardless of what is
379   // passed in via the ProgramShape. Set the layouts to the appropriate values.
380   ComputationLayout* entry_layout =
381       module_config.mutable_entry_computation_layout();
382   for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
383     TF_RETURN_IF_ERROR(
384         entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
385             program_shape.parameters(i)));
386   }
387   TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
388       program_shape.result()));
389   return module_config;
390 }
391 
392 namespace {
393 // Returns whether `hlo` is used outside the given subcomputation.
394 // `instructions_in_subcomputation` is the instruction set of the given
395 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)396 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
397                                  const absl::flat_hash_set<HloInstruction*>&
398                                      instructions_in_subcomputation) {
399   return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
400     return !instructions_in_subcomputation.contains(user);
401   });
402 }
403 }  // anonymous namespace
404 
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)405 HloInstruction* HloModule::OutlineExpressionFromComputation(
406     absl::Span<HloInstruction* const> instructions_to_outline,
407     const string& outlined_computation_name, HloComputation* computation) {
408   auto builder = HloComputation::Builder(outlined_computation_name);
409 
410   // A map from original instructions to their counterparts in the new outlined
411   // function.
412   absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
413   // A set that contains all instructions to be outlined.
414   absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
415       instructions_to_outline.begin(), instructions_to_outline.end());
416   std::vector<HloInstruction*> arguments;
417   std::vector<HloInstruction*> outputs;
418   int64 parameter_count = 0;
419   for (HloInstruction* instruction_to_outline : instructions_to_outline) {
420     // Clone the original instruction.
421     HloInstruction* outlined_instruction =
422         builder.AddInstruction(instruction_to_outline->Clone());
423 
424     // Replace its operands to their counterparts in the new function.
425     for (int64 operand_num = 0;
426          operand_num < outlined_instruction->operand_count(); ++operand_num) {
427       HloInstruction* old_operand =
428           outlined_instruction->mutable_operand(operand_num);
429 
430       HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
431       if (*operand_slot == nullptr) {
432         // Because instructions_to_outline is in topological order, if
433         // old_operand is not in outlined_instructions, old_operand must be an
434         // input of the outlined subcomputation and thus should be represented
435         // as a parameter in the new function.
436         arguments.push_back(old_operand);
437         *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
438             parameter_count, old_operand->shape(), "p"));
439         ++parameter_count;
440       }
441       TF_CHECK_OK(
442           outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
443     }
444 
445     // Insert the new instruction into the outlined_instructions map.
446     InsertOrDie(&outlined_instructions, instruction_to_outline,
447                 outlined_instruction);
448 
449     // Mark instruction_to_outline an output if it is used outside the
450     // subcomputation or is the output of the original computation (i.e. used
451     // externally).
452     if (instruction_to_outline->user_count() == 0 ||
453         IsUsedOutsideSubcomputation(*instruction_to_outline,
454                                     instruction_set_to_outline)) {
455       outputs.push_back(instruction_to_outline);
456     }
457   }
458 
459   if (outputs.size() != 1) {
460     string error_message =
461         "The subcomputation to outline has multiple outputs:\n";
462     for (HloInstruction* output : outputs) {
463       absl::StrAppend(&error_message, output->ToString(), "\n");
464     }
465     LOG(FATAL) << error_message;
466   }
467   HloInstruction* output = outputs[0];
468 
469   // Creates a call to the nested computation.
470   HloComputation* nested_computation = AddEmbeddedComputation(
471       builder.Build(FindOrDie(outlined_instructions, output)));
472   HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
473       output->shape(), arguments, nested_computation));
474 
475   VLOG(2) << "Outlining the following instructions";
476   for (auto* instruction_to_outline : instructions_to_outline) {
477     VLOG(2) << "  " << instruction_to_outline->ToString();
478   }
479   VLOG(2) << "as a call " << call->ToString();
480   VLOG(2) << "to " << nested_computation->ToString();
481 
482   TF_CHECK_OK(output->ReplaceAllUsesWith(call));
483   for (auto i = instructions_to_outline.rbegin();
484        i != instructions_to_outline.rend(); ++i) {
485     TF_CHECK_OK(computation->RemoveInstruction(*i));
486   }
487 
488   return call;
489 }
490 
instruction_count() const491 int64 HloModule::instruction_count() const {
492   int64 n = 0;
493   for (const auto& computation : computations_) {
494     n += computation->instruction_count();
495   }
496   return n;
497 }
498 
MakeComputationPostOrder() const499 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
500   // First determine all root computations by building a set of nonroot
501   // computations (computations which are called by an instruction in the
502   // module).
503   absl::flat_hash_set<HloComputation*> nonroot_computations;
504   for (auto& computation : computations_) {
505     for (auto* instruction : computation->instructions()) {
506       for (HloComputation* called_computation :
507            instruction->called_computations()) {
508         nonroot_computations.insert(called_computation);
509       }
510     }
511   }
512 
513   // Keep track of computations which have already been added to the post
514   // order. This prevents duplication as an embedded computation may be called
515   // from two different root computations.
516   absl::flat_hash_set<HloComputation*> added_computations;
517   std::vector<HloComputation*> post_order;
518   for (auto& computation : computations_) {
519     if (!nonroot_computations.contains(computation.get())) {
520       for (HloComputation* embedded_computation :
521            computation->MakeEmbeddedComputationsList()) {
522         if (!added_computations.contains(embedded_computation)) {
523           post_order.push_back(embedded_computation);
524           added_computations.insert(embedded_computation);
525         }
526       }
527       // Root computations should only be encountered once.
528       CHECK(!added_computations.contains(computation.get()));
529       post_order.push_back(computation.get());
530       added_computations.insert(computation.get());
531     }
532   }
533   if (post_order.size() != computations_.size()) {
534     for (HloComputation* computation : post_order) {
535       LOG(ERROR) << "Post Order: " << computation->name() << " ("
536                  << computation->parent()->name() << ")";
537     }
538     for (auto& computation : computations_) {
539       LOG(ERROR) << "Computations: " << computation->name() << " ("
540                  << computation->parent()->name() << ")";
541     }
542     LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
543                << " computation_count=" << computations_.size();
544   }
545   return post_order;
546 }
547 
MakeNonfusionComputations() const548 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
549   std::vector<HloComputation*> result;
550   for (auto* c : computations()) {
551     if (c->IsFusionComputation()) {
552       continue;
553     }
554     result.push_back(c);
555   }
556   return result;
557 }
558 
Clone(const string & suffix) const559 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
560   return Clone(config(), suffix);
561 }
562 
Clone(const HloModuleConfig & config,const string & suffix) const563 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
564                                             const string& suffix) const {
565   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
566   auto module = absl::make_unique<HloModule>(
567       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
568 
569   HloCloneContext context(module.get(), suffix);
570   auto cloned_computation = entry_computation_->Clone(suffix, &context);
571   module->AddEntryComputation(std::move(cloned_computation));
572 
573   if (has_schedule() && schedule().Verify().ok()) {
574     HloSchedule clone_schedule(module.get());
575     for (HloComputation* computation : computations()) {
576       if (schedule().is_computation_scheduled(computation)) {
577         HloInstructionSequence& clone_sequence =
578             clone_schedule.GetOrCreateSequence(
579                 context.GetComputation(computation));
580         for (const HloInstruction* instruction :
581              schedule().sequence(computation).instructions()) {
582           clone_sequence.push_back(context.GetInstruction(instruction));
583         }
584       }
585     }
586     TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
587   }
588   return module;
589 }
590 
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)591 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
592                                                 HloCloneContext* context) {
593   HloComputation* new_computation;
594   if (context != nullptr) {
595     if ((new_computation = context->FindComputation(computation)) != nullptr) {
596       return new_computation;
597     }
598     new_computation =
599         AddEmbeddedComputation(computation->Clone(context->suffix(), context));
600   } else {
601     new_computation = AddEmbeddedComputation(computation->Clone(""));
602   }
603   return new_computation;
604 }
605 
RandomNew64() const606 uint64 HloModule::RandomNew64() const {
607   tensorflow::mutex_lock l(rng_mutex_);
608   return rng_();
609 }
610 
GetComputationWithName(absl::string_view name)611 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
612   auto computations_in_module = computations();
613   auto it = absl::c_find_if(
614       computations_in_module,
615       [&](HloComputation* computation) { return computation->name() == name; });
616   return it == computations_in_module.end() ? nullptr : *it;
617 }
618 
619 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
620 
621 }  // namespace xla
622