1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_module.h"
17 
18 #include <iterator>
19 #include <set>
20 #include <sstream>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <utility>
24 
25 #include "tensorflow/compiler/xla/map_util.h"
26 #include "tensorflow/compiler/xla/ptr_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
HloModule(const string & name,const VersionedComputationHandle & entry_computation_handle,const HloModuleConfig & config)35 HloModule::HloModule(const string& name,
36                      const VersionedComputationHandle& entry_computation_handle,
37                      const HloModuleConfig& config)
38     : name_(NameUniquer::GetSanitizedName(name)),
39       config_(config),
40       has_entry_computation_handle_(true),
41       entry_computation_handle_(entry_computation_handle),
42       unique_id_(next_unique_module_id_++) {}
43 
HloModule(const string & name)44 HloModule::HloModule(const string& name)
45     : name_(NameUniquer::GetSanitizedName(name)),
46       unique_id_(next_unique_module_id_++) {}
HloModule(const string & name,const HloModuleConfig & config)47 HloModule::HloModule(const string& name, const HloModuleConfig& config)
48     : name_(NameUniquer::GetSanitizedName(name)),
49       config_(config),
50       unique_id_(next_unique_module_id_++) {}
51 
AddComputationInternal(std::unique_ptr<HloComputation> computation,bool is_entry,bool uniquify_names)52 HloComputation* HloModule::AddComputationInternal(
53     std::unique_ptr<HloComputation> computation, bool is_entry,
54     bool uniquify_names) {
55   if (is_entry) {
56     CHECK_EQ(nullptr, entry_computation_);
57     entry_computation_ = computation.get();
58 
59     // If the module configuration has no entry layout computation set, create a
60     // default one based on the program shape.
61     if (!config_.has_entry_computation_layout()) {
62       config_.SetDefaultComputationLayout(
63           entry_computation_->ComputeProgramShape());
64     }
65   }
66 
67   if (uniquify_names) {
68     computation->UniquifyName(&computation_name_uniquer_);
69     for (auto* instruction : computation->instructions()) {
70       instruction->UniquifyName(&instruction_name_uniquer_);
71     }
72   } else {
73     // Don't uniquify the names of the computation or instruction, but we must
74     // run the names through the uniquifiers to prevent future name collisions
75     // for computations and instructions created later.
76     computation_name_uniquer_.GetUniqueName(computation->name());
77     for (auto* instruction : computation->instructions()) {
78       instruction_name_uniquer_.GetUniqueName(instruction->name());
79     }
80   }
81 
82   // Pick unique IDs for each instruction.
83   for (auto* instruction : computation->instructions()) {
84     instruction->SetUniqueId(NewUniqueInstructionId());
85   }
86   computation->set_parent(this);
87   computations_.push_back(std::move(computation));
88   return computations_.back().get();
89 }
90 
AddEntryComputation(std::unique_ptr<HloComputation> computation)91 HloComputation* HloModule::AddEntryComputation(
92     std::unique_ptr<HloComputation> computation) {
93   return AddComputationInternal(std::move(computation), /*is_entry=*/true,
94                                 /*uniquify_names=*/true);
95 }
96 
RemoveEmbeddedComputation(HloComputation * to_remove)97 Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
98   auto it =
99       std::find_if(computations_.begin(), computations_.end(),
100                    [&to_remove](const std::unique_ptr<HloComputation>& comp) {
101                      return comp.get() == to_remove;
102                    });
103   TF_RET_CHECK(it->get() == to_remove);
104   computations_.erase(it);
105   return Status::OK();
106 }
107 
AddEmbeddedComputation(std::unique_ptr<HloComputation> computation)108 HloComputation* HloModule::AddEmbeddedComputation(
109     std::unique_ptr<HloComputation> computation) {
110   return AddComputationInternal(std::move(computation), /*is_entry=*/false,
111                                 /*uniquify_names=*/true);
112 }
113 
ReplaceComputations(const std::unordered_map<HloComputation *,HloComputation * > & replacements)114 void HloModule::ReplaceComputations(
115     const std::unordered_map<HloComputation*, HloComputation*>& replacements) {
116   // Replace all uses of non-canonical computations with their
117   // representatives.
118   std::vector<std::unique_ptr<HloComputation>> new_computations;
119   new_computations.reserve(computations_.size());
120 
121   for (std::unique_ptr<HloComputation>& computation : computations_) {
122     for (auto* instruction : computation->instructions()) {
123       switch (instruction->opcode()) {
124         case HloOpcode::kCall:
125         case HloOpcode::kMap:
126         case HloOpcode::kReduce:
127         case HloOpcode::kReduceWindow: {
128           HloComputation* new_arg = tensorflow::gtl::FindWithDefault(
129               replacements, instruction->to_apply(), nullptr);
130           if (new_arg != nullptr) {
131             instruction->set_to_apply(new_arg);
132           }
133           break;
134         }
135         case HloOpcode::kWhile: {
136           HloComputation* new_condition = tensorflow::gtl::FindWithDefault(
137               replacements, instruction->while_condition(), nullptr);
138           if (new_condition != nullptr) {
139             instruction->set_while_condition(new_condition);
140           }
141           HloComputation* new_body = tensorflow::gtl::FindWithDefault(
142               replacements, instruction->while_body(), nullptr);
143           if (new_body != nullptr) {
144             instruction->set_while_body(new_body);
145           }
146           break;
147         }
148         case HloOpcode::kConditional: {
149           HloComputation* new_true_computation =
150               tensorflow::gtl::FindWithDefault(
151                   replacements, instruction->true_computation(), nullptr);
152           if (new_true_computation != nullptr) {
153             instruction->set_true_computation(new_true_computation);
154           }
155           HloComputation* new_false_computation =
156               tensorflow::gtl::FindWithDefault(
157                   replacements, instruction->false_computation(), nullptr);
158           if (new_false_computation != nullptr) {
159             instruction->set_false_computation(new_false_computation);
160           }
161           break;
162         }
163         case HloOpcode::kSelectAndScatter: {
164           HloComputation* new_select = tensorflow::gtl::FindWithDefault(
165               replacements, instruction->select(), nullptr);
166           if (new_select != nullptr) {
167             instruction->set_select(new_select);
168           }
169           HloComputation* new_scatter = tensorflow::gtl::FindWithDefault(
170               replacements, instruction->scatter(), nullptr);
171           if (new_scatter != nullptr) {
172             instruction->set_scatter(new_scatter);
173           }
174           break;
175         }
176         default:
177           break;
178       }
179     }
180 
181     if (replacements.find(computation.get()) == replacements.end()) {
182       new_computations.push_back(std::move(computation));
183     }
184   }
185 
186   // Replace entry_computation if necessary.
187   entry_computation_ = tensorflow::gtl::FindWithDefault(
188       replacements, entry_computation_, entry_computation_);
189 
190   computations_ = std::move(new_computations);
191 }
192 
ToString(const HloPrintOptions & options) const193 string HloModule::ToString(const HloPrintOptions& options) const {
194   std::ostringstream s;
195   s << "HloModule " << name() << "\n\n";
196   for (const HloComputation* computation : MakeComputationPostOrder()) {
197     if (computation == entry_computation()) {
198       s << "ENTRY ";
199     }
200     s << computation->ToString(options) << "\n\n";
201   }
202   return s.str();
203 }
204 
ToProto() const205 HloModuleProto HloModule::ToProto() const {
206   HloModuleProto proto;
207   proto.set_name(name_);
208   proto.set_entry_computation_name(entry_computation_->name());
209   for (const HloComputation* computation : MakeComputationPostOrder()) {
210     // Fusion computations are added when the fusion instructions are created by
211     // HloInstruction::CreateFromProto.
212     if (computation->IsFusionComputation()) {
213       continue;
214     }
215     HloComputationProto computation_proto = computation->ToProto();
216     proto.add_computations()->Swap(&computation_proto);
217   }
218   return proto;
219 }
220 
221 namespace {
222 
223 // Construct a ProgramShape matching the shape of the parameters and root of the
224 // given module's entry computation.
ProgramShapeFromProto(const HloModuleProto & module)225 StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) {
226   const HloComputationProto* entry_computation = nullptr;
227   for (const HloComputationProto& computation : module.computations()) {
228     if (computation.name() == module.entry_computation_name()) {
229       entry_computation = &computation;
230       break;
231     }
232   }
233   TF_RET_CHECK(entry_computation != nullptr)
234       << "No computation with entry computation name"
235       << module.entry_computation_name();
236 
237   tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters;
238   const HloInstructionProto* root = nullptr;
239   for (const HloInstructionProto& instruction :
240        entry_computation->instructions()) {
241     if (instruction.name() == entry_computation->root_name()) {
242       TF_RET_CHECK(root == nullptr) << "Entry computation has more than "
243                                        "one instruction with (root) name "
244                                     << instruction.name();
245       root = &instruction;
246     }
247     if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
248       TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number()))
249           << "Entry computation has more than one parameter instruction "
250              "with parameter number "
251           << instruction.parameter_number();
252       parameters[instruction.parameter_number()] = {instruction.name(),
253                                                     &instruction.shape()};
254     }
255   }
256   TF_RET_CHECK(root != nullptr)
257       << "Entry computation is missing root instruction named "
258       << entry_computation->root_name();
259 
260   ProgramShape program_shape;
261   *program_shape.mutable_result() = root->shape();
262   for (int64 i = 0; i < parameters.size(); ++i) {
263     TF_RET_CHECK(ContainsKey(parameters, i))
264         << "Entry computation missing parameter number " << i;
265     const string& name = parameters.at(i).first;
266     const Shape& shape = *parameters.at(i).second;
267     *program_shape.add_parameters() = shape;
268     program_shape.add_parameter_names(name);
269   }
270 
271   return std::move(program_shape);
272 }
273 
274 }  // namespace
275 
276 /* static */
CreateFromProto(const HloModuleProto & proto,const HloModuleConfig & module_config,const VersionedComputationHandle & entry_computation_handle)277 StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
278     const HloModuleProto& proto, const HloModuleConfig& module_config,
279     const VersionedComputationHandle& entry_computation_handle) {
280   // The ProgramShape in the passed in module config must match the shapes of
281   // the entry parameters and root.
282   TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape,
283                       ProgramShapeFromProto(proto));
284   TF_RET_CHECK(expected_program_shape.parameters_size() ==
285                module_config.entry_computation_layout().parameter_count());
286   for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
287     const Shape& parameter_shape =
288         module_config.entry_computation_layout().parameter_layout(i).shape();
289     TF_RET_CHECK(
290         ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape))
291         << "HloModuleConfig has different shape for parameter " << i
292         << " than the HLO module. Expected: "
293         << ShapeUtil::HumanStringWithLayout(
294                expected_program_shape.parameters(i))
295         << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
296   }
297   const Shape& result_shape =
298       module_config.entry_computation_layout().result_layout().shape();
299   TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape))
300       << "HloModuleConfig has different result shape than the HLO module. "
301          "Expected: "
302       << ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
303       << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
304 
305   auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
306                                       module_config);
307 
308   tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
309   for (const HloComputationProto& computation_proto : proto.computations()) {
310     TF_ASSIGN_OR_RETURN(
311         std::unique_ptr<HloComputation> computation,
312         HloComputation::CreateFromProto(
313             module.get(), computation_proto, computation_map,
314             /*add_fused_computation=*/
315             [&module](std::unique_ptr<HloComputation> fused_computation) {
316               module->AddComputationInternal(std::move(fused_computation),
317                                              /*is_entry=*/false,
318                                              /*uniquify_names=*/false);
319             }));
320     CHECK_NE(computation.get(), nullptr);
321     TF_RET_CHECK(!ContainsKey(computation_map, computation->name()));
322     string computation_name = computation->name();
323     // Don't uniquify names because we want names to be stable across
324     // serialization and deserialization.
325     computation_map[computation_name] = module->AddComputationInternal(
326         std::move(computation),
327         /*is_entry=*/proto.entry_computation_name() == computation_name,
328         /*uniquify_names=*/false);
329   }
330   TF_RET_CHECK(module->entry_computation_ != nullptr);
331 
332   // Because we didn't uniquify the names, double-check that the instruction and
333   // computation names are unique from the proto.
334   tensorflow::gtl::FlatSet<string> computation_names;
335   tensorflow::gtl::FlatSet<string> instruction_names;
336   for (HloComputation* computation : module->computations()) {
337     if (computation->IsFusionComputation()) {
338       continue;
339     }
340 
341     TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
342         << "Computation name is not unique: " << computation->name();
343     computation_names.insert(computation->name());
344     for (HloInstruction* instruction : computation->instructions()) {
345       TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
346           << "Instruction name is not unique: " << instruction->name();
347       instruction_names.insert(instruction->name());
348     }
349   }
350 
351   return std::move(module);
352 }
353 
354 /* static */
CreateModuleConfigFromProto(const HloModuleProto & module)355 StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
356     const HloModuleProto& module) {
357   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
358                       ProgramShapeFromProto(module));
359 
360   HloModuleConfig module_config(program_shape);
361 
362   // The module config is constructed with default layouts regardless of what is
363   // passed in via the ProgramShape. Set the layouts to the appropriate values.
364   ComputationLayout* entry_layout =
365       module_config.mutable_entry_computation_layout();
366   for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
367     TF_RETURN_IF_ERROR(
368         entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
369             program_shape.parameters(i)));
370   }
371   TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
372       program_shape.result()));
373 
374   return module_config;
375 }
376 
377 namespace {
378 // Returns whether `hlo` is used outside the given subcomputation.
379 // `instructions_in_subcomputation` is the instruction set of the given
380 // subcomputation.
IsUsedOutsideSubcomputation(const HloInstruction & hlo,const std::unordered_set<HloInstruction * > & instructions_in_subcomputation)381 bool IsUsedOutsideSubcomputation(
382     const HloInstruction& hlo,
383     const std::unordered_set<HloInstruction*>& instructions_in_subcomputation) {
384   for (HloInstruction* user : hlo.users()) {
385     if (!instructions_in_subcomputation.count(user)) {
386       return true;
387     }
388   }
389   return false;
390 }
391 }  // anonymous namespace
392 
OutlineExpressionFromComputation(tensorflow::gtl::ArraySlice<HloInstruction * > instructions_to_outline,const string & outlined_computation_name,HloComputation * computation)393 HloInstruction* HloModule::OutlineExpressionFromComputation(
394     tensorflow::gtl::ArraySlice<HloInstruction*> instructions_to_outline,
395     const string& outlined_computation_name, HloComputation* computation) {
396   auto builder = HloComputation::Builder(outlined_computation_name);
397 
398   // A map from original instructions to their counterparts in the new outlined
399   // function.
400   std::unordered_map<HloInstruction*, HloInstruction*> outlined_instructions;
401   // A set that contains all instructions to be outlined.
402   std::unordered_set<HloInstruction*> instruction_set_to_outline(
403       instructions_to_outline.begin(), instructions_to_outline.end());
404   std::vector<HloInstruction*> arguments;
405   std::vector<HloInstruction*> outputs;
406   int64 parameter_count = 0;
407   for (HloInstruction* instruction_to_outline : instructions_to_outline) {
408     // Clone the original instruction.
409     HloInstruction* outlined_instruction =
410         builder.AddInstruction(instruction_to_outline->Clone());
411 
412     // Replace its operands to their counterparts in the new function.
413     for (int64 operand_num = 0;
414          operand_num < outlined_instruction->operand_count(); ++operand_num) {
415       HloInstruction* old_operand =
416           outlined_instruction->mutable_operand(operand_num);
417 
418       HloInstruction** operand_slot = &(outlined_instructions[old_operand]);
419       if (*operand_slot == nullptr) {
420         // Because instructions_to_outline is in topological order, if
421         // old_operand is not in outlined_instructions, old_operand must be an
422         // input of the outlined subcomputation and thus should be represented
423         // as a parameter in the new function.
424         arguments.push_back(old_operand);
425         *operand_slot = builder.AddInstruction(HloInstruction::CreateParameter(
426             parameter_count, old_operand->shape(), ""));
427         ++parameter_count;
428       }
429       TF_CHECK_OK(
430           outlined_instruction->ReplaceOperandWith(operand_num, *operand_slot));
431     }
432 
433     // Insert the new instruction into the outlined_instructions map.
434     InsertOrDie(&outlined_instructions, instruction_to_outline,
435                 outlined_instruction);
436 
437     // Mark instruction_to_outline an output if it is used outside the
438     // subcomputation or is the output of the original computation (i.e. used
439     // externally).
440     if (instruction_to_outline->user_count() == 0 ||
441         IsUsedOutsideSubcomputation(*instruction_to_outline,
442                                     instruction_set_to_outline)) {
443       outputs.push_back(instruction_to_outline);
444     }
445   }
446 
447   if (outputs.size() != 1) {
448     string error_message =
449         "The subcomputation to outline has multiple outputs:\n";
450     for (HloInstruction* output : outputs) {
451       tensorflow::strings::StrAppend(&error_message, output->ToString(), "\n");
452     }
453     LOG(FATAL) << error_message;
454   }
455   HloInstruction* output = outputs[0];
456 
457   // Creates a call to the nested computation.
458   HloComputation* nested_computation = AddEmbeddedComputation(
459       builder.Build(FindOrDie(outlined_instructions, output)));
460   HloInstruction* call = computation->AddInstruction(HloInstruction::CreateCall(
461       output->shape(), arguments, nested_computation));
462 
463   VLOG(2) << "Outlining the following instructions";
464   for (auto* instruction_to_outline : instructions_to_outline) {
465     VLOG(2) << "  " << instruction_to_outline->ToString();
466   }
467   VLOG(2) << "as a call " << call->ToString();
468   VLOG(2) << "to " << nested_computation->ToString();
469 
470   TF_CHECK_OK(output->ReplaceAllUsesWith(call));
471   for (auto i = instructions_to_outline.rbegin();
472        i != instructions_to_outline.rend(); ++i) {
473     TF_CHECK_OK(computation->RemoveInstruction(*i));
474   }
475 
476   return call;
477 }
478 
instruction_count() const479 int64 HloModule::instruction_count() const {
480   int64 n = 0;
481   for (const auto& computation : computations_) {
482     n += computation->instruction_count();
483   }
484   return n;
485 }
486 
MakeComputationPostOrder() const487 std::list<HloComputation*> HloModule::MakeComputationPostOrder() const {
488   // First determine all root computations by building a set of nonroot
489   // computations (computations which are called by an instruction in the
490   // module).
491   std::set<HloComputation*> nonroot_computations;
492   for (auto& computation : computations_) {
493     for (auto* instruction : computation->instructions()) {
494       for (HloComputation* called_computation :
495            instruction->called_computations()) {
496         nonroot_computations.insert(called_computation);
497       }
498     }
499   }
500 
501   // Keep track of computations which have already been added to the post
502   // order. This prevents duplication as an embedded computation may be called
503   // from two different root computations.
504   std::set<HloComputation*> added_computations;
505   std::list<HloComputation*> post_order;
506   for (auto& computation : computations_) {
507     if (nonroot_computations.count(computation.get()) == 0) {
508       for (HloComputation* embedded_computation :
509            computation->MakeEmbeddedComputationsList()) {
510         if (added_computations.count(embedded_computation) == 0) {
511           post_order.push_back(embedded_computation);
512           added_computations.insert(embedded_computation);
513         }
514       }
515       // Root computations should only be encountered once.
516       CHECK_EQ(0, added_computations.count(computation.get()));
517       post_order.push_back(computation.get());
518       added_computations.insert(computation.get());
519     }
520   }
521   CHECK_EQ(post_order.size(), computations_.size());
522   return post_order;
523 }
524 
MakeNonfusionComputations() const525 std::vector<HloComputation*> HloModule::MakeNonfusionComputations() const {
526   std::vector<HloComputation*> result;
527   for (auto* c : computations()) {
528     if (c->IsFusionComputation()) {
529       continue;
530     }
531     result.push_back(c);
532   }
533   return result;
534 }
535 
Clone(const string & suffix) const536 std::unique_ptr<HloModule> HloModule::Clone(const string& suffix) const {
537   VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n";
538   auto module = MakeUnique<HloModule>(name_ + "-" + suffix);
539   module->config_ = config_;
540   module->entry_computation_handle_ = entry_computation_handle_;
541   module->has_entry_computation_handle_ = has_entry_computation_handle_;
542 
543   std::unordered_map<HloComputation*, HloComputation*> clone_map;
544   for (auto& computation : computations_) {
545     if (computation->IsFusionComputation()) {
546       // Cloning of a fused computation is handled by its fusion instruction.
547       continue;
548     }
549 
550     // When cloning a computation, pass in the new module, so that for any
551     // fusion instruction in this computation, the fused computation will be
552     // deep cloned to the new module.
553     auto cloned_computation = computation->Clone(suffix, module.get());
554     InsertOrDie(&clone_map, computation.get(), cloned_computation.get());
555 
556     if (entry_computation_ == computation.get()) {
557       module->AddEntryComputation(std::move(cloned_computation));
558     } else {
559       module->AddEmbeddedComputation(std::move(cloned_computation));
560     }
561   }
562 
563   for (auto& cloned_computation : module->computations_) {
564     for (auto* instruction : cloned_computation->instructions()) {
565       // Rewrite instruction's called_computation to point to the cloned
566       // computations.
567       instruction->ReplaceCalledComputations([&](HloComputation* hlo) {
568         if (hlo->IsFusionComputation()) {
569           // Cloning of a fused computation has already been handled when its
570           // fusion instruction is cloned. So this hlo computation is already
571           // the cloned one.
572           return hlo;
573         }
574         return FindOrDie(clone_map, hlo);
575       });
576     }
577   }
578   return module;
579 }
580 
DeepCloneComputation(HloComputation * computation)581 HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) {
582   HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this));
583   TF_CHECK_OK(
584       clone->root_instruction()->Accept([this](HloInstruction* instruction) {
585         instruction->ReplaceCalledComputations([this](HloComputation* callee) {
586           return DeepCloneComputation(callee);
587         });
588         return Status::OK();
589       }));
590   return clone;
591 }
592 
RandomNew64() const593 uint64 HloModule::RandomNew64() const {
594   tensorflow::mutex_lock l(rng_mutex_);
595   return rng_();
596 }
597 
598 /* static */ std::atomic<int> HloModule::next_unique_module_id_(0);
599 
600 }  // namespace xla
601