1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO instructions are in DAG form and represent the computations that the user
17 // has built up via the XLA service interface. They are ultimately lowered
18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
19 // form; e.g. see DfsHloVisitor.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23 
24 #include <functional>
25 #include <iosfwd>
26 #include <list>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <tuple>
31 #include <vector>
32 
33 #include "absl/container/flat_hash_map.h"
34 #include "absl/container/flat_hash_set.h"
35 #include "absl/container/inlined_vector.h"
36 #include "absl/memory/memory.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/types/span.h"
40 #include "tensorflow/compiler/xla/comparison_util.h"
41 #include "tensorflow/compiler/xla/iterator_util.h"
42 #include "tensorflow/compiler/xla/literal.h"
43 #include "tensorflow/compiler/xla/map_util.h"
44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
45 #include "tensorflow/compiler/xla/service/hlo.pb.h"
46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
50 #include "tensorflow/compiler/xla/service/name_uniquer.h"
51 #include "tensorflow/compiler/xla/shape_tree.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/iterator_range.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/macros.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/platform/types.h"
60 
61 namespace xla {
62 
63 class HloComputation;
64 class HloModule;
65 
66 string PrintName(const string& name, bool print_ids);
67 
68 // A bunch of switches that control how the hlo text should be printed.
69 class HloPrintOptions {
70  public:
71   enum class PrintSubcomputationMode {
72     kOff,         // Do not print anything about subcomputations.
73     kNameOnly,    // Only print the name of subcomputations.
74     kFullBodies,  // Print the full bodies of subcomputations.
75   };
76 
77   // Constructs the default print options: don't print large constants, don't
78   // compact operands, no indentation.
HloPrintOptions()79   HloPrintOptions()
80       : print_large_constants_(false),
81         print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
82         print_metadata_(true),
83         print_backend_config_(true),
84         compact_operands_(false),
85         include_layout_in_shapes_(true),
86         print_result_shape_(true),
87         print_operand_shape_(true),
88         print_operand_names_(true),
89         print_program_shape_(true),
90         print_percent_(true),
91         print_control_dependencies_(true),
92         canonicalize_instruction_names_(false),
93         indent_amount_(0),
94         is_in_nested_computation_(false),
95         print_ids_(true),
96         canonicalize_computations_(false),
97         print_extra_attributes_(true) {}
98 
ShortParsable()99   static HloPrintOptions ShortParsable() {
100     return HloPrintOptions()
101         .set_print_large_constants(true)
102         .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
103         .set_print_metadata(false)
104         .set_print_backend_config(false)
105         .set_print_operand_shape(false)
106         .set_print_program_shape(false)
107         .set_print_percent(false)
108         .set_print_control_dependencies(false);
109   }
110 
111   // Options to produce the canonical string representing an isomorphic
112   // computation graph.
Canonical()113   static HloPrintOptions Canonical() {
114     return HloPrintOptions()
115         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
116         .set_print_metadata(false)
117         .set_print_backend_config(false)
118         .set_compact_operands(false)
119         .set_print_operand_names(false)
120         .set_print_operand_shape(true)
121         .set_print_program_shape(false)
122         .set_print_percent(false)
123         .set_print_control_dependencies(false)
124         .set_canonicalize_instruction_names(true);
125   }
126 
127   // Options to produce a fingerprint of an HLO.
Fingerprint()128   static HloPrintOptions Fingerprint() {
129     return HloPrintOptions()
130         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
131         .set_print_metadata(false)
132         .set_print_backend_config(false)
133         .set_compact_operands(true)
134         .set_print_operand_names(false)
135         .set_print_operand_shape(true)
136         .set_print_program_shape(false)
137         .set_print_percent(false)
138         .set_print_control_dependencies(false)
139         .set_canonicalize_instruction_names(true)
140         .set_print_ids(false)
141         .set_canonicalize_computations(true);
142   }
143 
144   // If true, large constants will be printed out.
set_print_large_constants(bool value)145   HloPrintOptions& set_print_large_constants(bool value) {
146     print_large_constants_ = value;
147     return *this;
148   }
149 
set_print_subcomputation_mode(PrintSubcomputationMode value)150   HloPrintOptions& set_print_subcomputation_mode(
151       PrintSubcomputationMode value) {
152     print_subcomputation_mode_ = value;
153     return *this;
154   }
155 
156   // If true, metadata will be printed.
set_print_metadata(bool value)157   HloPrintOptions& set_print_metadata(bool value) {
158     print_metadata_ = value;
159     return *this;
160   }
161 
162   // If true, backend_config will be printed.
set_print_backend_config(bool value)163   HloPrintOptions& set_print_backend_config(bool value) {
164     print_backend_config_ = value;
165     return *this;
166   }
167 
168   // If true, result shapes will be printed.
set_print_result_shape(bool value)169   HloPrintOptions& set_print_result_shape(bool value) {
170     print_result_shape_ = value;
171     return *this;
172   }
173 
174   // If true, operands' shapes will be printed.
set_print_operand_shape(bool value)175   HloPrintOptions& set_print_operand_shape(bool value) {
176     print_operand_shape_ = value;
177     return *this;
178   }
179 
180   // If true, the operand names will be printed.
set_print_operand_names(bool value)181   HloPrintOptions& set_print_operand_names(bool value) {
182     print_operand_names_ = value;
183     return *this;
184   }
185 
186   // If true, all printed names include unique identifiers.
set_print_ids(bool value)187   HloPrintOptions& set_print_ids(bool value) {
188     print_ids_ = value;
189     return *this;
190   }
191 
192   // If true, the HLO includes its attributes.
set_print_extra_attributes(bool value)193   HloPrintOptions& set_print_extra_attributes(bool value) {
194     print_extra_attributes_ = value;
195     return *this;
196   }
197 
198   // If true, program shape of hlo computations will be printed.
set_print_program_shape(bool value)199   HloPrintOptions& set_print_program_shape(bool value) {
200     print_program_shape_ = value;
201     return *this;
202   }
203 
204   // If true, names will be printed with prefix '%'.
set_print_percent(bool value)205   HloPrintOptions& set_print_percent(bool value) {
206     print_percent_ = value;
207     return *this;
208   }
209 
210   // If true, control dependencies will be printed.
set_print_control_dependencies(bool value)211   HloPrintOptions& set_print_control_dependencies(bool value) {
212     print_control_dependencies_ = value;
213     return *this;
214   }
215 
216   // If true, only a part of operands will be printed out (note that in this
217   // case the text will not be parsable).
set_compact_operands(bool value)218   HloPrintOptions& set_compact_operands(bool value) {
219     compact_operands_ = value;
220     return *this;
221   }
222 
223   // If true, include the layout in any shapes that are printed (instruction
224   // and operands).
set_include_layout_in_shapes(bool value)225   HloPrintOptions& set_include_layout_in_shapes(bool value) {
226     include_layout_in_shapes_ = value;
227     return *this;
228   }
229 
230   // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
231   // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
set_canonicalize_instruction_names(bool value)232   HloPrintOptions& set_canonicalize_instruction_names(bool value) {
233     canonicalize_instruction_names_ = value;
234     return *this;
235   }
236 
237   // If true, canonicalizes computations, sorting by computations' names.
set_canonicalize_computations(bool value)238   HloPrintOptions& set_canonicalize_computations(bool value) {
239     canonicalize_computations_ = value;
240     return *this;
241   }
242 
243   // The indent of the hlo text block.
set_indent_amount(int value)244   HloPrintOptions& set_indent_amount(int value) {
245     indent_amount_ = value;
246     return *this;
247   }
248 
249   // If true, indicates the instruction being printed is inside a nested
250   // computation.
set_is_in_nested_computation(bool value)251   HloPrintOptions& set_is_in_nested_computation(bool value) {
252     is_in_nested_computation_ = value;
253     return *this;
254   }
255 
256   // Instructions are selected for printing by a predicate function
257   // (`set_print_instructions`). We also print their surrounding instructions
258   // for ease of reading. We print `leading_and_trailing_instructions_number`
259   // instructions before and after the qualified ones inside a computation.
set_leading_and_trailing_instructions_number(int value)260   HloPrintOptions& set_leading_and_trailing_instructions_number(int value) {
261     leading_and_trailing_instructions_number_ = value;
262     return *this;
263   }
264 
265   // A callback which takes an HloInstruction*, its string representation,
266   // the indentation level of the resulting block, and a
267   // bool variable indicating whether the instruction is root or not. The return
268   // value is a string which is used for this instruction during printing.
269   using FormatInstructionFunc =
270       std::function<string(const HloInstruction*, const string&, int, bool)>;
271 
set_format_instruction(FormatInstructionFunc callback)272   HloPrintOptions& set_format_instruction(FormatInstructionFunc callback) {
273     format_instruction_ = callback;
274     return *this;
275   }
276 
277   using HloInstructionPredicate = std::function<bool(const HloInstruction*)>;
278 
279   // A callback which takes an HloInstruction* and returns whether it should be
280   // printed or not.
set_print_instruction(HloInstructionPredicate callback)281   HloPrintOptions& set_print_instruction(HloInstructionPredicate callback) {
282     print_instruction_ = callback;
283     return *this;
284   }
285 
286   using HloComputationPredicate = std::function<bool(const HloComputation*)>;
287 
288   // A callback which takes an HloComputation* and returns whether it should be
289   // printed or not.
set_print_computation(HloComputationPredicate callback)290   HloPrintOptions& set_print_computation(HloComputationPredicate callback) {
291     print_computation_ = callback;
292     return *this;
293   }
294 
print_large_constants()295   bool print_large_constants() const { return print_large_constants_; }
print_subcomputation_mode()296   PrintSubcomputationMode print_subcomputation_mode() const {
297     return print_subcomputation_mode_;
298   }
print_metadata()299   bool print_metadata() const { return print_metadata_; }
print_backend_config()300   bool print_backend_config() const { return print_backend_config_; }
compact_operands()301   bool compact_operands() const { return compact_operands_; }
include_layout_in_shapes()302   bool include_layout_in_shapes() const { return include_layout_in_shapes_; }
print_result_shape()303   bool print_result_shape() const { return print_result_shape_; }
print_operand_shape()304   bool print_operand_shape() const { return print_operand_shape_; }
print_operand_names()305   bool print_operand_names() const { return print_operand_names_; }
print_ids()306   bool print_ids() const { return print_ids_; }
print_program_shape()307   bool print_program_shape() const { return print_program_shape_; }
print_percent()308   bool print_percent() const { return print_percent_; }
print_control_dependencies()309   bool print_control_dependencies() const {
310     return print_control_dependencies_;
311   }
print_extra_attributes()312   bool print_extra_attributes() const { return print_extra_attributes_; }
canonicalize_instruction_names()313   bool canonicalize_instruction_names() const {
314     return canonicalize_instruction_names_;
315   }
canonicalize_computations()316   bool canonicalize_computations() const { return canonicalize_computations_; }
indent_amount()317   int indent_amount() const { return indent_amount_; }
is_in_nested_computation()318   int is_in_nested_computation() const { return is_in_nested_computation_; }
leading_and_trailing_instructions_number()319   int leading_and_trailing_instructions_number() const {
320     return leading_and_trailing_instructions_number_;
321   }
format_instruction(const HloInstruction * instr,const string & instr_name,int indent,bool is_root)322   string format_instruction(const HloInstruction* instr,
323                             const string& instr_name, int indent,
324                             bool is_root) const {
325     return format_instruction_(instr, instr_name, indent, is_root);
326   }
print_instruction(const HloInstruction * instr)327   bool print_instruction(const HloInstruction* instr) const {
328     return print_instruction_(instr);
329   }
print_computation(const HloComputation * comp)330   bool print_computation(const HloComputation* comp) const {
331     return print_computation_(comp);
332   }
333 
334  private:
335   bool print_large_constants_;
336   PrintSubcomputationMode print_subcomputation_mode_;
337   bool print_metadata_;
338   bool print_backend_config_;
339   bool compact_operands_;
340   bool include_layout_in_shapes_;
341   bool print_result_shape_;
342   bool print_operand_shape_;
343   bool print_operand_names_;
344   bool print_program_shape_;
345   bool print_percent_;
346   bool print_control_dependencies_;
347   bool canonicalize_instruction_names_;
348   int indent_amount_;
349   bool is_in_nested_computation_;
350   bool print_ids_;
351   bool canonicalize_computations_;
352   bool print_extra_attributes_;
353   int leading_and_trailing_instructions_number_ = 3;
354   FormatInstructionFunc format_instruction_ = [](const HloInstruction* instr,
355                                                  const string& instr_name,
356                                                  int indent, bool is_root) {
357     return absl::StrCat(string(2 * indent, ' '), is_root ? "ROOT " : "",
358                         instr_name);
359   };
360   HloInstructionPredicate print_instruction_ = [](const HloInstruction* instr) {
361     return true;
362   };
363   HloComputationPredicate print_computation_ = [](const HloComputation* comp) {
364     return true;
365   };
366 };
367 
368 // For canonical string output, we need to have a canonical way to rename
369 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
370 // where <xxx> is an index starting from 0.
371 class CanonicalNameMap {
372  public:
CanonicalNameMap()373   CanonicalNameMap() : index(0) {}
374 
LookupOrInsert(const string & old_name)375   string LookupOrInsert(const string& old_name) {
376     auto iter = canonical_name_map.find(old_name);
377     if (iter != canonical_name_map.end()) {
378       return iter->second;
379     }
380 
381     string new_name = absl::StrCat("tmp_", index++);
382     canonical_name_map[old_name] = new_name;
383     return new_name;
384   }
Clear()385   void Clear() {
386     canonical_name_map.clear();
387     index = 0;
388   }
389 
390  private:
391   int64 index;
392   absl::flat_hash_map<string, string> canonical_name_map;
393 };
394 
395 // HLO instructions are the atomic unit of the high-level compiler's IR.
396 //
397 // HloInstructions live inside of an HloComputation, which is analogous to a
398 // function in other programming languages.  Nodes have no total order within
399 // their computation.  Instead, they have a partial ordering determined by their
400 // data and control dependencies.
401 //
402 // HLO does not have basic blocks or explicit "branch" instructions.  Instead,
403 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
404 // control flow.  For example, the kConditional HLO executes one of two possible
405 // computations, depending on the runtime value of a predicate.
406 //
407 // HLO is pure (mostly).  It has no concept of mutable state.  Instead, data
408 // values are produced by one HLO and flow into consumers across dependency
409 // edges.
410 class HloInstruction {
411  public:
412   // A fusion node computes the same value a call to its fusion computation
413   // would compute.  However, the choice of fusion kind dictates codegen
414   // strategy for the backend.
415   //
416   // To generate code for a kFusion HloInstruction, most backends do something
417   // like the following:
418   //
419   // 1) Identify the "primary" HloInstruction of the fused computation.
420   // 2) Emit code that does the work of the primary node, creating its inputs
421   //    and transforming its outputs as specified by the fused computation.
422   //
423   // In step (2), the code emitted is usually similar to the code that would be
424   // emitted for an *unfused* version of the primary node, except that
425   //
426   //  - when the primary node reads an element of one of its operands, instead
427   //    of loading the value from memory, it *computes* the value based on the
428   //    contents of the fused computation.
429   //  - when the primary node outputs a value, instead of storing it to memory,
430   //    it forwards the value to its users, which then perform additional
431   //    computations before the value is finally stored to memory at the root of
432   //    the fusion node.
433   //
434   // An HloInstruction's FusionKind helps us find the kFusion instruction's
435   // primary node, and can also affect how we generate code in step (2).
436   //
437   //  - kInput: The primary node is the root of the fused instruction.
438   //
439   //  - kOutput: The primary node is not the root of the fused instruction.
440   //    This fusion kind requires that one operand buffer of the fusion
441   //    instruction be able to alias the output buffer.  This constraint is
442   //    usually enough to let backends find the primary node unambiguously.
443   //
444   //  - kLoop: The primary node is the root of the fused computation, but,
445   //    unlike in input fusion, we prescribe a specific implementation for
446   //    codegen.  Rather than generating code that looks like the code we'd emit
447   //    for an unfused version of the primary/root node, we emit code that
448   //    generates one element of the root at a time.
449   //
450   //  - kCustom: Custom category for backend-specific fusions that don't fit
451   //    into the above patterns.
452   //
453   // Not all backends support all fusion kinds, and given a particular fused
454   // computation, it's not in general safe to change its fusion kind.  Creation
455   // of fusion nodes is always backend-specific.
456   //
457   // For elementwise ops (e.g. kAdd), most backends would emit a
458   // one-element-at-a-time implementation for the unfused version, so loop
459   // fusion and input fusion are probably equivalent if the root node is
460   // elementwise.  They're not necessarily equivalent e.g. for kReduce, where an
461   // implementation might emit something more sophisticated for an unfused or
462   // input-fusion reduce, but will emit the naive code that reduces one element
463   // at a time for loop fusion with a reduce as the root.
464   //
465   // Another way to think of loop fusion is that it's equivalent to input
466   // fusion, but where the root node is an implicit identity node, whose
467   // unfused implementation is "read one element, write one element".
468   //
469   // TODO(b/79869434): This categorization scheme is not great.  For one thing,
470   // input and loop fusion are basically the same thing: There is no reason for
471   // the HLO to encode backend-specific decisions about how e.g. a reduce that's
472   // the root of a fusion should be lowered.  In addition, this scheme as
473   // written doesn't work for multi-output fusion, where the primary node is
474   // never actually the root (which is a kTuple instruction that gathers the
475   // multiple outputs of the fusion).
476   enum class FusionKind {
477     kLoop,
478     kInput,
479     kOutput,
480     kCustom,
481   };
482 
~HloInstruction()483   virtual ~HloInstruction() { DetachFromOperandsAndUsers(); }
484 
485   // Detaches an instruction from its operands and users. That is, remove the
486   // instruction from each operand's user set and user's operand set.
487   void DetachFromOperandsAndUsers();
488 
489   // Creates an instruction from the given proto. Arguments:
490   //
491   //   proto: the proto to convert from.
492   //   instruction_map: a map from instruction id to HloInstruction*. This map
493   //     must contain all operands of the newly constructed instruction.
494   //   computation_map: a map from computation id to HloComputation*. This map
495   //     must contain all computations which the newly constructed instruction
496   //     calls.
497   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
498       const HloInstructionProto& proto,
499       const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
500       const absl::flat_hash_map<int64, HloComputation*>& computation_map,
501       bool prohibit_empty_literal = true);
502 
503   // Creates a parameter-retrieving instruction.
504   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
505                                                          const Shape& shape,
506                                                          const string& name);
507 
508   // Creates a literal constant instruction.
509   static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
510 
511   // Creates an Iota instruction.
512   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
513                                                     int64 iota_dimension);
514 
515   // Creates a get tuple element instruction.
516   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
517       const Shape& shape, HloInstruction* operand, int64 index);
518 
519   // Creates a trace instruction that logs the input operand in the computation.
520   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
521                                                      HloInstruction* operand);
522 
523   // Creates a random number generation instruction that fills a shape with
524   // random numbers from a given distribution.
525   //
526   // The parameters to the instruction are interpreted as follows:
527   //
528   //  - If `distribution` is RNG_UNIFORM, generates a number in range
529   //    [param0, param1).
530   //
531   //  - If `distribution` is RNG_NORMAL, generates a normally-distributed value
532   //    with mean `param0` and standard deviation `param1`.
533   static std::unique_ptr<HloInstruction> CreateRng(
534       const Shape& shape, RandomDistribution distribution,
535       absl::Span<HloInstruction* const> parameters);
536 
537   // Creates a stateless random bit generator instruction that fills a shape
538   // with random bits.
539   static std::unique_ptr<HloInstruction> CreateRngBitGenerator(
540       const Shape& shape, HloInstruction* state, RandomAlgorithm algorithm);
541 
542   // Creates an instruction to update the random number generator state to
543   // reflect the new state after `delta` units of 32 random bits are generated
544   // and returns the old state.
545   static std::unique_ptr<HloInstruction> CreateRngGetAndUpdateState(
546       const Shape& shape, int64 delta);
547 
548   // Creates a unary instruction (one operand).
549   // Precondition: opcode must be a legitimate unary operation.
550   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
551                                                      HloOpcode opcode,
552                                                      HloInstruction* operand);
553 
554   // Creates a binary instruction (two operands).
555   // Precondition: opcode must be a legitimate binary operation.
556   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
557                                                       HloOpcode opcode,
558                                                       HloInstruction* lhs,
559                                                       HloInstruction* rhs);
560 
561   // Creates a ternary instruction (three operands).
562   // Precondition: opcode must be a legitimate ternary operation.
563   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
564                                                        HloOpcode opcode,
565                                                        HloInstruction* lhs,
566                                                        HloInstruction* rhs,
567                                                        HloInstruction* ehs);
568 
569   // Creates a variadic instruction (variable number of operands).
570   // Precondition: opcode must be a legitimate variadic operation.
571   static std::unique_ptr<HloInstruction> CreateVariadic(
572       const Shape& shape, HloOpcode opcode,
573       absl::Span<HloInstruction* const> operands);
574 
575   // Creates a map instruction, where the computation (given by the handle) is
576   // applied element-wise to every element in operands (across the operands,
577   // at a given index)
578   static std::unique_ptr<HloInstruction> CreateMap(
579       const Shape& shape, absl::Span<HloInstruction* const> operands,
580       HloComputation* map_computation);
581 
582   // Creates a convolution op, where rhs is the convolutional filter
583   // and window describes how the filter is applied to lhs.
584   static std::unique_ptr<HloInstruction> CreateConvolve(
585       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
586       int64 feature_group_count, int64 batch_group_count, const Window& window,
587       const ConvolutionDimensionNumbers& dimension_numbers,
588       const PrecisionConfig& precision_config);
589 
590   // Creates an FFT op, of the type indicated by fft_type.
591   static std::unique_ptr<HloInstruction> CreateFft(
592       const Shape& shape, HloInstruction* operand, FftType fft_type,
593       absl::Span<const int64> fft_length);
594 
595   // Creates a copy-start op, indicating whether this is a cross-program
596   // prefetch or not.
597   static std::unique_ptr<HloInstruction> CreateCopyStart(
598       const Shape& shape, HloInstruction* operand,
599       bool is_cross_program_prefetch = false);
600 
601   // Creates a compare op, performing the comparison specified in direction.
602   static std::unique_ptr<HloInstruction> CreateCompare(
603       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
604       Comparison::Direction direction,
605       absl::optional<Comparison::Type> type = absl::nullopt);
606 
607   static std::unique_ptr<HloInstruction> CreateTriangularSolve(
608       const Shape& shape, HloInstruction* a, HloInstruction* b,
609       const TriangularSolveOptions& options);
610 
611   static std::unique_ptr<HloInstruction> CreateCholesky(
612       const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
613 
614   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
615   // dimensions specified in 'dimension_numbers'.
616   static std::unique_ptr<HloInstruction> CreateDot(
617       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
618       const DotDimensionNumbers& dimension_numbers,
619       const PrecisionConfig& precision_config);
620 
621   // Creates a reduce-precision op, where operand is the data to reduce in
622   // precision, and exponent_bits and mantissa_bits describe the precision to
623   // reduce it to.
624   static std::unique_ptr<HloInstruction> CreateReducePrecision(
625       const Shape& shape, HloInstruction* operand, const int exponent_bits,
626       const int mantissa_bits);
627 
628   // Creates an all-gather op, which concats the operands of all participants
629   // along all_gather_dimension. The replica_groups, channel_id, and
630   // use_global_device_ids arguments are identical to those in all-reduce,
631   // except that the order of the group members determines the concatenation
632   // order of inputs from different participants.
633   static std::unique_ptr<HloInstruction> CreateAllGather(
634       const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
635       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
636       const absl::optional<int64>& channel_id, bool use_global_device_ids);
637 
638   // Creates a cross replica reduction op.
639   //
640   // `reduction_computation`: the reduction function.
641   //
642   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
643   // empty, all replicas belong to one group in the order of 0 - (n-1).
644   // Allreduce will be applied within subgroups.
645   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
646   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
647   //
648   // `channel_id`: for Allreduce nodes from different modules, if
649   // they have the same channel_id, they will be 'Allreduce'd. If
650   // empty, Allreduce will not be applied cross modules.
651   static std::unique_ptr<HloInstruction> CreateAllReduce(
652       const Shape& shape, absl::Span<HloInstruction* const> operands,
653       HloComputation* reduce_computation,
654       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
655       const absl::optional<int64>& channel_id, bool use_global_device_ids);
656 
657   // An all-to-all op takes N array operands of the same shape and scatters them
658   // to N replicas.  Each replica gathers the results into a tuple.
659   //
660   // For example, suppose we have 3 replicas, with replica i passing inputs
661   // [a_i, b_i, c_i] to its all-to-all op.  Then the resulting tuples are
662   //
663   //   replica 0: (a_0, a_1, a_2)
664   //   replica 1: (b_0, b_1, b_2)
665   //   replica 2: (c_0, c_1, c_2).
666   //
667   // If replica_groups is set, the op is sharded and the replicas are permuted.
668   // To explain by way of example, suppose we have replica_groups={{1,2},{3,0}}.
669   // Then each replica passes two operands, say [a_i, b_i], and the result is
670   //
671   //   replica 0: (b_3, b_0)
672   //   replica 1: (a_1, a_2)
673   //   replica 2: (b_1, b_2)
674   //   replica 3: (a_3, a_0).
675   //
676   // All replica groups must have the same number of elements, and the number of
677   // operands must be equal to the size of one replica group.  Each replica must
678   // appear in exactly one group.
679   //
680   // Note that this instruction is different than the all-to-all op in
681   // xla_builder.h.  The version in XlaBuilder takes one input and slices it,
682   // and then concatenates the results into a single array.  This instruction
683   // takes multiple inputs and returns a tuple; it doesn't slice or concatenate.
684   // It is used to implement the higher-level instruction in XlaBuilder.
685   static std::unique_ptr<HloInstruction> CreateAllToAll(
686       const Shape& shape, absl::Span<HloInstruction* const> operands,
687       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
688       const absl::optional<int64>& channel_id,
689       const absl::optional<int64>& split_dimension = absl::nullopt);
690 
691   // Creates a communication instruction that permutes data cross replicas.
692   // Data is sent/received according to the (source_replica_id,
693   // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
694   // target_replica_id in any pair, the output on that replica is a tensor
695   // consists of 0(s) in `shape`.
696   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
697       const Shape& shape, HloInstruction* operand,
698       const std::vector<std::pair<int64, int64>>& source_target_pairs,
699       const absl::optional<int64>& channel_id);
700 
701   // Creates a communication instruction that initiates the start of
702   // CollectivePermute.
703   static std::unique_ptr<HloInstruction> CreateCollectivePermuteStart(
704       const Shape& shape, HloInstruction* operand,
705       const std::vector<std::pair<int64, int64>>& source_target_pairs,
706       const absl::optional<int64>& channel_id);
707 
708   // Creates an instruction that returns a U32 replica ID.
709   static std::unique_ptr<HloInstruction> CreateReplicaId(
710       const Shape& shape = ShapeUtil::MakeShape(U32, {}));
711 
712   // Creates an instruction that returns a U32 partition ID.
713   static std::unique_ptr<HloInstruction> CreatePartitionId(
714       const Shape& shape = ShapeUtil::MakeShape(U32, {}));
715 
716   // Creates a conversion instruction, where operand is the data to convert and
717   // shape is the target shape for the conversion.
718   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
719                                                        HloInstruction* operand);
720 
721   // Creates a bitcast instruction, where operand is the data to
722   // convert and shape is the target shape for the conversion.
723   static std::unique_ptr<HloInstruction> CreateBitcast(const Shape& shape,
724                                                        HloInstruction* operand);
725 
726   // Creates a bitcast conversion instruction, where operand is the data to
727   // convert and shape is the target shape for the conversion.
728   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
729       const Shape& shape, HloInstruction* operand);
730 
731   // Creates an infeed instruction, which reads data of the given shape from the
732   // Infeed interface of the device. infeed_shape is the shape of the data
733   // received from the infeed *not* the shape of the infeed instruction which
734   // is a tuple containing the infeed_shape and the TOKEN.
735   static std::unique_ptr<HloInstruction> CreateInfeed(
736       const Shape& infeed_shape, HloInstruction* token_operand,
737       const string& config);
738 
739   // Creates an outfeed instruction, which outputs data. outfeed_shape is the
740   // shape of the data being outfed *not* the shape of the outfeed instruction
741   // which is a TOKEN.
742   static std::unique_ptr<HloInstruction> CreateOutfeed(
743       const Shape& outfeed_shape, HloInstruction* operand,
744       HloInstruction* token_operand, absl::string_view outfeed_config);
745 
746   // Creates an asynchronous send instruction with the given channel id, which
747   // initiates sending the operand data to a unique receive instruction in
748   // another computation that has the same channel id. If is_host_transfer is
749   // true, then this Send operation transfers data to the host.
750   static std::unique_ptr<HloInstruction> CreateSend(
751       HloInstruction* operand, HloInstruction* token, int64 channel_id,
752       bool is_host_transfer = false);
753 
754   // Blocks until data transfer for the Send instruction (operand) is complete.
755   // The operand must be kSend.
756   static std::unique_ptr<HloInstruction> CreateSendDone(
757       HloInstruction* operand, bool is_host_transfer = false);
758 
759   // Creates an asynchronous receive instruction with the given channel id,
760   // which allocates resources to receive data of the given shape from a unique
761   // send instruction in another computation that has the same channel id.  If
762   // is_host_transfer is true, then this Send operation transfers data from the
763   // host.
764   static std::unique_ptr<HloInstruction> CreateRecv(
765       const Shape& shape, HloInstruction* token, int64 channel_id,
766       bool is_host_transfer = false);
767 
768   // Blocks until data transfer for the Recv instruction (operand) is complete
769   // and returns the receive buffer. The operand must be kRecv.
770   static std::unique_ptr<HloInstruction> CreateRecvDone(
771       HloInstruction* operand, bool is_host_transfer = false);
772 
773   // Creates a slice instruction, where the operand is sliced by the given
774   // start/limit indices.
775   static std::unique_ptr<HloInstruction> CreateSlice(
776       const Shape& shape, HloInstruction* operand,
777       absl::Span<const int64> start_indices,
778       absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
779 
780   // Creates a slice instruction, where the first operand is sliced by
781   // start indices specified in the second operand, and by size specified in
782   // 'slice_sizes'.
783   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
784       const Shape& shape, HloInstruction* operand,
785       absl::Span<HloInstruction* const> start_indices,
786       absl::Span<const int64> slice_sizes);
787 
788   // Creates a dynamic update slice instruction, which updates a slice
789   // of 'operand' with 'update' and 'start_indices'.
790   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
791       const Shape& shape, HloInstruction* operand, HloInstruction* update,
792       absl::Span<HloInstruction* const> start_indices);
793 
794   // Creates a concatenate instruction, where the operands are concatenated on
795   // the provided dimension.
796   static std::unique_ptr<HloInstruction> CreateConcatenate(
797       const Shape& shape, absl::Span<HloInstruction* const> operands,
798       int64 dimension);
799 
800   // Creates a reduce instruction, where the computation (given by the handle)
801   // is applied successively to every element in operand. For example, let f be
802   // the function to apply, which takes 2 arguments, an accumulator and the
803   // current value. Let init be an initial value (which is normally chosen to be
804   // the identity element for f, e.g. 0 if f is addition).
805   // Then the reduce HLO will compute:
806   // f(f(init, value0), value1), ...)
807   static std::unique_ptr<HloInstruction> CreateReduce(
808       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
809       absl::Span<const int64> dimensions_to_reduce,
810       HloComputation* reduce_computation);
811 
812   // A more general, multiple-argument version of the above.
813   // The function to apply, f, now takes N arguments:
814   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
815   // init_valueN], and returns an N-tuple. The performed computation is (for
816   // commutative and associative f operators) equivalent to:
817   //
818   // f_1 = f(init0, ...  initN, input0.value0, ..., inputN.value0)
819   // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
820   // ..., inputN.value1)
821   // ...
822   static std::unique_ptr<HloInstruction> CreateReduce(
823       const Shape& shape, absl::Span<HloInstruction* const> operands,
824       absl::Span<HloInstruction* const> init_values,
825       absl::Span<const int64> dimensions_to_reduce,
826       HloComputation* reduce_computation);
827 
828   // Creates a reduce-window instruction, where the computation (given
829   // by the handle) is applied window-wise at each valid window
830   // position in the operand.
831   static std::unique_ptr<HloInstruction> CreateReduceWindow(
832       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
833       const Window& window, HloComputation* reduce_computation);
834 
835   // A more general, multiple-argument version of the above.
836   // The reduce_computation being applied,now takes N arguments:
837   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
838   // valueN], and returns an N-tuple. The operands and init_values now each
839   // contain a span of N input arrays and n initial values.
840   static std::unique_ptr<HloInstruction> CreateReduceWindow(
841       const Shape& shape, absl::Span<HloInstruction* const> operands,
842       absl::Span<HloInstruction* const> init_values, const Window& window,
843       HloComputation* reduce_computation);
844 
845   // Creates a batch-norm-training instruction.
846   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
847       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
848       HloInstruction* offset, float epsilon, int64 feature_index);
849 
850   // Creates a batch-norm-inference instruction.
851   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
852       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
853       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
854       float epsilon, int64 feature_index);
855 
856   // Creates a batch-norm-grad instruction.
857   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
858       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
859       HloInstruction* mean, HloInstruction* variance,
860       HloInstruction* grad_output, float epsilon, int64 feature_index);
861 
862   // Creates a scatter computation that scatters the `source` array to the
863   // selected indices of each window.
864   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
865       const Shape& shape, HloInstruction* operand, HloComputation* select,
866       const Window& window, HloInstruction* source, HloInstruction* init_value,
867       HloComputation* scatter);
868 
869   // Creates a broadcast instruction.
870   static std::unique_ptr<HloInstruction> CreateBroadcast(
871       const Shape& shape, HloInstruction* operand,
872       absl::Span<const int64> broadcast_dimensions);
873 
874   // Creates a sequence of instructions that performs an explicit broadcast of
875   // the operand to the target shape.
876   //
877   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
878   // returned as a unique_ptr for API consistency with other factory methods in
879   // this interface.
880   //
881   // TODO(b/72173833) Ideally HloComputations would always be present, and so
882   // the adder being passed by the caller would not be necessary.
883   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
884       const Shape& output_shape, HloInstruction* operand,
885       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
886           adder);
887 
888   // Creates a pad instruction, where the operand is padded on the edges and
889   // between the elements with the given padding value.
890   static std::unique_ptr<HloInstruction> CreatePad(
891       const Shape& shape, HloInstruction* operand,
892       HloInstruction* padding_value, const PaddingConfig& padding_config);
893 
894   // Creates a reshape instruction, where the operand is flattened row-major
895   // order and then reshaped to the given result shape.
896   static std::unique_ptr<HloInstruction> CreateReshape(
897       const Shape& shape, HloInstruction* operand,
898       int64 inferred_dimension = -1);
899 
900   // Creates a dynamic reshape instruction. Similar to reshape but dynamic
901   // dimensions sizes are provided as additional variadic arguments.
902   //
903   // Precondition: dim_sizes.size() == shape.rank()
904   static std::unique_ptr<HloInstruction> CreateDynamicReshape(
905       const Shape& shape, HloInstruction* data_operand,
906       absl::Span<HloInstruction* const> dim_sizes);
907 
908   // Creates a transpose instruction which permutes the operand dimensions.
909   static std::unique_ptr<HloInstruction> CreateTranspose(
910       const Shape& shape, HloInstruction* operand,
911       absl::Span<const int64> dimensions);
912 
913   // Creates a n-ary sort op with a 'compare' computation which is used for
914   // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
915   // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
916   // specific index positions which should be compared, and should return a
917   // PRED. 'is_stable' specifies whether stable sorting is required.
918   static std::unique_ptr<HloInstruction> CreateSort(
919       const Shape& shape, int64 dimension,
920       absl::Span<HloInstruction* const> operands, HloComputation* compare,
921       bool is_stable);
922 
923   // Creates a while instruction, given a condition computation, a body
924   // computation, and the initial value for the input of the computations. For
925   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
926   // corresponds to the C code below.
927   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
928   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
929                                                      HloComputation* condition,
930                                                      HloComputation* body,
931                                                      HloInstruction* init);
932 
933   static std::unique_ptr<HloInstruction> CreateConditional(
934       const Shape& shape, HloInstruction* pred,
935       HloInstruction* true_computation_arg, HloComputation* true_computation,
936       HloInstruction* false_computation_arg, HloComputation* false_computation);
937 
938   static std::unique_ptr<HloInstruction> CreateConditional(
939       const Shape& shape, HloInstruction* branch_index,
940       absl::Span<HloComputation* const> branch_computations,
941       absl::Span<HloInstruction* const> branch_computation_args);
942 
943   static std::unique_ptr<HloInstruction> CreateGather(
944       const Shape& shape, HloInstruction* operand,
945       HloInstruction* start_indices,
946       const GatherDimensionNumbers& gather_dim_numbers,
947       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
948 
949   static std::unique_ptr<HloInstruction> CreateScatter(
950       const Shape& shape, HloInstruction* operand,
951       HloInstruction* scatter_indices, HloInstruction* updates,
952       HloComputation* update_computation,
953       const ScatterDimensionNumbers& scatter_dim_numbers,
954       bool indices_are_sorted, bool unique_indices);
955 
956   // Creates a kDomain instruction which delimits an HLO domain which have
957   // the provided user and operand side metadata.
958   static std::unique_ptr<HloInstruction> CreateDomain(
959       const Shape& shape, HloInstruction* operand,
960       std::unique_ptr<DomainMetadata> operand_side_metadata,
961       std::unique_ptr<DomainMetadata> user_side_metadata);
962 
963   // Creates a fusion instruction. A fusion instruction contains one or more
964   // fused instructions forming an expression with a single root
965   // "fused_root". Additional instructions can be added to the fusion
966   // instruction with the method FuseInstruction.
967   static std::unique_ptr<HloInstruction> CreateFusion(
968       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
969 
970   static std::unique_ptr<HloInstruction> CreateFusion(
971       const Shape& shape, FusionKind fusion_kind,
972       absl::Span<HloInstruction* const> operands,
973       HloComputation* fusion_computation);
974 
975   // Creates a call instruction that applies the given computation on the given
976   // operands. "shape" is the resultant shape.
977   static std::unique_ptr<HloInstruction> CreateCall(
978       const Shape& shape, absl::Span<HloInstruction* const> operands,
979       HloComputation* computation);
980 
981   // Creates a custom call instruction that applies the given custom call target
982   // to the given operands. "opaque" can be an arbitrary string with a
983   // backend-specific interpretation. "shape" is the resultant shape.
984   static std::unique_ptr<HloInstruction> CreateCustomCall(
985       const Shape& shape, absl::Span<HloInstruction* const> operands,
986       absl::string_view custom_call_target, string opaque = "");
987 
988   // Overload with a to_apply computation.
989   static std::unique_ptr<HloInstruction> CreateCustomCall(
990       const Shape& shape, absl::Span<HloInstruction* const> operands,
991       HloComputation* to_apply, absl::string_view custom_call_target,
992       string opaque = "");
993 
994   // Overload with multiple computations. The called computations can have
995   // different function signatures.
996   static std::unique_ptr<HloInstruction> CreateCustomCall(
997       const Shape& shape, absl::Span<HloInstruction* const> operands,
998       absl::Span<HloComputation* const> called_computations,
999       absl::string_view custom_call_target, string opaque = "");
1000 
1001   // Overload which constrains the layouts of the operand and result. 'shape'
1002   // and 'operand_shapes_with_layout' must have layouts.
1003   // 'operand_shapes_with_layout' must have a compatible element for each
1004   // operand.
1005   static std::unique_ptr<HloInstruction> CreateCustomCall(
1006       const Shape& shape, absl::Span<HloInstruction* const> operands,
1007       absl::string_view custom_call_target,
1008       absl::Span<const Shape> operand_shapes_with_layout, string opaque = "");
1009 
1010   // Creates a tuple instruction with the given elements. This is a convenience
1011   // wrapper around CreateVariadic.
1012   static std::unique_ptr<HloInstruction> CreateTuple(
1013       absl::Span<HloInstruction* const> elements);
1014 
1015   // Creates a reverse instruction, which reverses the order of the elements
1016   // in the specified dimensions.
1017   static std::unique_ptr<HloInstruction> CreateReverse(
1018       const Shape& shape, HloInstruction* operand,
1019       absl::Span<const int64> dimensions);
1020 
1021   // Creates a Afterall instruction used for joining or creating new values of
1022   // token type which thread through side-effecting operations. Operands must
1023   // all be tokens, and there must be at least one operand.
1024   static std::unique_ptr<HloInstruction> CreateAfterAll(
1025       absl::Span<HloInstruction* const> operands);
1026 
1027   // Creates an AfterAll instruction which creates a token type out of thin air
1028   // (no operands). This is a separate method from CreateAfterAll to facility
1029   // the removal of operand-less AfterAll instructions.
1030   // TODO(b/110532604): Remove this capability of creating a token from nothing
1031   // when we plumb a primordial token from the entry computation.
1032   static std::unique_ptr<HloInstruction> CreateToken();
1033 
1034   static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
1035       const Shape& shape, HloInstruction* operand, int64 dimension);
1036 
1037   static std::unique_ptr<HloInstruction> CreateSetDimensionSize(
1038       const Shape& shape, HloInstruction* operand, HloInstruction* val,
1039       int64 dimension);
1040 
1041   static std::unique_ptr<HloInstruction> CreateAddDependency(
1042       HloInstruction* data_operand, HloInstruction* token_operand);
1043 
1044   // Returns the opcode for this instruction.
opcode()1045   HloOpcode opcode() const { return opcode_; }
1046 
1047   // Returns true if this instruction has a side effect, irrespective of whether
1048   // any called computations may contain an instruction with side effects.
1049   bool HasSideEffectNoRecurse() const;
1050 
1051   // Returns true if this instruction has a side effect. An instruction has a
1052   // side effect if it uses certain opcodes or calls a computation with a side
1053   // effect.
1054   bool HasSideEffect() const;
1055 
1056   // Returns the result shape of this instruction.
1057   const Shape& shape() const;
1058 
1059   // Returns the (mutable) result shape of this instruction.
mutable_shape()1060   Shape* mutable_shape() { return &shape_; }
1061 
1062   // Returns the ith operand to this instruction.
1063   const HloInstruction* operand(int64 i) const;
1064 
1065   // Returns the ith operand to this instruction.
1066   HloInstruction* mutable_operand(int64 i);
1067 
1068   // Returns the number of operands to this instruction.
operand_count()1069   int64 operand_count() const { return operands_.size(); }
1070 
1071   // Returns the vector of operands of this instruction.
1072   using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
operands()1073   const InstructionVector& operands() const { return operands_; }
1074 
1075   // Returns the vector of unique operands, in the same order they are found
1076   // within the operand vector.
1077   InstructionVector unique_operands() const;
1078 
1079   // Returns the index of 'target' in the operands sequence.
1080   // Precondition: target must be an operand (or a fatal error will occur).
1081   int64 operand_index(const HloInstruction* target) const;
1082 
1083   // Returns the number of users of this instruction.
user_count()1084   int64 user_count() const { return users_.size(); }
1085 
1086   // Returns the users of this instruction.
users()1087   const std::vector<HloInstruction*>& users() const { return users_; }
1088 
1089   // Returns the index of the user in the users() vector.
1090   //
1091   // Precondition: `user` is a user of the instruction.
1092   int64 UserId(HloInstruction* user);
1093 
1094   // Returns true if this instruction is a user of 'instruction'.
IsUserOf(const HloInstruction * instruction)1095   bool IsUserOf(const HloInstruction* instruction) const {
1096     return ContainsKey(instruction->user_map_, this);
1097   }
1098 
1099   // Adds a control dependency from this instruction to the given
1100   // instruction. This instruction becomes a control predecessor of
1101   // 'instruction', and 'instruction' becomes a control successor of this
1102   // instruction. Returns an error status if either of the given instructions
1103   // does not belong to the same computation.
1104   //
1105   // This is used to enforce an additional ordering requirement that is not
1106   // captured by normal data dependencies, such as ordering among Send or Recv
1107   // operations to avoid deadlock.
1108   Status AddControlDependencyTo(HloInstruction* instruction);
1109 
1110   // Removes a previously added control dependency from this instruction to
1111   // 'instruction'.
1112   Status RemoveControlDependencyTo(HloInstruction* instruction);
1113 
1114   // Drops all control predecessors and successors from this HLO instruction.
1115   Status DropAllControlDeps();
1116 
1117   // Copies the control predecessors and successors on this HLO instruction to
1118   // `inst`.  Does not do a deep copy so this makes sense only if `inst` and
1119   // this HLO are in the same module.
1120   //
1121   // Depending on the use cases we see in practice, in the future we may
1122   // consider folding the logic here into Clone, CloneWithNewOperands and
1123   // ReplaceAllUsesWith by treating control dependencies like data dependencies.
1124   Status CopyAllControlDepsFrom(const HloInstruction* inst);
1125 
1126   // Returns the set of control predecessors (successors) of this
1127   // instruction. Control predecessors (successors) must execute before (after)
1128   // the current instruction.
control_predecessors()1129   const std::vector<HloInstruction*>& control_predecessors() const {
1130     return control_predecessors_;
1131   }
control_successors()1132   const std::vector<HloInstruction*>& control_successors() const {
1133     return control_successors_;
1134   }
1135 
1136   // Returns true if "other" performs the same computation as this instruction.
1137   bool Identical(
1138       const HloInstruction& other,
1139       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1140           eq_operands = std::equal_to<const HloInstruction*>(),
1141       const std::function<bool(const HloComputation*, const HloComputation*)>&
1142           eq_computations = std::equal_to<const HloComputation*>(),
1143       bool layout_sensitive = true) const {
1144     return IdenticalInternal(other, eq_operands, eq_computations,
1145                              layout_sensitive,
1146                              /*ignore_channel_id_values=*/false);
1147   }
1148 
1149   // Same as Identical() but ignores channel ID value mismatches, as long as
1150   // both have channel IDs or neither has a channel ID.
1151   bool IdenticalIgnoringChannelIdValues(
1152       const HloInstruction& other,
1153       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1154           eq_operands = std::equal_to<const HloInstruction*>(),
1155       const std::function<bool(const HloComputation*, const HloComputation*)>&
1156           eq_computations = std::equal_to<const HloComputation*>(),
1157       bool layout_sensitive = true) const {
1158     return IdenticalInternal(other, eq_operands, eq_computations,
1159                              layout_sensitive,
1160                              /*ignore_channel_id_values=*/true);
1161   }
1162 
1163   // Generates a hash value of an HLO instruction. Hash considers
1164   // information on opcode, shape, operands, and typically a root instruction.
1165   // This function returns the same hash value for equivalent HLO instructions,
1166   // with respect to HloInstruction::Identical() method.
1167   //
1168   // Uses hash_operand function to compute hash values of its operands.
1169   // At the very top level, hash_operand should be non-recursive to prevent
1170   // non-termination.
1171   uint64 Hash(
1172       const std::function<uint64(const HloInstruction*)>& hash_operand) const;
1173 
1174   // Calls the above method with non-recursive hash_operand function.
1175   uint64 Hash() const;
1176 
1177   // Returns whether the instruction has a constant operand.
1178   bool HasConstantOperand() const;
1179 
1180   // Replaces the use of this instruction in "user" with "new_producer". Note
1181   // that there might be multiple uses of this instruction in "user"; all will
1182   // be replaced.
1183   //
1184   // If user is a fusion instruction, this function will remove any duplicated
1185   // operands of it which could be created due to this replacement.
1186   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
1187 
1188   // Same as ReplaceUseWith(), but new_producer can have a different shape.
1189   Status ReplaceUseWithDifferentShape(HloInstruction* user,
1190                                       HloInstruction* new_producer);
1191 
1192   // Replaces the specified operand with new_operand. The old and new operands
1193   // must have compatible shapes ignoring floating-point precision.
1194   //
1195   // This function does NOT remove duplicated operands even if this instruction
1196   // is a fusion, so that the existing operand numbers do not change.
1197   Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand);
1198 
1199   // Same as ReplaceOperandWith(), but new_operand can have a different shape.
1200   Status ReplaceOperandWithDifferentShape(int64 operand_num,
1201                                           HloInstruction* new_operand);
1202 
1203   // Replaces all uses of this instruction with the new producer. If
1204   // new_producer is a user of this instruction then new_producer remains a use
1205   // of this instruction to avoid introducing cycles into the graph.
1206   //
1207   // If this instruction is the root of its computation, sets the computation's
1208   // root to new_producer.
1209   //
1210   // The new producer must have a compatible shape ignoring floating-point
1211   // precision.
1212   //
1213   // If a user is a fusion instruction, this function will remove any duplicated
1214   // operands of it which could be created due to this replacement.
1215   Status ReplaceAllUsesWith(HloInstruction* new_producer);
1216 
1217   // Same as ReplaceAllUsesWith, but new_producer can have a different shape.
1218   Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
1219 
1220   // Same as ReplaceAllUsesWith, but only replace given set of users.
1221   Status ReplaceUsesWith(absl::Span<HloInstruction* const> users,
1222                          HloInstruction* new_producer);
1223   Status ReplaceAllUsesWithDifferentShape(
1224       absl::Span<HloInstruction* const> users, HloInstruction* new_producer);
1225 
1226   // Performs a postorder DFS visit using this node as the root. If
1227   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
1228   // complete. If ignore_control_predecessors is true, instructions only
1229   // reachable via control dependencies will not be visited, and the postorder
1230   // will not take control dependencies into account. It is as if the control
1231   // dependencies didn't exist in the graph at all.
1232   template <typename HloInstructionPtr>
1233   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
1234                 bool call_finish_visit = true,
1235                 bool ignore_control_predecessors = false);
1236   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
1237                 bool ignore_control_predecessors = false) const {
1238     return const_cast<HloInstruction*>(this)->Accept(
1239         visitor, call_finish_visit, ignore_control_predecessors);
1240   }
1241 
1242   // Same as Accept() above, but the order of operand and control predecessor
1243   // visitation is determined by the given operand order; if compare(A, B) ==
1244   // true, A is visited before B.
1245   using CompareFunction =
1246       std::function<bool(const HloInstruction*, const HloInstruction*)>;
1247   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
1248                                 const CompareFunction& operand_order,
1249                                 bool call_finish_visit = true);
1250 
1251   // Visit this instruction and only this instruction with the given visitor.
1252   template <typename HloInstructionPtr>
1253   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
1254 
1255   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
1256   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
1257   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
1258   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
1259       const;
1260 
LatestNonGteAncestorAndIndex()1261   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
1262     auto rv =
1263         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
1264     return {const_cast<HloInstruction*>(rv.first), rv.second};
1265   }
1266 
1267   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
1268   const HloInstruction* LatestNonGteAncestor() const;
1269 
LatestNonGteAncestor()1270   HloInstruction* LatestNonGteAncestor() {
1271     return const_cast<HloInstruction*>(
1272         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
1273   }
1274 
1275   // Returns true whether this instruction is effectively a bitcast. Currently,
1276   // this means it either is a bitcast, or it is a transpose that is effectively
1277   // a bitcast.
1278   bool IsEffectiveBitcast() const;
1279 
1280   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
1281   // The setter should only be called by HloModule or HloComputation methods.
1282   //
1283   // Precondition: The instruction has a valid to_apply_ field.
1284   HloComputation* to_apply() const;
1285   void set_to_apply(HloComputation* to_apply);
1286 
1287   // Gets/sets the while_condition or while_body HloComputation for While. The
1288   // setters should only be called by HloModule or HloComputation methods.
1289   //
1290   // Precondition: The instruction is a While instruction.
1291   HloComputation* while_condition() const;
1292   HloComputation* while_body() const;
1293   void set_while_condition(HloComputation* while_condition);
1294   void set_while_body(HloComputation* while_body);
1295 
1296   HloInstruction* while_init() const;
1297 
1298   // Gets/sets the true and false HloComputation for Conditional.
1299   //
1300   // Precondition: The instruction is a predicated Conditional instruction.
1301   HloComputation* true_computation() const;
1302   HloComputation* false_computation() const;
1303 
1304   // Gets the branch HloComputations for Conditional.
1305   //
1306   // Precondition: The instruction is a Conditional instruction.
1307   const std::vector<HloComputation*>& branch_computations() const;
1308   int branch_count() const;
1309   HloComputation* branch_computation(int b) const;
1310   // Sets a branch HloComputation for Conditional.
1311   // The setter should only be called by HloModule or HloComputation methods.
1312   //
1313   // Precondition: The instruction is a Conditional instruction.
1314   void set_branch_computation(int b, HloComputation* computation);
1315 
1316   // Returns a string for the signature of this instruction if considered as a
1317   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
1318   string SignatureString() const;
1319 
1320   // Returns a debugging string that represents this instruction.
1321   //
1322   // (We express the default options using an overload rather than a default
1323   // param because gdb ignores default params, but does resolve overloads.)
1324   //
1325   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
1326   // default, backing off on providing full information for very large strings,
1327   // or provide a different name for a ToString-like function that does that.
ToString()1328   string ToString() const { return ToString(HloPrintOptions()); }
1329   string ToString(const HloPrintOptions& options) const;
1330 
1331   // Components of the ToString() representation:
1332 
1333   // Returns a string representation of the operand list.
1334   string OperandsToString(const HloPrintOptions& options) const;
1335 
1336   // Returns string representation of op-specific attributes.
1337   std::vector<string> ExtraAttributesToString(
1338       const HloPrintOptions& options) const;
1339 
1340   // As ToString, but returns a shorter string.
1341   string ToShortString() const;
1342 
1343   // Prints an instruction to a string.
1344   //
1345   // The canonical string representation needs to name operands and instruction
1346   // names in a consistent way. This is implemented through the
1347   // canonical_name_map.
1348   string ToStringWithCanonicalNameMap(
1349       const HloPrintOptions& options,
1350       CanonicalNameMap* canonical_name_map) const;
1351 
1352   // Returns a serialized representation of this instruction.
1353   virtual HloInstructionProto ToProto() const;
1354 
1355   // Returns a category for the HLO. This could be something like "convolution"
1356   // or "elementwise".
1357   virtual string ToCategory() const;
1358 
1359   // Returns a logging instruction, if the output of this instruction is logged.
1360   //
1361   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
1362   HloInstruction* tracing() const;
1363   void set_tracing(HloInstruction* trace_instruction);
1364 
1365   // Returns true if this instruction is fused, ie contained within a fusion
1366   // instruction.
1367   bool IsFused() const;
1368 
1369   bool IsLoopFusion() const;
1370   bool IsInputFusion() const;
1371   bool IsOutputFusion() const;
1372   bool IsCustomFusion() const;
1373 
1374   // Returns true if this instruction can be legally fused into a fusion
1375   // instruction.
1376   bool IsFusible() const;
1377 
1378   bool IsCustomCall(absl::string_view target) const;
1379 
1380   // Returns the sharding applied to this operator.
1381   // REQUIRES: has_sharding() is true.
sharding()1382   const HloSharding& sharding() const {
1383     CHECK(has_sharding());
1384     return *sharding_;
1385   }
sharding_ptr()1386   std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
1387 
1388   // Returns the sharding applied to this operator, or default_ if none exists.
sharding_or_default(const HloSharding & default_)1389   const HloSharding& sharding_or_default(const HloSharding& default_) const {
1390     return sharding_ ? *sharding_ : default_;
1391   }
1392   // Returns the sharding unique device, if any.
sharding_unique_device()1393   absl::optional<int64> sharding_unique_device() const {
1394     if (sharding_ == nullptr) {
1395       return absl::optional<int64>();
1396     }
1397     return sharding_->UniqueDevice();
1398   }
1399   // Sets the sharding of this operator. Should only be called by HloModule or
1400   // HloComputation methods.
set_sharding(const HloSharding & sharding)1401   void set_sharding(const HloSharding& sharding) {
1402     sharding_ = std::make_shared<const HloSharding>(sharding);
1403   }
set_sharding(std::shared_ptr<const HloSharding> sharding)1404   void set_sharding(std::shared_ptr<const HloSharding> sharding) {
1405     sharding_ = std::move(sharding);
1406   }
1407   void set_single_sharding(const HloSharding& sharding);
1408   // Sets a sharding that assigns the current instruction to device.
set_device_sharding(int64 device)1409   void set_device_sharding(int64 device) {
1410     set_single_sharding(HloSharding::AssignDevice(device));
1411   }
1412   // Remove any sharding from this operator.
clear_sharding()1413   void clear_sharding() { sharding_ = nullptr; }
1414   // Return true if this operator has a sharding assigned.
has_sharding()1415   bool has_sharding() const { return sharding_ != nullptr; }
1416   // Checks whether the instruction has compatible sharding with the other
1417   // instruction.
has_compatible_sharding(const HloInstruction * other)1418   bool has_compatible_sharding(const HloInstruction* other) const {
1419     if (!has_sharding()) {
1420       return !other->has_sharding();
1421     }
1422     return other->has_sharding() ? sharding() == other->sharding() : false;
1423   }
1424 
1425   // When creating a new instruction which either replaces, or shifts up (kCopy
1426   // insertion case), another instruction, we need to make sure the certain
1427   // properties of the new instruction are copied into the derived one. As of
1428   // today, the metadata and sharding will be propagated to the derived
1429   // instruction.
1430   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
1431 
1432   // Clones the HLO instruction. The clone will have the same opcode, shape, and
1433   // operands. After creation the clone has no uses. "this" (the instruction
1434   // cloned from) is not changed. Suffix is the string to append to the name of
1435   // the instruction to form the name of the cloned instruction.
1436   // Ignores the control predecessors and successors of this HLO instruction.
1437   std::unique_ptr<HloInstruction> Clone(
1438       const string& suffix = "clone", HloCloneContext* context = nullptr) const;
1439 
1440   // Clones the HLO instruction as above but with new shape.
1441   std::unique_ptr<HloInstruction> CloneWithNewShape(
1442       const Shape& shape, const string& suffix = "clone",
1443       HloCloneContext* context = nullptr) const;
1444 
1445   // Clones the HLO instruction as above but with new shape and operands.
1446   std::unique_ptr<HloInstruction> CloneWithNewOperands(
1447       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1448       HloCloneContext* context = nullptr) const;
1449 
1450   // Returns the computations this instruction directly calls (if any).
called_computations()1451   const std::vector<HloComputation*>& called_computations() const {
1452     return called_computations_;
1453   }
1454 
1455   // Replaces all called computations based on a map function. This is needed
1456   // when we clone hlo_computations and want to let the instructions to point
1457   // to the newly cloned nodes.
ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1458   void ReplaceCalledComputations(
1459       std::function<HloComputation*(HloComputation*)> map_function) {
1460     for (int64 i = 0; i < called_computations_.size(); ++i) {
1461       called_computations_[i] = map_function(called_computations_[i]);
1462     }
1463   }
1464 
1465   // Clears out the called computations.
1466   //
1467   // This is, in particular, necessary when inlining function bodies into their
1468   // caller. If there were side-effecting operations in the called computations,
1469   // the call itself is considered side-effecting and thus cannot be removed. By
1470   // clearing out the computations, we reflect the fact that all side-effecting
1471   // properties have been reflected in the caller, and make the call HLO
1472   // removable.
ClearCalledComputations()1473   void ClearCalledComputations() { called_computations_.clear(); }
1474 
1475   // Returns true if this instruction performs an elementwise operation on
1476   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1477   // to compute the output at index {i_0,i_1,...,i_n}, the only element required
1478   // from the operand (if any) is the element at {i_0,i_1,...,i_n}.
1479   //
1480   // Note on performance: when this instruction is kFusion, this method, in the
1481   // worst case, scans all fused instructions. We could speed this up by
1482   // caching.
1483   bool IsElementwiseOnOperand(int64 operand_idx) const;
1484 
1485   // Returns true if this instruction is elementwise on all its operands.
1486   bool IsElementwise() const;
1487 
1488   static bool IsOpElementwise(HloOpcode opcode);
1489 
1490   // Returns true if this is a cross module all-reduce instruction.
1491   bool IsCrossModuleAllReduce() const;
1492 
1493   // Returns true if this is a cross-replica all-reduce instruction.
1494   bool IsCrossReplicaAllReduce() const;
1495 
1496   // Returns true if this instruction is binary and elementwise.
1497   bool IsElementwiseBinary() const;
1498 
1499   // Returns whether this instruction may reuse elements of its `i`th operand.
ReusesOperandElements(int64 i)1500   bool ReusesOperandElements(int64 i) const {
1501     return OperandElementUse(i) == UseKind::kReuse;
1502   }
1503 
1504   // Returns the indices that the given operand appear in the operand list of
1505   // this instruction. Note that an instruction can use the same operand
1506   // multiple times.
1507   absl::InlinedVector<int64, 4> OperandIndices(
1508       const HloInstruction* operand) const;
1509 
1510   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1511   // this reshape merely inserts or deletes 1-sized dimensions, return the input
1512   // indices of the deleted dimensions and the output indices of the inserted
1513   // dimensions.
1514   //
1515   // Precondition: this op must be a reshape.
1516   std::tuple<bool, std::vector<int64>, std::vector<int64>>
1517   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1518 
1519   // Gets the string identifier for this instruction.
name()1520   const string& name() const { return name_; }
1521 
1522   // Sets the string identifier for this instruction. Name will be sanitized to
1523   // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
SetAndSanitizeName(const string & name)1524   void SetAndSanitizeName(const string& name) {
1525     name_ = NameUniquer::GetSanitizedName(name);
1526   }
1527 
1528   // Use the given NameUniquer to select a unique name for the instruction based
1529   // on the instruction's existing name.
1530   void UniquifyName(NameUniquer* name_uniquer);
1531 
1532   // Clear the unique ID of the instruction so that it can be re-assigned, such
1533   // as for the purpose of compacting the instruction unique IDs.
ClearUniqueIdInternal()1534   void ClearUniqueIdInternal() { unique_id_ = -1; }
1535 
1536   // Set the unique id for this instruction to "id"
SetUniqueId(int id)1537   void SetUniqueId(int id) {
1538     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1539     CHECK_GE(id, 0);
1540     unique_id_ = id;
1541   }
1542 
1543   // Return the unique ID assigned to this node via SetUniqueId (or -1
1544   // if no id has been assigned yet).
unique_id()1545   int unique_id() const { return unique_id_; }
1546 
1547   // Returns the backend-specific configuration for how a backend should compile
1548   // this HLO. The meaning of the field is backend specific. Not for use before
1549   // or during general HLO optimization, since HLO optimizations do not preserve
1550   // this field and they cannot interpret it due to its meaning being backend
1551   // specific. Except for CustomCall, where this field is preserved and no
1552   // general HLO optimization needs to interpret it.
1553   //
1554   // ConfigProto should be a protobuf Message type.
1555   template <typename ConfigProto>
backend_config()1556   StatusOr<ConfigProto> backend_config() const {
1557     ConfigProto proto;
1558     TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
1559     return std::move(proto);
1560   }
1561   Status set_backend_config(const tensorflow::protobuf::Message& proto);
1562 
set_frontend_attributes(FrontendAttributes frontend_attributes)1563   void set_frontend_attributes(FrontendAttributes frontend_attributes) {
1564     frontend_attributes_ = std::move(frontend_attributes);
1565   }
1566 
frontend_attributes()1567   const FrontendAttributes& frontend_attributes() const {
1568     return frontend_attributes_;
1569   }
1570 
1571   // Getter/setter for raw JSON-encoded backend config.  Prefer the
1572   // functions above that deal in proto Messages where possible.
raw_backend_config_string()1573   const string& raw_backend_config_string() const { return backend_config_; }
set_raw_backend_config_string(string config_str)1574   void set_raw_backend_config_string(string config_str) {
1575     backend_config_ = std::move(config_str);
1576   }
1577 
is_default_config()1578   bool is_default_config() const { return is_default_config_; }
set_default_config()1579   void set_default_config() { is_default_config_ = true; }
1580 
1581   // Returns a string representation of a proto in the format used by
1582   // raw_backend_config_string.
1583   //
1584   // This is morally equivalent to:
1585   //
1586   //   HloInstruction instr;
1587   //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
1588   //   return instr.raw_backend_config_string();
1589   //
1590   static StatusOr<string> BackendConfigToRawString(
1591       const tensorflow::protobuf::Message& proto);
1592 
1593   // Returns the information used to tell the implementation information about
1594   // what sort of precision is requested. The meaning of the field is backend
1595   // specific. At the moment, it is only supported for kConvolution and kDot.
1596   // Transformations on one kDot or kConvolution to another will preserve this
1597   // information. Transformations to other HLOs will not preserve this
1598   // information but it is presumed that the alternate lowering is strictly
1599   // superior.
1600   // Precondition: opcode must be kConvolution or kDot.
1601   const PrecisionConfig& precision_config() const;
1602   PrecisionConfig* mutable_precision_config();
1603 
1604   // Sets the debug metadata for this instruction, excluding creation_pass_id,
1605   // which should never be copied anywhere.
set_metadata(const OpMetadata & metadata)1606   void set_metadata(const OpMetadata& metadata) {
1607     int64 creation_pass_id = metadata_.creation_pass_id();
1608     metadata_ = metadata;
1609     metadata_.set_creation_pass_id(creation_pass_id);
1610   }
set_creation_pass_id(int64 pass_id)1611   void set_creation_pass_id(int64 pass_id) {
1612     metadata_.set_creation_pass_id(pass_id);
1613   }
set_metadata_op_name(const std::string & name)1614   void set_metadata_op_name(const std::string& name) {
1615     metadata_.set_op_name(name);
1616   }
set_logical_creation_pass_id(int64 pass_id)1617   void set_logical_creation_pass_id(int64 pass_id) {
1618     metadata_.set_logical_creation_pass_id(pass_id);
1619   }
metadata()1620   const OpMetadata& metadata() const { return metadata_; }
1621 
1622   // Set/get the computation containing this instruction. set_parent should only
1623   // be called by HloComputation methods which add/remove instructions to
1624   // computations.
set_parent(HloComputation * computation)1625   void set_parent(HloComputation* computation) { parent_ = computation; }
parent()1626   const HloComputation* parent() const { return parent_; }
parent()1627   HloComputation* parent() { return parent_; }
1628 
1629   // Returns the module for this instruction.
1630   HloModule* GetModule() const;
1631 
1632   // Get/Set the number of partitions per outer dimension (in order, starting
1633   // with outer-most dimension first). Currently used by the parallel cpu
1634   // backend to partition HLOs into parallel tasks.
1635   //
1636   // TODO(b/62783254) Replace these methods with a more general way to
1637   // annotate HLOs with backend-specific information.
outer_dimension_partitions()1638   const std::vector<int64>& outer_dimension_partitions() const {
1639     return outer_dimension_partitions_;
1640   }
1641   void set_outer_dimension_partitions(
1642       const std::vector<int64>& outer_dimension_partitions);
1643 
1644   // Old methods kept for smooth subclassing transition BEGIN.
1645   // TODO(b/80131774): Remove this code.
1646 
1647   // Delegates to HloBatchNormInstruction::feature_index.
1648   int64 feature_index() const;
1649 
1650   // Delegates to HloBatchNormInstruction::epsilon.
1651   float epsilon() const;
1652 
1653   // Delegates to HloFftInstruction::fft_type.
1654   FftType fft_type() const;
1655 
1656   // Delegates to HloFftInstruction::fft_length.
1657   const std::vector<int64>& fft_length() const;
1658 
1659   // Delegates to HloChannelInstruction::channel_id.
1660   absl::optional<int64> channel_id() const;
1661   void set_channel_id(const absl::optional<int64>& channel_id);
1662 
1663   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1664   virtual const std::vector<int64>& dimensions() const {
1665     LOG(FATAL) << "Unimplemented method.";
1666   }
dimensions(int64 index)1667   virtual int64 dimensions(int64 index) const {
1668     LOG(FATAL) << "Unimplemented method.";
1669   }
mutable_dimensions()1670   virtual std::vector<int64>* mutable_dimensions() {
1671     LOG(FATAL) << "Unimplemented method.";
1672   }
1673 
1674   // Delegates to HloConcatenateInstruction::concatenate_dimension.
1675   int64 concatenate_dimension() const;
1676 
1677   // Delegates to HloGetDimensionSizeInstruction::dimension.
1678   int64 dimension() const;
1679 
1680   // Delegates to HloReshapeInstruction::inferred_dimension.
1681   int64 inferred_dimension() const;
1682 
1683   // Returns whether this instruction does a rank-2 transposition.
1684   bool IsRank2Transpose() const;
1685 
1686   // Delegates to HloSliceInstruction::slice_start.
1687   int64 slice_starts(int64 dimension) const;
1688   const std::vector<int64>& slice_starts() const;
1689   std::vector<int64>* mutable_slice_starts();
1690 
1691   // Delegates to HloSliceInstruction::slice_limits.
1692   int64 slice_limits(int64 dimension) const;
1693   const std::vector<int64>& slice_limits() const;
1694   std::vector<int64>* mutable_slice_limits();
1695 
1696   // Delegates to HloSliceInstruction::slice_strides.
1697   int64 slice_strides(int64 dimension) const;
1698   const std::vector<int64>& slice_strides() const;
1699   std::vector<int64>* mutable_slice_strides();
1700 
1701   // Returns the literal associated with this instruction.
1702   const Literal& literal() const;
1703 
1704   // Returns whether the instruction is a constant.
1705   bool IsConstant() const;
1706 
1707   // Delegate to HloConstantInstruction::RelayoutConstant.
1708   void RelayoutConstant(const Layout& new_layout,
1709                         const ShapeIndex& shape_index = {});
1710 
1711   // Delegates to HloTraceInstruction::TracingTag.
1712   string TracingTag() const;
1713 
1714   // Delegates to HloFusionInstruction::AddFusionOperand.
1715   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1716 
1717   // Delegates to HloFusionInstruction::MergeFusionInstruction.
1718   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
1719 
1720   // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
1721   void MergeFusionInstructionIntoMultiOutput(
1722       HloInstruction* instruction_to_merge);
1723 
1724   // Delegates to HloFusionInstruction::FuseInstruction.
1725   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
1726 
1727   // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
1728   HloInstruction* FuseInstructionIntoMultiOutput(
1729       HloInstruction* instruction_to_fuse);
1730 
1731   // Delegates to HloFusionInstruction::fused_instruction.
1732   HloComputation* fused_instructions_computation() const;
1733 
1734   // Delegates to HloFusionInstruction::fused_expression_root.
1735   HloInstruction* fused_expression_root() const;
1736 
1737   // Delegates to HloFusionInstruction::fused_instructions.
1738   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1739       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1740   fused_instructions() const;
1741 
1742   const tensorflow::gtl::iterator_range<
1743       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1744   fused_instructions();
1745 
1746   // Delegates to HloFusionInstruction::fused_instruction_count.
1747   int64 fused_instruction_count() const;
1748 
1749   // Delegates to HloFusionInstruction::fused_parameter.
1750   HloInstruction* fused_parameter(int64 parameter_number) const;
1751 
1752   // Delegates to HloFusionInstruction::fused_parameters.
1753   const std::vector<HloInstruction*>& fused_parameters() const;
1754 
1755   // Returns true if this instruction is a fusion instruction that generates
1756   // multiple outputs.
1757   const bool IsMultiOutputFusion() const;
1758 
1759   // Delegates to HloFusionInstruction::fusion_kind.
1760   FusionKind fusion_kind() const;
1761 
1762   // Delegates to HloFusionInstruction::set_fusion_kind.
1763   void set_fusion_kind(FusionKind kind);
1764 
1765   // Delegates to HloRngInstruction::random_distribution.
1766   RandomDistribution random_distribution() const;
1767 
1768   // Delegates to HloParameterInstruction::parameter_number.
1769   int64 parameter_number() const;
1770 
1771   // Delegates to
1772   // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers.
1773   void set_parameter_replicated_at_leaf_buffers(
1774       absl::Span<const bool> parameter_replicated_at_leaf_buffers);
1775   void set_parameter_replicated_at_leaf_buffers(
1776       const std::vector<bool>& parameter_replicated_at_leaf_buffers);
1777 
1778   // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers.
1779   const absl::optional<std::vector<bool>>&
1780   parameter_replicated_at_leaf_buffers() const;
1781 
1782   // Delegates to HloGetTupleElementInstruction::tuple_index.
1783   int64 tuple_index() const;
1784 
1785   // Delegates to HloGetTupleElementInstruction::set_tuple_index.
1786   void set_tuple_index(int64 new_tuple_index);
1787 
1788   // Delegates to HloReducePrecisionInstruction::exponent_bits.
1789   int32 exponent_bits() const;
1790 
1791   // Delegates to HloReducePrecisionInstruction::mantissa_bits.
1792   int32 mantissa_bits() const;
1793 
1794   // Delegates to HloInfeedInstruction::infeed_config.
1795   string infeed_config() const;
1796 
1797   // Delegates to HloInfeedInstruction::set_infeed_config.
1798   void set_infeed_config(const string& config);
1799 
1800   // Returns the config for the Outfeed instruction.
1801   const string& outfeed_config() const;
1802 
1803   // Delegates to HloOutfeedInstruction::set_outfeed_config.
1804   void set_outfeed_config(const string& config);
1805 
1806   // Returns the shape for the Outfeed instruction.
1807   const Shape& outfeed_shape() const;
1808 
1809   // Returns the mutable shape for the Outfeed instruction.
1810   Shape* mutable_outfeed_shape();
1811 
1812   // Delegates to HloCollectiveInstruction::replica_groups.
1813   const std::vector<ReplicaGroup>& replica_groups() const;
1814 
1815   // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
1816   const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
1817 
1818   // Returns data on the window in a windowed operation such as
1819   // convolution.
window()1820   virtual const Window& window() const {
1821     LOG(FATAL) << "Unimplemented method.";
1822   }
1823 
1824   // Sets the window data in a windowed operation such as convolution.
set_window(const Window & window)1825   virtual void set_window(const Window& window) {
1826     LOG(FATAL) << "Unimplemented method.";
1827   }
1828 
1829   // Returns the unique_indices field.
unique_indices()1830   virtual bool unique_indices() const { LOG(FATAL) << "Unimplemented method."; }
1831 
1832   // Returns data on the dimension numbers used for a convolution operation,
1833   // which may be a kConvolution instruction or a kCustomCall that implements a
1834   // convolution.
1835   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
1836 
1837   // Sets the convolution dimension numbers on this instruction.  In general you
1838   // shouldn't need to call this; instead, specify the convolution dimension
1839   // numbers when you create the instruction.
1840   void set_convolution_dimension_numbers(
1841       const ConvolutionDimensionNumbers& dnums);
1842 
1843   // The number of feature groups. Must be a divisor of the input feature
1844   // dimension and output feature dimension.
1845   int64 feature_group_count() const;
1846 
1847   void set_feature_group_count(int64 feature_group_count);
1848 
1849   // The number of batch groups. Must be a divisor of the input batch dimension
1850   int64 batch_group_count() const;
1851 
1852   void set_batch_group_count(int64 batch_group_count);
1853 
1854   // Delegates to HloSelectAndScatterInstruction::select.
1855   HloComputation* select() const;
1856 
1857   // Delegates to HloSelectAndScatterInstruction::scatter.
1858   HloComputation* scatter() const;
1859 
1860   // Delegates to HloSelectAndScatterInstruction::set_select.
1861   void set_select(HloComputation* computation);
1862 
1863   // Delegates to HloSelectAndScatterInstruction::set_scatter.
1864   void set_scatter(HloComputation* computation);
1865 
1866   // Delegates to HloCustomCallInstruction::custom_call_target.
1867   const string& custom_call_target() const;
1868 
1869   // Delegates to HloPadInstruction::padding_config.
1870   const PaddingConfig& padding_config() const;
1871   PaddingConfig* mutable_padding_config();
1872 
1873   // Delegates to HloConvolutionInstruction::padding_type.
1874   PaddingType padding_type() const;
1875 
1876   // Delegates to HloDynamicSliceInstruction::slice_sizes.
1877   int64 slice_sizes(int64 dimension) const;
1878 
1879   // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
1880   const std::vector<int64>& dynamic_slice_sizes() const;
1881 
1882   // Delegates to HloGatherInstruction::gather_dimension_numbers.
1883   const GatherDimensionNumbers& gather_dimension_numbers() const;
1884   // Delegates to HloGatherInstruction::gather_slice_sizes.
1885   absl::Span<const int64> gather_slice_sizes() const;
1886 
1887   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
1888   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
1889 
1890   // Delegates to HloDotInstruction::dot_dimension_numbers().
1891   const DotDimensionNumbers& dot_dimension_numbers() const;
1892 
1893   // Delegates to HloDomainInstruction::operand_side_metadata().
1894   const DomainMetadata& operand_side_metadata() const;
1895 
1896   // Delegates to HloDomainInstruction::user_side_metadata().
1897   const DomainMetadata& user_side_metadata() const;
1898 
1899   // Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
1900   bool is_cross_program_prefetch() const;
1901 
1902   // Delegates to HloCompareInstruction::direction().
1903   ComparisonDirection comparison_direction() const;
1904 
1905   // Delegates to HloTriangularSolveInstruction::triangular_solve_options().
1906   const TriangularSolveOptions& triangular_solve_options() const;
1907 
1908   // Delegates to HloCholeskyInstruction::cholesky_options().
1909   const CholeskyOptions& cholesky_options() const;
1910 
1911   // Appends operand to the list of operands and adds this instruction as a user
1912   // of the operand.
1913   void AppendOperand(HloInstruction* operand);
1914 
1915   // Old methods kept for smooth subclassing transition END.
1916 
1917  protected:
1918   // Indicates how an instruction uses a value (such as an operand).
1919   //
1920   // Does it (a) not use it, (b) use it, or (c) use it multiple times?
1921   //
1922   // In the kUse case (i.e. (b)) we may either (i) use the value elementwise, or
1923   // (ii) use it after having permuted it somehow, e.g. through a reshape.  If
1924   // the use is a permuting use, we set permutation_instr to the instruction
1925   // that did the permuting.
1926   struct UseKind {
1927     enum Kind { kReuse, kUse, kNoUse };
1928 
1929     // Creates a UseKind that represents a use that permutes an instruction's
1930     // elements according to the given instruction.
PermutingUseKind1931     static UseKind Permuting(const HloInstruction* permutation_instr) {
1932       UseKind k(kUse);
1933       k.permutation_instr = permutation_instr;
1934       return k;
1935     }
1936 
UseKindUseKind1937     UseKind(Kind kind)  // NOLINT intentionally nonexplicit
1938         : kind(kind), permutation_instr(nullptr) {}
1939 
1940     bool friend operator==(UseKind a, Kind b) { return a.kind == b; }
1941     bool friend operator==(Kind a, UseKind b) { return b == a; }
1942 
1943     Kind kind;
1944     const HloInstruction* permutation_instr;
1945   };
1946 
1947   // Helper class for computing OperandElementUse for kFusion.
1948   class FusionReusesParamElements;
1949 
1950   // Internal constructor for a given opcode/shape, other fields must be filled
1951   // by factory methods.
1952   HloInstruction(HloOpcode opcode, const Shape& shape);
1953 
RemoveAllOperands()1954   void RemoveAllOperands() { operands_.clear(); }
1955 
RemoveOperandAt(int index)1956   void RemoveOperandAt(int index) {
1957     operands_.erase(operands_.begin() + index);
1958   }
1959 
1960   // Removes a list of operands with the given indices in ascending order.
1961   void RemoveOperandsAtAscendingIndices(
1962       absl::Span<const int> ascending_indices);
1963 
AppendComputation(HloComputation * computation)1964   void AppendComputation(HloComputation* computation) {
1965     called_computations_.push_back(computation);
1966   }
1967 
DetachFrom(HloInstruction * usee)1968   void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
1969 
set_called_computation(int index,HloComputation * computation)1970   void set_called_computation(int index, HloComputation* computation) {
1971     called_computations_[index] = computation;
1972   }
1973   // Indices of computations in called_computations_ for instructions which call
1974   // multiple computations.
1975   enum {
1976     // kWhile computations.
1977     kBodyComputationIndex = 0,
1978     kConditionComputationIndex = 1,
1979 
1980     // kSelectAndScatter computations.
1981     kSelectComputationIndex = 0,
1982     kScatterComputationIndex = 1,
1983 
1984     // kConditional computations.
1985     kTrueComputationIndex = 0,
1986     kFalseComputationIndex = 1,
1987   };
1988 
1989  private:
1990   friend class HloComputation;
1991 
1992   bool IdenticalInternal(
1993       const HloInstruction& other,
1994       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
1995           eq_operands,
1996       const std::function<bool(const HloComputation*, const HloComputation*)>&
1997           eq_computations,
1998       bool layout_sensitive, bool ignore_channel_id_values) const;
1999 
2000   // Implementation for non-common logic of CloneWithNewOperands.
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)2001   virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
2002       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
2003       HloCloneContext* context) const {
2004     // TODO(b/80131774): This should be pure virtual.
2005     LOG(FATAL) << "Unimplemented method.";
2006   }
2007 
2008   // Implementation for non-common logic of ExtraAttributesToString.
ExtraAttributesToStringImpl(const HloPrintOptions & options)2009   virtual std::vector<string> ExtraAttributesToStringImpl(
2010       const HloPrintOptions& options) const {
2011     return {};
2012   }
2013 
2014   // Implementation for IsElementwise if operand_idx is nullopt and for
2015   // IsElementwiseOnOperand if otherwise.
2016   //
2017   // NOTE: For all instructions other than kFusion, being elementwise on one of
2018   // the operands is equivalent to being elementwise on all the operands.
2019   virtual bool IsElementwiseImpl(
2020       const absl::optional<int64>& operand_idx) const;
2021 
2022   // Prints an operand to a string. Accessed by friend class HloInstruction.
2023   virtual string OperandsToStringWithCanonicalNameMap(
2024       const HloPrintOptions& options,
2025       CanonicalNameMap* canonical_name_map) const;
2026 
2027   // See comments on Identical().
2028   virtual bool IdenticalSlowPath(
2029       const HloInstruction& other,
2030       const std::function<bool(const HloComputation*, const HloComputation*)>&
2031           eq_computations) const;
2032 
2033   // Generates a hash value specific to a particular type of an instruction.
2034   // This function typically considers the inner root instruction.
2035   virtual uint64 InnerHash() const;
2036 
2037   // Creates an n-ary elementwise operation.
2038   static std::unique_ptr<HloInstruction> CreateNary(
2039       const Shape& shape, HloOpcode opcode,
2040       absl::Span<HloInstruction* const> operands);
2041 
2042   // Adds a user for this instruction.
2043   void AddUser(HloInstruction* user);
2044 
2045   // Removes a user for this instruction.
2046   void RemoveUser(HloInstruction* user);
2047 
2048   // Returns how this instruction uses elements of its operand at operand_num.
2049   UseKind OperandElementUse(int64 operand_num) const;
2050 
2051   // Helper for implementing backend_config().  Parses backend_config_ into the
2052   // given proto.
2053   Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
2054 
2055   // Mark this instruction as dead. Accessed by friend class HloInstruction.
MarkAsDead()2056   void MarkAsDead() { marked_as_dead_ = true; }
2057 
2058   // Has this instruction been marked as dead? Accessed by friend class
2059   // HloInstruction.
IsMarkedAsDead()2060   bool IsMarkedAsDead() const { return marked_as_dead_; }
2061 
2062   int unique_id_;  // Unique to this HloInstruction within a HloModule
2063 
2064   // Opcode for this instruction.
2065   HloOpcode opcode_;
2066 
2067   // Instruction operands.
2068   InstructionVector operands_;
2069 
2070   // The set of control predecessors of this instruction.
2071   // Note that the order of the instructions in the vector influences the order
2072   // computed in HloComputation::ComputeInstructionPostOrder, which may
2073   // influence the result of the compilation by changing the scheduling. We are
2074   // not sure if it matters.
2075   std::vector<HloInstruction*> control_predecessors_;
2076 
2077   // The users of this instruction. Users are HLOs where this instruction is an
2078   // operand. The vector users_ and the map user_map_ contain identical members.
2079   // The map enables fast membership testing and the vector enables fast, stable
2080   // iteration. The value in the map contains the index of the instruction in
2081   // the vector what enables fast removal.
2082   std::vector<HloInstruction*> users_;
2083   absl::flat_hash_map<const HloInstruction*, int64> user_map_;
2084 
2085   // The set of control successors of this instruction.
2086   std::vector<HloInstruction*> control_successors_;
2087 
2088   // The computation in which this instruction is contained.
2089   HloComputation* parent_ = nullptr;
2090 
2091   // Result shape of this instruction.
2092   Shape shape_;
2093 
2094   // The sharding, if one exists.
2095   // Uses std::shared_ptr to allow reuse of the same sharding object between
2096   // HloInstructions and other components as HloSharding can be very large for
2097   // many element tuples.
2098   std::shared_ptr<const HloSharding> sharding_;
2099 
2100   // Computations called by this instruction.
2101   std::vector<HloComputation*> called_computations_;
2102 
2103   // A trace instruction that consumes this instruction.
2104   //
2105   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
2106   // an operand.
2107   HloInstruction* trace_instruction_ = nullptr;
2108 
2109   // The backend-specific configuration for how a backend should compile this
2110   // HLO. See the documentation on backend_config().
2111   string backend_config_;
2112 
2113   // Attributes passed from the frontend to give hints to the backend about
2114   // how to compile this HLO.
2115   // HLO -> HLO transforms are expected to preserve these attributes on a
2116   // "best effort" basis only.
2117   // For example:
2118   //    x = const(10, frontend_attributes={x}
2119   //    y = const(10, frontend_attributes={y}
2120   //    z = add(x,y), frontend_attributes={y}
2121   // Could be simplified to:
2122   //    z' = const(20), frontend_attributes={?}
2123   FrontendAttributes frontend_attributes_;
2124 
2125   // This field is assigned to true when backend_config_ is assigned to
2126   // a default configuration.
2127   bool is_default_config_ = false;
2128 
2129   // True if this instruction has already been detached from its user and
2130   // operands.
2131   bool cleaned_up_ = false;
2132 
2133   // String identifier for instruction.
2134   string name_;
2135 
2136   // Metadata for debugging.
2137   OpMetadata metadata_;
2138 
2139   // The number of partitions per outer dimension (listed in order from
2140   // outer-most dimension first).
2141   std::vector<int64> outer_dimension_partitions_;
2142 
2143   // Intrusive flag used by HloComputation, whether this instruction has
2144   // been marked as dead.
2145   bool marked_as_dead_;
2146 
2147   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
2148 };
2149 
2150 // Explicit instantiations in hlo_instruction.cc.
2151 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
2152 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
2153 extern template Status HloInstruction::Visit(DfsHloVisitor* visitor);
2154 extern template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor);
2155 
2156 string ToString(HloInstruction::FusionKind kind);
2157 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
2158     const string& kind_name);
2159 
2160 // Custom (de)stringification functions for protos that live inside
2161 // HloInstruction.
2162 string PaddingConfigToString(const PaddingConfig& padding);
2163 string FrontendAttributesToString(
2164     const FrontendAttributes& frontend_attributes);
2165 string RandomAlgorithmToString(const RandomAlgorithm& algorithm);
2166 string RandomDistributionToString(const RandomDistribution& distribution);
2167 string PrecisionToString(const PrecisionConfig::Precision& precision);
2168 string ConvolutionDimensionNumbersToString(
2169     const ConvolutionDimensionNumbers& dnums);
2170 string ReplicaGroupsToString(const std::vector<ReplicaGroup>& replica_groups);
2171 
2172 StatusOr<RandomAlgorithm> StringToRandomAlgorithm(const string& name);
2173 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
2174 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
2175 
2176 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
2177 
2178 // Map classes that guarantee a deterministic iteration order when the key is
2179 // an HloInstruction* or a const HloInstruction*.
2180 // To make the iteration order over the map deterministic, the comparator
2181 // should not be using the pointer values, but rather an intrinsic property of
2182 // the hlo. Exception: null pointer values compare less than non-null.
2183 struct HloPtrComparator {
2184   bool operator()(const HloInstruction* const& lhs,
2185                   const HloInstruction* const& rhs) const;
2186 };
2187 
2188 template <typename ValueT>
2189 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
2190 
2191 template <typename ValueT>
2192 using ConstHloInstructionMap =
2193     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
2194 
2195 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
2196 using ConstHloInstructionSet =
2197     std::set<const HloInstruction*, HloPtrComparator>;
2198 
2199 }  // namespace xla
2200 
2201 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
2202