1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <set>
21 #include <sstream>
22 #include <unordered_map>
23 #include <unordered_set>
24 #include <utility>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/container/flat_hash_set.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_cat.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/gtl/map_util.h"
38 #include "tensorflow/core/lib/hash/hash.h"
39 #include "tensorflow/core/platform/stacktrace.h"
40 #include "tensorflow/core/platform/types.h"
41 
42 namespace xla {
43 
HloModule(const string & name,HloModuleConfig config)44 HloModule::HloModule(const string& name, HloModuleConfig config)
45     : name_(NameUniquer::GetSanitizedName(name)),
46       config_(std::move(config)),
47       unique_id_(next_unique_module_id_++),
48       metadata_(tensorflow::Env::Default()) {
49   metadata_.set_canonical_module_id(unique_id_);
50 }
51 
set_schedule(HloSchedule schedule)52 Status HloModule::set_schedule(HloSchedule schedule) {
53   TF_RET_CHECK(schedule.module() == this);
54   TF_RETURN_IF_ERROR(schedule.Verify());
55   schedule_ = std::move(schedule);
56   return Status::OK();
57 }
58 
ReplaceEntryComputation(HloComputation * entry_computation)59 void HloModule::ReplaceEntryComputation(HloComputation* entry_computation) {
60   entry_computation_ = entry_computation;
61   config_.SetDefaultComputationLayout(
62       entry_computation_->ComputeProgramShape());
63   input_output_alias_config_ = HloInputOutputAliasConfig(
64       entry_computation_->root_instruction()->shape());
65 }
66 
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_identifiers,bool preserve_entry_layouts)67 HloComputation* HloModule::AddComputationInternal(
68     std::unique_ptr<HloComputation> computation, bool is_entry,
69     bool uniquify_identifiers, bool preserve_entry_layouts) {
70   if (is_entry) {
71     CHECK_EQ(nullptr, entry_computation_);
72     entry_computation_ = computation.get();
73 
74     if (preserve_entry_layouts) {
75       config_.SetComputationLayoutIfExists(
76           entry_computation_->ComputeProgramShape());
77     } else if (!config_.has_entry_computation_layout()) {
78       // If the module configuration has no entry layout computation set, create
79       // a default one based on the program shape.
80       config_.SetDefaultComputationLayout(
81           entry_computation_->ComputeProgramShape());
82     }
83     input_output_alias_config_ = HloInputOutputAliasConfig(
84         entry_computation_->root_instruction()->shape());
85   }
86 
87   if (uniquify_identifiers) {
88     computation->UniquifyName(&computation_name_uniquer_);
89     for (auto* instruction : computation->instructions()) {
90       instruction->UniquifyName(&instruction_name_uniquer_);
91     }
92 
93     // Pick unique IDs for each instruction.
94     for (auto* instruction : computation->instructions()) {
95       instruction->SetUniqueId(NewUniqueInstructionId());
96     }
97     // Set unique id to this computation.
98     CHECK_NE(computation->root_instruction()->unique_id(), -1)
99         << "Root has no valid id: " << computation->ToString();
100     computation->SetUniqueId(computation->root_instruction()->unique_id());
101   } else {
102     // Don't uniquify the names of the computation or instruction, but we must
103     // run the names through the uniquifiers to prevent future name collisions
104     // for computations and instructions created later. Also, set the
105     // next_unique_id_ to the one greater than the max unique id of any
106     // instruction (or the computation) to avoid ID collisions.
107     computation_name_uniquer_.GetUniqueName(computation->name());
108     for (auto* instruction : computation->instructions()) {
109       instruction_name_uniquer_.GetUniqueName(instruction->name());
110       next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
111     }
112     if (next_unique_id_ < computation->unique_id() + 1) {
113       next_unique_id_ = computation->unique_id() + 1;
114     }
115   }
116 
117   computation->set_parent(this);
118   computations_.push_back(std::move(computation));
119   return computations_.back().get();
120 }
121 
AddEntryComputation(std::unique_ptr<HloComputation> computation)122 HloComputation* HloModule::AddEntryComputation(
123     std::unique_ptr<HloComputation> computation) {
124   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
125                                 /*uniquify_identifiers=*/true,
126                                 /*preserve_entry_layouts=*/false);
127 }
128 
AddEntryComputationWithLayouts(std::unique_ptr<HloComputation> computation)129 HloComputation* HloModule::AddEntryComputationWithLayouts(
130     std::unique_ptr<HloComputation> computation) {
131   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
132                                 /*uniquify_identifiers=*/true,
133                                 /*preserve_entry_layouts=*/true);
134 }
135 
RemoveEmbeddedComputation(HloComputation * to_remove)136 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
137   if (has_schedule() && !to_remove->IsFusionComputation()) {
138     schedule_->remove_computation(to_remove);
139   }
140 
141   auto it = absl::c_find_if(
142       computations_, [&to_remove](const std::unique_ptr<HloComputation>& comp) {
143         return comp.get() == to_remove;
144       });
145   TF_RET_CHECK(it != computations_.end());
146   TF_RET_CHECK(it->get() == to_remove);
147   computations_.erase(it);
148   return Status::OK();
149 }
150 
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)151 HloComputation* HloModule::AddEmbeddedComputation(
152     std::unique_ptr<HloComputation> computation) {
153   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
154                                 /*uniquify_identifiers=*/true,
155                                 /*preserve_entry_layouts=*/false);
156 }
157 
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)158 void HloModule::ReplaceComputations(
159     const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
160   // Replace all uses of non-canonical computations with their
161   // representatives.
162   std::vector<std::unique_ptr<HloComputation>> new_computations;
163   new_computations.reserve(computations_.size());
164 
165   for (std::unique_ptr<HloComputation>& computation : computations_) {
166     for (auto* instruction : computation->instructions()) {
167       switch (instruction->opcode()) {
168         case HloOpcode::kAllReduce:
169         case HloOpcode::kCall:
170         case HloOpcode::kMap:
171         case HloOpcode::kReduce:
172         case HloOpcode::kReduceWindow:
173         case HloOpcode::kScatter:
174         case HloOpcode::kSort: {
175           HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
176               replacements, instruction->to_apply(), nullptr);
177           if (new_arg != nullptr) {
178             instruction->set_to_apply(new_arg);
179           }
180           break;
181         }
182         case HloOpcode::kWhile: {
183           HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
184               replacements, instruction->while_condition(), nullptr);
185           if (new_condition != nullptr) {
186             instruction->set_while_condition(new_condition);
187           }
188           HloComputation* new_body = tensorflow::gtl::FindWithDefault(
189               replacements, instruction->while_body(), nullptr);
190           if (new_body != nullptr) {
191             instruction->set_while_body(new_body);
192           }
193           break;
194         }
195         case HloOpcode::kConditional: {
196           for (int b = 0; b < instruction->branch_count(); ++b) {
197             HloComputation* new_computation = tensorflow::gtl::FindWithDefault(
198                 replacements, instruction->branch_computation(b), nullptr);
199             if (new_computation != nullptr) {
200               instruction->set_branch_computation(b, new_computation);
201             }
202           }
203           break;
204         }
205         case HloOpcode::kSelectAndScatter: {
206           HloComputation* new_select = tensorflow::gtl::FindWithDefault(
207               replacements, instruction->select(), nullptr);
208           if (new_select != nullptr) {
209             instruction->set_select(new_select);
210           }
211           HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
212               replacements, instruction->scatter(), nullptr);
213           if (new_scatter != nullptr) {
214             instruction->set_scatter(new_scatter);
215           }
216           break;
217         }
218         default:
219           break;
220       }
221     }
222 
223     if (replacements.find(computation.get()) == replacements.end()) {
224       new_computations.push_back(std::move(computation));
225     }
226   }
227 
228   // Replace entry_computation if necessary.
229   entry_computation_ = tensorflow::gtl::FindWithDefault(
230       replacements, entry_computation_, entry_computation_);
231 
232   computations_ = std::move(new_computations);
233 }
234 
ToString(const HloPrintOptions & options) const235 string HloModule::ToString(const HloPrintOptions& options) const {
236   std::ostringstream s;
237   // When print_ids() is false, exclude module's name because it includes and
238   // leads to non-deterministic fingerprint.
239   s << "HloModule "
240     << (options.print_ids() ? PrintName(name(), options.print_ids()) : "");
241   if (has_schedule()) {
242     TF_CHECK_OK(schedule().Verify());
243     s << ", is_scheduled=true";
244   }
245   std::string serialized_aliasing = input_output_alias_config().ToShortString();
246   if (!serialized_aliasing.empty()) {
247     s << absl::StrFormat(", input_output_alias={ %s }", serialized_aliasing);
248   }
249   s << "\n\n";
250   const auto& computations = options.canonicalize_computations()
251                                  ? MakeComputationSorted()
252                                  : MakeComputationPostOrder();
253   for (const HloComputation* computation : computations) {
254     if (!options.print_computation(computation)) {
255       continue;
256     }
257     if (computation == entry_computation()) {
258       s << "ENTRY ";
259     }
260     if (has_schedule() && schedule().is_computation_scheduled(computation)) {
261       s << computation->ToString(
262                options, schedule().sequence(computation).instructions())
263         << "\n\n";
264     } else {
265       s << computation->ToString(options) << "\n\n";
266     }
267   }
268   return s.str();
269 }
270 
ToProto() const271 HloModuleProto HloModule::ToProto() const {
272   HloModuleProto proto;
273   proto.set_id(unique_id_);
274   proto.set_name(name_);
275   proto.set_entry_computation_name(entry_computation_->name());
276   proto.set_entry_computation_id(entry_computation_->unique_id());
277   for (const HloComputation* computation : MakeComputationPostOrder()) {
278     HloComputationProto computation_proto = computation->ToProto();
279     proto.add_computations()->Swap(&computation_proto);
280   }
281   if (has_schedule()) {
282     *proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
283   }
284   *proto.mutable_host_program_shape() =
285       entry_computation_layout().ComputeProgramShape().ToProto();
286   *proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
287   *proto.mutable_dynamic_parameter_binding() =
288       dynamic_parameter_binding().ToProto();
289   for (const auto& parameter_indices : CrossProgramPrefetches()) {
290     const auto& parameter = parameter_indices.first;
291     const auto& indices = parameter_indices.second;
292     auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
293     prefetch->set_parameter(parameter);
294     for (auto index : indices) {
295       prefetch->add_index(index);
296     }
297   }
298   return proto;
299 }
300 
CheckUniqueNamesAndIdsForComputationsAndInstructions() const301 Status HloModule::CheckUniqueNamesAndIdsForComputationsAndInstructions() const {
302   absl::flat_hash_set<string> computation_names;
303   absl::flat_hash_set<int> computation_ids;
304   absl::flat_hash_set<string> instruction_names;
305   absl::flat_hash_set<int> instruction_ids;
306 
307   for (const HloComputation* computation : computations()) {
308     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
309         << "Computation name is not unique: " << computation->name();
310     computation_names.insert(computation->name());
311 
312     TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
313         << "Computation id is not unique: " << computation->unique_id();
314     computation_ids.insert(computation->unique_id());
315 
316     for (const HloInstruction* instruction : computation->instructions()) {
317       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
318           << "Instruction name is not unique: " << instruction->name();
319       instruction_names.insert(instruction->name());
320 
321       TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
322           << "Instruction id is not unique: " << instruction->unique_id();
323       instruction_ids.insert(instruction->unique_id());
324     }
325   }
326   return Status::OK();
327 }
328 
329 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,bool prohibit_empty_literal)330 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
331     const HloModuleProto& proto, const HloModuleConfig& module_config,
332     bool prohibit_empty_literal) {
333   VLOG(2) << "CreateFromProto()";
334   XLA_VLOG_LINES(3, proto.DebugString());
335 
336   // The ProgramShape in the passed in module config must match the shapes of
337   // the entry parameters and root.
338   TF_RET_CHECK(proto.has_host_program_shape())
339       << "No program shape found in the proto";
340   ProgramShape expected_program_shape(proto.host_program_shape());
341   TF_RET_CHECK(expected_program_shape.parameters_size() ==
342                module_config.entry_computation_layout().parameter_count());
343   for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
344     const Shape& parameter_shape =
345         module_config.entry_computation_layout().parameter_layout(i).shape();
346     TF_RET_CHECK(ShapeUtil::Compatible(expected_program_shape.parameters(i),
347                                        parameter_shape))
348         << "HloModuleConfig has different shape for parameter " << i
349         << " than the HLO module. Expected: "
350         << ShapeUtil::HumanStringWithLayout(
351                expected_program_shape.parameters(i))
352         << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
353   }
354   const Shape& result_shape =
355       module_config.entry_computation_layout().result_layout().shape();
356   TF_RET_CHECK(
357       ShapeUtil::Compatible(expected_program_shape.result(), result_shape))
358       << "HloModuleConfig has different result shape than the HLO module. "
359          "Expected: "
360       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
361       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
362 
363   absl::flat_hash_map<int64, HloComputation*> computation_map;
364   absl::flat_hash_map<HloComputation*, int64> to_proto_id;
365   std::vector<std::unique_ptr<HloComputation>> computations;
366   HloComputation* entry = nullptr;
367   for (const HloComputationProto& computation_proto : proto.computations()) {
368     TF_ASSIGN_OR_RETURN(
369         std::unique_ptr<HloComputation> computation,
370         HloComputation::CreateFromProto(computation_proto, computation_map,
371                                         prohibit_empty_literal));
372     CHECK_NE(computation.get(), nullptr);
373     int64 computation_id = computation_proto.id();
374     TF_RET_CHECK(computation_id != -1);
375     TF_RET_CHECK(!ContainsKey(computation_map, computation_id));
376     computation_map[computation_id] = computation.get();
377     to_proto_id[computation.get()] = computation_id;
378     if (computation_id == proto.entry_computation_id()) {
379       entry = computation.get();
380     }
381     computations.push_back(std::move(computation));
382   }
383   TF_RET_CHECK(entry != nullptr);
384 
385   auto module = absl::make_unique<HloModule>(proto.name(), module_config);
386 
387   // Sort the computations in the proto id's order.
388   absl::c_sort(computations, [&](const std::unique_ptr<HloComputation>& a,
389                                  const std::unique_ptr<HloComputation>& b) {
390     return to_proto_id[a.get()] < to_proto_id[b.get()];
391   });
392 
393   // Add sorted computations to the module.
394   for (auto& computation : computations) {
395     bool is_entry = computation.get() == entry;
396     // Don't uniquify names because we want names to be stable across
397     // serialization and deserialization.
398     module->AddComputationInternal(std::move(computation), is_entry,
399                                    /*uniquify_identifiers=*/false,
400                                    /*preserve_entry_layouts=*/false);
401   }
402   TF_RET_CHECK(module->entry_computation_ != nullptr);
403 
404   TF_ASSIGN_OR_RETURN(
405       module->input_output_alias_config_,
406       HloInputOutputAliasConfig::CreateFromProto(
407           entry->ComputeProgramShape().result(), proto.input_output_alias()));
408 
409   // Because we didn't uniquify the names or the ids, double-check that the
410   // instruction and computation names and ids are unique from the proto.
411   TF_ASSIGN_OR_RETURN(module->dynamic_parameter_binding_,
412                       DynamicParameterBinding::CreateFromProto(
413                           proto.dynamic_parameter_binding()));
414 
415   TF_RETURN_IF_ERROR(
416       module->CheckUniqueNamesAndIdsForComputationsAndInstructions());
417 
418   if (proto.has_schedule()) {
419     TF_ASSIGN_OR_RETURN(
420         HloSchedule schedule,
421         HloSchedule::CreateFromProto(module.get(), proto.schedule()));
422     TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
423   }
424 
425   for (auto prefetch : proto.cross_program_prefetches()) {
426     module->AddCrossProgramPrefetch(
427         prefetch.parameter(),
428         ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
429   }
430 
431   return std::move(module);
432 }
433 
434 /* static */
CreateModuleConfigFromShape(const ProgramShape & program_shape,const DebugOptions & debug_options,const ExecutionOptions * execution_options)435 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromShape(
436     const ProgramShape& program_shape, const DebugOptions& debug_options,
437     const ExecutionOptions* execution_options) {
438   HloModuleConfig module_config(ProgramShape{program_shape});
439   module_config.set_debug_options(debug_options);
440   if (execution_options) {
441     if (execution_options->num_replicas() > 0) {
442       module_config.set_replica_count(execution_options->num_replicas());
443     }
444     if (execution_options->num_partitions() > 0) {
445       module_config.set_num_partitions(execution_options->num_partitions());
446     }
447     module_config.set_use_spmd_partitioning(
448         execution_options->use_spmd_partitioning());
449     module_config.set_deduplicate_hlo(execution_options->deduplicate_hlo());
450     module_config.set_broadcast_replicated_params(
451         execution_options->broadcast_replicated_parameters_via_collectives());
452     if (execution_options->has_device_assignment()) {
453       TF_ASSIGN_OR_RETURN(std::unique_ptr<DeviceAssignment> device_assignment,
454                           DeviceAssignment::Deserialize(
455                               execution_options->device_assignment()));
456       module_config.set_static_device_assignment(*device_assignment);
457       if (execution_options->num_replicas() > 0) {
458         CHECK_EQ(module_config.static_device_assignment().replica_count(),
459                  module_config.replica_count());
460       }
461       if (execution_options->num_partitions() > 0) {
462         CHECK_EQ(module_config.static_device_assignment().computation_count(),
463                  module_config.num_partitions());
464       }
465     }
466   }
467 
468   // The module config is constructed with default layouts regardless of what is
469   // passed in via the ProgramShape. Set the layouts to the appropriate values.
470   ComputationLayout* entry_layout =
471       module_config.mutable_entry_computation_layout();
472   for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
473     TF_RETURN_IF_ERROR(
474         entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
475             program_shape.parameters(i)));
476   }
477   TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
478       program_shape.result()));
479   return module_config;
480 }
481 
482 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module,const DebugOptions & debug_options,const ExecutionOptions * execution_options)483 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
484     const HloModuleProto& module, const DebugOptions& debug_options,
485     const ExecutionOptions* execution_options) {
486   TF_RET_CHECK(module.has_host_program_shape())
487       << "No program shape found in the proto";
488   ProgramShape program_shape(module.host_program_shape());
489   return CreateModuleConfigFromShape(program_shape, debug_options,
490                                      execution_options);
491 }
492 
493 namespace {
494 // Returns whether `hlo` is used outside the given subcomputation.
495 // `instructions_in_subcomputation` is the instruction set of the given
496 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const absl::flat_hash_set<HloInstruction * > & instructions_in_subcomputation)497 bool IsUsedOutsideSubcomputation(const HloInstruction& hlo,
498                                  const absl::flat_hash_set<HloInstruction*>&
499                                      instructions_in_subcomputation) {
500   return absl::c_any_of(hlo.users(), [&](HloInstruction* user) {
501     return !instructions_in_subcomputation.contains(user);
502   });
503 }
504 }  // anonymous namespace
505 
OutlineExpressionFromComputation(absl::Span<HloInstruction * const> instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)506 HloInstruction* HloModule::OutlineExpressionFromComputation(
507     absl::Span<HloInstruction* const> instructions_to_outline,
508     const string& outlined_computation_name, HloComputation* computation) {
509   auto builder = HloComputation::Builder(outlined_computation_name);
510 
511   // A map from original instructions to their counterparts in the new outlined
512   // function.
513   absl::flat_hash_map<HloInstruction*, HloInstruction*> outlined_instructions;
514   // A set that contains all instructions to be outlined.
515   absl::flat_hash_set<HloInstruction*> instruction_set_to_outline(
516       instructions_to_outline.begin(), instructions_to_outline.end());
517   std::vector<HloInstruction*> arguments;
518   std::vector<HloInstruction*> outputs;
519   int64 parameter_count = 0;
520   for (HloInstruction* instruction_to_outline : instructions_to_outline) {
521     // Clone the original instruction.
522     HloInstruction* outlined_instruction =
523         builder.AddInstruction(instruction_to_outline->Clone());
524 
525     // Replace its operands to their counterparts in the new function.
526     for (int64 operand_num = 0;
527          operand_num < outlined_instruction->operand_count(); ++operand_num) {
528       HloInstruction* old_operand =
529           outlined_instruction->mutable_operand(operand_num);
530 
531       HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
532       if (*operand_slot == nullptr) {
533         // Because instructions_to_outline is in topological order, if
534         // old_operand is not in outlined_instructions, old_operand must be an
535         // input of the outlined subcomputation and thus should be represented
536         // as a parameter in the new function.
537         arguments.push_back(old_operand);
538         *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
539             parameter_count, old_operand->shape(), "p"));
540         ++parameter_count;
541       }
542       TF_CHECK_OK(
543           outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
544     }
545 
546     // Insert the new instruction into the outlined_instructions map.
547     InsertOrDie(&outlined_instructions, instruction_to_outline,
548                 outlined_instruction);
549 
550     // Mark instruction_to_outline an output if it is used outside the
551     // subcomputation or is the output of the original computation (i.e. used
552     // externally).
553     if (instruction_to_outline->user_count() == 0 ||
554         IsUsedOutsideSubcomputation(*instruction_to_outline,
555                                     instruction_set_to_outline)) {
556       outputs.push_back(instruction_to_outline);
557     }
558   }
559 
560   if (outputs.size() != 1) {
561     string error_message =
562         "The subcomputation to outline has multiple outputs:\n";
563     for (HloInstruction* output : outputs) {
564       absl::StrAppend(&error_message, output->ToString(), "\n");
565     }
566     LOG(FATAL) << error_message;
567   }
568   HloInstruction* output = outputs[0];
569 
570   // Creates a call to the nested computation.
571   HloComputation* nested_computation = AddEmbeddedComputation(
572       builder.Build(FindOrDie(outlined_instructions, output)));
573   HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
574       output->shape(), arguments, nested_computation));
575 
576   VLOG(2) << "Outlining the following instructions";
577   for (auto* instruction_to_outline : instructions_to_outline) {
578     VLOG(2) << "  " << instruction_to_outline->ToString();
579   }
580   VLOG(2) << "as a call " << call->ToString();
581   VLOG(2) << "to " << nested_computation->ToString();
582 
583   TF_CHECK_OK(output->ReplaceAllUsesWith(call));
584   for (auto i = instructions_to_outline.rbegin();
585        i != instructions_to_outline.rend(); ++i) {
586     TF_CHECK_OK(computation->RemoveInstruction(*i));
587   }
588 
589   return call;
590 }
591 
instruction_count() const592 int64 HloModule::instruction_count() const {
593   int64 n = 0;
594   for (const auto& computation : computations_) {
595     n += computation->instruction_count();
596   }
597   return n;
598 }
599 
MakeComputationPostOrder(const absl::flat_hash_set<HloComputation * > & allow_list) const600 std::vector<HloComputation*> HloModule::MakeComputationPostOrder(
601     const absl::flat_hash_set<HloComputation*>& allow_list) const {
602   std::vector<HloComputation*> filtered_post_order(allow_list.size());
603   auto post_order = this->MakeComputationPostOrder();
604 
605   int filtered_idx = 0;
606   for (auto& computation : post_order) {
607     if (allow_list.contains(computation)) {
608       filtered_post_order[filtered_idx] = computation;
609       filtered_idx += 1;
610     }
611   }
612 
613   return filtered_post_order;
614 }
615 
MakeComputationPostOrder() const616 std::vector<HloComputation*> HloModule::MakeComputationPostOrder() const {
617   // First determine all root computations by building a set of nonroot
618   // computations (computations which are called by an instruction in the
619   // module).
620   absl::flat_hash_set<HloComputation*> nonroot_computations;
621   for (auto& computation : computations_) {
622     for (auto* instruction : computation->instructions()) {
623       for (HloComputation* called_computation :
624            instruction->called_computations()) {
625         nonroot_computations.insert(called_computation);
626       }
627     }
628   }
629 
630   // Keep track of computations which have already been added to the post
631   // order. This prevents duplication as an embedded computation may be called
632   // from two different root computations.
633   absl::flat_hash_set<HloComputation*> added_computations;
634   std::vector<HloComputation*> post_order;
635   for (auto& computation : computations_) {
636     if (!nonroot_computations.contains(computation.get())) {
637       for (HloComputation* embedded_computation :
638            computation->MakeEmbeddedComputationsList()) {
639         if (!added_computations.contains(embedded_computation)) {
640           post_order.push_back(embedded_computation);
641           added_computations.insert(embedded_computation);
642         }
643       }
644       // Root computations should only be encountered once.
645       CHECK(!added_computations.contains(computation.get()));
646       post_order.push_back(computation.get());
647       added_computations.insert(computation.get());
648     }
649   }
650   if (post_order.size() != computations_.size()) {
651     for (HloComputation* computation : post_order) {
652       LOG(ERROR) << "Post Order: " << computation->name() << " ("
653                  << computation->parent()->name() << ")";
654     }
655     for (auto& computation : computations_) {
656       LOG(ERROR) << "Computations: " << computation->name() << " ("
657                  << computation->parent()->name() << ")";
658     }
659     LOG(FATAL) << "Mismatch computation count: post_order=" << post_order.size()
660                << " computation_count=" << computations_.size();
661   }
662   return post_order;
663 }
664 
665 namespace {
CompareComputationsByContent(HloComputation * a,HloComputation * b)666 bool CompareComputationsByContent(HloComputation* a, HloComputation* b) {
667   if (a->instruction_count() != b->instruction_count()) {
668     return a->instruction_count() < b->instruction_count();
669   }
670   return a->ToString(HloPrintOptions::Fingerprint()) <
671          b->ToString(HloPrintOptions::Fingerprint());
672 }
673 }  // anonymous namespace
674 
MakeComputationSorted() const675 std::vector<HloComputation*> HloModule::MakeComputationSorted() const {
676   std::vector<HloComputation*> result = MakeComputationPostOrder();
677   if (config().content_aware_computation_sorting()) {
678     absl::c_sort(result, CompareComputationsByContent);
679   }
680   return result;
681 }
682 
MakeNonfusionComputations() const683 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
684   std::vector<HloComputation*> result = MakeComputationPostOrder();
685   result.erase(std::remove_if(
686                    result.begin(), result.end(),
687                    [](HloComputation* c) { return c->IsFusionComputation(); }),
688                result.end());
689   return result;
690 }
691 
MakeNonfusionComputationsSorted() const692 std::vector<HloComputation*> HloModule::MakeNonfusionComputationsSorted()
693     const {
694   auto result = MakeNonfusionComputations();
695   if (config().content_aware_computation_sorting()) {
696     absl::c_sort(result, CompareComputationsByContent);
697   }
698   return result;
699 }
700 
Clone(const string & suffix) const701 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
702   return Clone(config(), suffix);
703 }
704 
Clone(const HloModuleConfig & config,const string & suffix) const705 std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
706                                             const string& suffix) const {
707   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
708   auto module = absl::make_unique<HloModule>(
709       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config);
710 
711   HloCloneContext context(module.get(), suffix);
712   auto cloned_computation = entry_computation_->Clone(suffix, &context);
713   module->AddEntryComputation(std::move(cloned_computation));
714   module->input_output_alias_config() = input_output_alias_config();
715 
716   if (has_schedule() && schedule().Verify().ok()) {
717     HloSchedule clone_schedule(module.get());
718     for (HloComputation* computation : computations()) {
719       if (schedule().is_computation_scheduled(computation)) {
720         HloInstructionSequence& clone_sequence =
721             clone_schedule.GetOrCreateSequence(
722                 context.GetComputation(computation));
723         for (const HloInstruction* instruction :
724              schedule().sequence(computation).instructions()) {
725           clone_sequence.push_back(context.GetInstruction(instruction));
726         }
727       }
728     }
729     TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
730   }
731   for (const auto& parameter_indices : CrossProgramPrefetches()) {
732     const auto& parameter = parameter_indices.first;
733     const auto& indices = parameter_indices.second;
734     module->AddCrossProgramPrefetch(parameter, indices);
735   }
736   return module;
737 }
738 
RemoveUnusedComputations()739 Status HloModule::RemoveUnusedComputations() {
740   std::string suffix = "tmp";
741   auto module = absl::make_unique<HloModule>(
742       absl::StrCat(name_, suffix.empty() ? "" : "-", suffix), config());
743   HloCloneContext context(module.get(), suffix);
744   entry_computation_->Clone(suffix, &context);
745   std::vector<HloComputation*> to_remove;
746   for (auto computation : computations()) {
747     auto found_computation = context.FindComputation(computation);
748     if (found_computation == nullptr) {
749       to_remove.push_back(computation);
750     }
751   }
752   for (auto computation : to_remove) {
753     TF_RETURN_IF_ERROR(RemoveEmbeddedComputation(computation));
754   }
755   return Status::OK();
756 }
757 
DeepCloneComputation(HloComputation * computation,HloCloneContext * context)758 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation,
759                                                 HloCloneContext* context) {
760   HloComputation* new_computation;
761   if (context != nullptr) {
762     if ((new_computation = context->FindComputation(computation)) != nullptr) {
763       return new_computation;
764     }
765     new_computation =
766         AddEmbeddedComputation(computation->Clone(context->suffix(), context));
767   } else {
768     new_computation = AddEmbeddedComputation(computation->Clone(""));
769   }
770   return new_computation;
771 }
772 
RandomNew64() const773 uint64 HloModule::RandomNew64() const {
774   tensorflow::mutex_lock l(rng_mutex_);
775   return rng_();
776 }
777 
GetComputationWithName(absl::string_view name)778 HloComputation* HloModule::GetComputationWithName(absl::string_view name) {
779   auto computations_in_module = computations();
780   auto it = absl::c_find_if(
781       computations_in_module,
782       [&](HloComputation* computation) { return computation->name() == name; });
783   return it == computations_in_module.end() ? nullptr : *it;
784 }
785 
Hash() const786 uint64 HloModule::Hash() const {
787   uint64 result = entry_computation_layout().Hash();
788   // Use MakeComputationSorted() instead of MakeComputationPostOrder()
789   // because naming may affect the order of MakeComputationPostOrder() but not
790   // MakeComputationSorted().
791   for (auto* computation : MakeComputationSorted()) {
792     for (auto* instruction : computation->MakeInstructionPostOrder()) {
793       result = tensorflow::Hash64Combine(result, instruction->Hash());
794     }
795   }
796   return result;
797 }
798 
799 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
800 
801 }  // namespace xla
802