1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO instructions are in DAG form and represent the computations that the user
17 // has built up via the XLA service interface. They are ultimately lowered
18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
19 // form; e.g. see DfsHloVisitor.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23 
24 #include <functional>
25 #include <iosfwd>
26 #include <list>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <tuple>
31 #include <vector>
32 
33 #include "absl/container/flat_hash_map.h"
34 #include "absl/container/flat_hash_set.h"
35 #include "absl/container/inlined_vector.h"
36 #include "absl/memory/memory.h"
37 #include "absl/strings/str_cat.h"
38 #include "absl/strings/string_view.h"
39 #include "absl/types/span.h"
40 #include "tensorflow/compiler/xla/comparison_util.h"
41 #include "tensorflow/compiler/xla/iterator_util.h"
42 #include "tensorflow/compiler/xla/literal.h"
43 #include "tensorflow/compiler/xla/map_util.h"
44 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
45 #include "tensorflow/compiler/xla/service/hlo.pb.h"
46 #include "tensorflow/compiler/xla/service/hlo_clone_context.h"
47 #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
48 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
49 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
50 #include "tensorflow/compiler/xla/service/name_uniquer.h"
51 #include "tensorflow/compiler/xla/shape_tree.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/xla_data.pb.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/iterator_range.h"
56 #include "tensorflow/core/platform/logging.h"
57 #include "tensorflow/core/platform/macros.h"
58 #include "tensorflow/core/platform/protobuf.h"
59 #include "tensorflow/core/platform/types.h"
60 
61 namespace xla {
62 
63 class HloComputation;
64 class HloModule;
65 
66 // A bunch of switches that control how the hlo text should be printed.
67 class HloPrintOptions {
68  public:
69   enum class PrintSubcomputationMode {
70     kOff,         // Do not print anything about subcomputations.
71     kNameOnly,    // Only print the name of subcomputations.
72     kFullBodies,  // Print the full bodies of subcomputations.
73   };
74 
75   // Constructs the default print options: don't print large constants, don't
76   // compact operands, no indentation.
HloPrintOptions()77   HloPrintOptions()
78       : print_large_constants_(false),
79         print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
80         print_metadata_(true),
81         print_backend_config_(true),
82         compact_operands_(false),
83         print_operand_shape_(true),
84         print_operand_names_(true),
85         print_program_shape_(true),
86         print_percent_(true),
87         print_control_dependencies_(true),
88         canonicalize_instruction_names_(false),
89         indent_amount_(0),
90         is_in_nested_computation_(false) {}
91 
ShortParsable()92   static HloPrintOptions ShortParsable() {
93     return HloPrintOptions()
94         .set_print_large_constants(true)
95         .set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
96         .set_print_metadata(false)
97         .set_print_backend_config(false)
98         .set_print_operand_shape(false)
99         .set_print_program_shape(false)
100         .set_print_percent(false)
101         .set_print_control_dependencies(false);
102   }
103 
104   // Options to produce the canonical string representing an isomorphic
105   // computation graph.
Canonical()106   static HloPrintOptions Canonical() {
107     return HloPrintOptions()
108         .set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
109         .set_print_metadata(false)
110         .set_print_backend_config(false)
111         .set_compact_operands(true)
112         .set_print_operand_names(false)
113         .set_print_operand_shape(true)
114         .set_print_program_shape(false)
115         .set_print_percent(false)
116         .set_print_control_dependencies(false)
117         .set_canonicalize_instruction_names(true);
118   }
119 
120   // If true, large constants will be printed out.
set_print_large_constants(bool value)121   HloPrintOptions& set_print_large_constants(bool value) {
122     print_large_constants_ = value;
123     return *this;
124   }
125 
set_print_subcomputation_mode(PrintSubcomputationMode value)126   HloPrintOptions& set_print_subcomputation_mode(
127       PrintSubcomputationMode value) {
128     print_subcomputation_mode_ = value;
129     return *this;
130   }
131 
132   // If true, metadata will be printed.
set_print_metadata(bool value)133   HloPrintOptions& set_print_metadata(bool value) {
134     print_metadata_ = value;
135     return *this;
136   }
137 
138   // If true, backend_config will be printed.
set_print_backend_config(bool value)139   HloPrintOptions& set_print_backend_config(bool value) {
140     print_backend_config_ = value;
141     return *this;
142   }
143 
144   // If true, operands' shapes will be printed.
set_print_operand_shape(bool value)145   HloPrintOptions& set_print_operand_shape(bool value) {
146     print_operand_shape_ = value;
147     return *this;
148   }
149 
150   // If true, the operand names will be printed.
set_print_operand_names(bool value)151   HloPrintOptions& set_print_operand_names(bool value) {
152     print_operand_names_ = value;
153     return *this;
154   }
155 
156   // If true, program shape of hlo computations will be printed.
set_print_program_shape(bool value)157   HloPrintOptions& set_print_program_shape(bool value) {
158     print_program_shape_ = value;
159     return *this;
160   }
161 
162   // If true, names will be printed with prefix '%'.
set_print_percent(bool value)163   HloPrintOptions& set_print_percent(bool value) {
164     print_percent_ = value;
165     return *this;
166   }
167 
168   // If true, control dependencies will be printed.
set_print_control_dependencies(bool value)169   HloPrintOptions& set_print_control_dependencies(bool value) {
170     print_control_dependencies_ = value;
171     return *this;
172   }
173 
174   // If true, only a part of operands will be printed out (note that in this
175   // case the text will not be parsable).
set_compact_operands(bool value)176   HloPrintOptions& set_compact_operands(bool value) {
177     compact_operands_ = value;
178     return *this;
179   }
180 
181   // If true, canonicalizes instructions' name. Instead of using "%foo.1" as
182   // the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
set_canonicalize_instruction_names(bool value)183   HloPrintOptions& set_canonicalize_instruction_names(bool value) {
184     canonicalize_instruction_names_ = value;
185     return *this;
186   }
187 
188   // The indent of the hlo text block.
set_indent_amount(int value)189   HloPrintOptions& set_indent_amount(int value) {
190     indent_amount_ = value;
191     return *this;
192   }
193 
194   // If true, indicates the instruction being printed is inside a nested
195   // computation.
set_is_in_nested_computation(bool value)196   HloPrintOptions& set_is_in_nested_computation(bool value) {
197     is_in_nested_computation_ = value;
198     return *this;
199   }
200 
print_large_constants()201   bool print_large_constants() const { return print_large_constants_; }
print_subcomputation_mode()202   PrintSubcomputationMode print_subcomputation_mode() const {
203     return print_subcomputation_mode_;
204   }
print_metadata()205   bool print_metadata() const { return print_metadata_; }
print_backend_config()206   bool print_backend_config() const { return print_backend_config_; }
compact_operands()207   bool compact_operands() const { return compact_operands_; }
print_operand_shape()208   bool print_operand_shape() const { return print_operand_shape_; }
print_operand_names()209   bool print_operand_names() const { return print_operand_names_; }
print_program_shape()210   bool print_program_shape() const { return print_program_shape_; }
print_percent()211   bool print_percent() const { return print_percent_; }
print_control_dependencies()212   bool print_control_dependencies() const {
213     return print_control_dependencies_;
214   }
canonicalize_instruction_names()215   bool canonicalize_instruction_names() const {
216     return canonicalize_instruction_names_;
217   }
indent_amount()218   int indent_amount() const { return indent_amount_; }
is_in_nested_computation()219   int is_in_nested_computation() const { return is_in_nested_computation_; }
220 
221  private:
222   bool print_large_constants_;
223   PrintSubcomputationMode print_subcomputation_mode_;
224   bool print_metadata_;
225   bool print_backend_config_;
226   bool compact_operands_;
227   bool print_operand_shape_;
228   bool print_operand_names_;
229   bool print_program_shape_;
230   bool print_percent_;
231   bool print_control_dependencies_;
232   bool canonicalize_instruction_names_;
233   int indent_amount_;
234   bool is_in_nested_computation_;
235 };
236 
237 // For canonical string output, we need to have a canonical way to rename
238 // each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
239 // where <xxx> is an index starting from 0.
240 class CanonicalNameMap {
241  public:
CanonicalNameMap()242   CanonicalNameMap() : index(0) {}
243 
LookupOrInsert(const string & old_name)244   string LookupOrInsert(const string& old_name) {
245     auto iter = canonical_name_map.find(old_name);
246     if (iter != canonical_name_map.end()) {
247       return iter->second;
248     }
249 
250     string new_name = absl::StrCat("tmp_", index++);
251     canonical_name_map[old_name] = new_name;
252     return new_name;
253   }
Clear()254   void Clear() {
255     canonical_name_map.clear();
256     index = 0;
257   }
258 
259  private:
260   int64 index;
261   absl::flat_hash_map<string, string> canonical_name_map;
262 };
263 
264 // HLO instructions are the atomic unit of the high-level compiler's IR.
265 //
266 // HloInstructions live inside of an HloComputation, which is analogous to a
267 // function in other programming languages.  Nodes have no total order within
268 // their computation.  Instead, they have a partial ordering determined by their
269 // data and control dependencies.
270 //
271 // HLO does not have basic blocks or explicit "branch" instructions.  Instead,
272 // certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
273 // control flow.  For example, the kConditional HLO executes one of two possible
274 // computations, depending on the runtime value of a predicate.
275 //
276 // HLO is pure (mostly).  It has no concept of mutable state.  Instead, data
277 // values are produced by one HLO and flow into consumers across dependency
278 // edges.
279 class HloInstruction {
280  public:
281   // A fusion node computes the same value a call to its fusion computation
282   // would compute.  However, the choice of fusion kind dictates codegen
283   // strategy for the backend.
284   //
285   // To generate code for a kFusion HloInstruction, most backends do something
286   // like the following:
287   //
288   // 1) Identify the "primary" HloInstruction of the fused computation.
289   // 2) Emit code that does the work of the primary node, creating its inputs
290   //    and transforming its outputs as specified by the fused computation.
291   //
292   // In step (2), the code emitted is usually similar to the code that would be
293   // emitted for an *unfused* version of the primary node, except that
294   //
295   //  - when the primary node reads an element of one of its operands, instead
296   //    of loading the value from memory, it *computes* the value based on the
297   //    contents of the fused computation.
298   //  - when the primary node outputs a value, instead of storing it to memory,
299   //    it forwards the value to its users, which then perform additional
300   //    computations before the value is finally stored to memory at the root of
301   //    the fusion node.
302   //
303   // An HloInstruction's FusionKind helps us find the kFusion instruction's
304   // primary node, and can also affect how we generate code in step (2).
305   //
306   //  - kInput: The primary node is the root of the fused instruction.
307   //
308   //  - kOutput: The primary node is not the root of the fused instruction.
309   //    This fusion kind requires that one operand buffer of the fusion
310   //    instruction be able to alias the output buffer.  This constraint is
311   //    usually enough to let backends find the primary node unambiguously.
312   //
313   //  - kLoop: The primary node is the root of the fused computation, but,
314   //    unlike in input fusion, we prescribe a specific implementation for
315   //    codegen.  Rather than generating code that looks like the code we'd emit
316   //    for an unfused version of the primary/root node, we emit code that
317   //    generates one element of the root at a time.
318   //
319   //  - kCustom: Custom category for backend-specific fusions that don't fit
320   //    into the above patterns.
321   //
322   // Not all backends support all fusion kinds, and given a particular fused
323   // computation, it's not in general safe to change its fusion kind.  Creation
324   // of fusion nodes is always backend-specific.
325   //
326   // For elementwise ops (e.g. kAdd), most backends would emit a
327   // one-element-at-a-time implementation for the unfused version, so loop
328   // fusion and input fusion are probably equivalent if the root node is
329   // elementwise.  They're not necessarily equivalent e.g. for kReduce, where an
330   // implementation might emit something more sophisticated for an unfused or
331   // input-fusion reduce, but will emit the naive code that reduces one element
332   // at a time for loop fusion with a reduce as the root.
333   //
334   // Another way to think of loop fusion is that it's equivalent to input
335   // fusion, but where the root node is an implicit identity node, whose
336   // unfused implementation is "read one element, write one element".
337   //
338   // TODO(b/79869434): This categorization scheme is not great.  For one thing,
339   // input and loop fusion are basically the same thing: There is no reason for
340   // the HLO to encode backend-specific decisions about how e.g. a reduce that's
341   // the root of a fusion should be lowered.  In addition, this scheme as
342   // written doesn't work for multi-output fusion, where the primary node is
343   // never actually the root (which is a kTuple instruction that gathers the
344   // multiple outputs of the fusion).
345   enum class FusionKind {
346     kLoop,
347     kInput,
348     kOutput,
349     kCustom,
350   };
351 
352   virtual ~HloInstruction();
353 
354   // Creates an instruction from the given proto. Arguments:
355   //
356   //   proto: the proto to convert from.
357   //   instruction_map: a map from instruction id to HloInstruction*. This map
358   //     must contain all operands of the newly constructed instruction.
359   //   computation_map: a map from computation id to HloComputation*. This map
360   //     must contain all computations which the newly constructed instruction
361   //     calls.
362   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
363       const HloInstructionProto& proto,
364       const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
365       const absl::flat_hash_map<int64, HloComputation*>& computation_map);
366 
367   // Creates a parameter-retrieving instruction.
368   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
369                                                          const Shape& shape,
370                                                          const string& name);
371 
372   // Creates a literal constant instruction.
373   static std::unique_ptr<HloInstruction> CreateConstant(Literal literal);
374 
375   // Creates an Iota instruction.
376   static std::unique_ptr<HloInstruction> CreateIota(const Shape& shape,
377                                                     int64 iota_dimension);
378 
379   // Creates a get tuple element instruction.
380   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
381       const Shape& shape, HloInstruction* operand, int64 index);
382 
383   // Creates a trace instruction that logs the input operand in the computation.
384   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
385                                                      HloInstruction* operand);
386 
387   // Creates a random number generation instruction that fills a shape with
388   // random numbers from a given distribution.
389   //
390   // The parameters to the instruction are interpreted as follows:
391   //
392   //  - If `distribution` is RNG_UNIFORM, generates a number in range
393   //    [param0, param1).
394   //
395   //  - If `distribution` is RNG_NORMAL, generates a normally-distributed value
396   //    with mean `param0` and standard deviation `param1`.
397   static std::unique_ptr<HloInstruction> CreateRng(
398       const Shape& shape, RandomDistribution distribution,
399       absl::Span<HloInstruction* const> parameters);
400 
401   // Creates a unary instruction (one operand).
402   // Precondition: opcode must be a legitimate unary operation.
403   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
404                                                      HloOpcode opcode,
405                                                      HloInstruction* operand);
406 
407   // Creates a binary instruction (two operands).
408   // Precondition: opcode must be a legitimate binary operation.
409   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
410                                                       HloOpcode opcode,
411                                                       HloInstruction* lhs,
412                                                       HloInstruction* rhs);
413 
414   // Creates a ternary instruction (three operands).
415   // Precondition: opcode must be a legitimate ternary operation.
416   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
417                                                        HloOpcode opcode,
418                                                        HloInstruction* lhs,
419                                                        HloInstruction* rhs,
420                                                        HloInstruction* ehs);
421 
422   // Creates a variadic instruction (variable number of operands).
423   // Precondition: opcode must be a legitimate variadic operation.
424   static std::unique_ptr<HloInstruction> CreateVariadic(
425       const Shape& shape, HloOpcode opcode,
426       absl::Span<HloInstruction* const> operands);
427 
428   // Creates a map instruction, where the computation (given by the handle) is
429   // applied element-wise to every element in operands (across the operands,
430   // at a given index)
431   static std::unique_ptr<HloInstruction> CreateMap(
432       const Shape& shape, absl::Span<HloInstruction* const> operands,
433       HloComputation* map_computation);
434 
435   // Creates a convolution op, where rhs is the convolutional filter
436   // and window describes how the filter is applied to lhs.
437   static std::unique_ptr<HloInstruction> CreateConvolve(
438       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
439       int64 feature_group_count, int64 batch_group_count, const Window& window,
440       const ConvolutionDimensionNumbers& dimension_numbers,
441       const PrecisionConfig& precision_config);
442 
443   // Creates an FFT op, of the type indicated by fft_type.
444   static std::unique_ptr<HloInstruction> CreateFft(
445       const Shape& shape, HloInstruction* operand, FftType fft_type,
446       absl::Span<const int64> fft_length);
447 
448   // Creates a compare op, performing the comparison specified in direction.
449   static std::unique_ptr<HloInstruction> CreateCompare(
450       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
451       ComparisonDirection direction);
452 
453   static std::unique_ptr<HloInstruction> CreateTriangularSolve(
454       const Shape& shape, HloInstruction* a, HloInstruction* b,
455       const TriangularSolveOptions& options);
456 
457   static std::unique_ptr<HloInstruction> CreateCholesky(
458       const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
459 
460   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
461   // dimensions specified in 'dimension_numbers'.
462   static std::unique_ptr<HloInstruction> CreateDot(
463       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
464       const DotDimensionNumbers& dimension_numbers,
465       const PrecisionConfig& precision_config);
466 
467   // Creates a reduce-precision op, where operand is the data to reduce in
468   // precision, and exponent_bits and mantissa_bits describe the precision to
469   // reduce it to.
470   static std::unique_ptr<HloInstruction> CreateReducePrecision(
471       const Shape& shape, HloInstruction* operand, const int exponent_bits,
472       const int mantissa_bits);
473 
474   // Creates a cross replica reduction op.
475   //
476   // `reduction_computation`: the reduction function.
477   //
478   // `replica_groups`: each ReplicaGroup contains a list of replica id. If
479   // empty, all replicas belong to one group in the order of 0 - (n-1).
480   // Allreduce will be applied within subgroups.
481   // For example, we have 4 replicas, then replica_groups={{0,2},{1,3}} means,
482   // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
483   //
484   // `all_reduce_id`: for Allreduce nodes from different modules, if they have
485   // the same all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will
486   // not be applied cross modules.
487   static std::unique_ptr<HloInstruction> CreateAllReduce(
488       const Shape& shape, absl::Span<HloInstruction* const> operands,
489       HloComputation* reduce_computation,
490       const std::vector<ReplicaGroup>& replica_groups,
491       absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
492 
493   // This op handles the communication of an Alltoall operation. On each core,
494   // the operands are N ops in the same shape, where N is the number of cores
495   // participating the Alltoall. Then the N operands are scattered to N cores,
496   // e.g., the ith operand is sent to the ith core. Then each core gathers the
497   // received data into a tuple.
498   //
499   // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
500   // empty, all replicas belong to one group in the order of 0 - (n-1). Alltoall
501   // will be applied within subgroups in the specified order. For example,
502   // replica groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied
503   // within replica 1, 2, 3, and in the gather phase, the received blocks will
504   // be concatenated in the order of 1, 2, 3; another Alltoall will be applied
505   // within replica 4, 5, 0, and the concatenation order is 4, 5, 0.
506   static std::unique_ptr<HloInstruction> CreateAllToAll(
507       const Shape& shape, absl::Span<HloInstruction* const> operands,
508       const std::vector<ReplicaGroup>& replica_groups);
509 
510   // Creates a communitation instructions that permutes data cross replicas.
511   // Data is sent/received according to the (source_replica_id,
512   // target_replica_id) pairs in `source_target_pairs`. If a replica id is not a
513   // target_replica_id in any pair, the output on that replica is a tensor
514   // consists of 0(s) in `shape`.
515   static std::unique_ptr<HloInstruction> CreateCollectivePermute(
516       const Shape& shape, HloInstruction* operand,
517       const std::vector<std::pair<int64, int64>>& source_target_pairs);
518 
519   // Creates an instruction that returns a U32 replica ID.
520   static std::unique_ptr<HloInstruction> CreateReplicaId();
521 
522   // Creates a conversion instruction, where operand is the data to convert and
523   // shape is the target shape for the conversion.
524   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
525                                                        HloInstruction* operand);
526 
527   // Creates a bitcast conversion instruction, where operand is the data to
528   // convert and shape is the target shape for the conversion.
529   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
530       const Shape& shape, HloInstruction* operand);
531 
532   // Creates an infeed instruction, which reads data of the given shape from the
533   // Infeed interface of the device. infeed_shape is the shape of the data
534   // received from the infeed *not* the shape of the infeed instruction which
535   // is a tuple containing the infeed_shape and the TOKEN.
536   static std::unique_ptr<HloInstruction> CreateInfeed(
537       const Shape& infeed_shape, HloInstruction* token_operand,
538       const string& config);
539 
540   // Creates an outfeed instruction, which outputs data. outfeed_shape is the
541   // shape of the data being outfed *not* the shape of the outfeed instruction
542   // which is a TOKEN.
543   static std::unique_ptr<HloInstruction> CreateOutfeed(
544       const Shape& outfeed_shape, HloInstruction* operand,
545       HloInstruction* token_operand, absl::string_view outfeed_config);
546 
547   // Creates an asynchronous send instruction with the given channel id, which
548   // initiates sending the operand data to a unique receive instruction in
549   // another computation that has the same channel id. If is_host_transfer is
550   // true, then this Send operation transfers data to the host.
551   static std::unique_ptr<HloInstruction> CreateSend(
552       HloInstruction* operand, HloInstruction* token, int64 channel_id,
553       bool is_host_transfer = false);
554 
555   // Blocks until data transfer for the Send instruction (operand) is complete.
556   // The operand must be kSend.
557   static std::unique_ptr<HloInstruction> CreateSendDone(
558       HloInstruction* operand, bool is_host_transfer = false);
559 
560   // Creates an asynchronous receive instruction with the given channel id,
561   // which allocates resources to receive data of the given shape from a unique
562   // send instruction in another computation that has the same channel id.  If
563   // is_host_transfer is true, then this Send operation transfers data from the
564   // host.
565   static std::unique_ptr<HloInstruction> CreateRecv(
566       const Shape& shape, HloInstruction* token, int64 channel_id,
567       bool is_host_transfer = false);
568 
569   // Blocks until data transfer for the Recv instruction (operand) is complete
570   // and returns the receive buffer. The operand must be kRecv.
571   static std::unique_ptr<HloInstruction> CreateRecvDone(
572       HloInstruction* operand, bool is_host_transfer = false);
573 
574   // Creates a slice instruction, where the operand is sliced by the given
575   // start/limit indices.
576   static std::unique_ptr<HloInstruction> CreateSlice(
577       const Shape& shape, HloInstruction* operand,
578       absl::Span<const int64> start_indices,
579       absl::Span<const int64> limit_indices, absl::Span<const int64> strides);
580 
581   // Creates a slice instruction, where the first operand is sliced by
582   // start indices specified in the second operand, and by size specified in
583   // 'slice_sizes'.
584   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
585       const Shape& shape, HloInstruction* operand,
586       absl::Span<HloInstruction* const> start_indices,
587       absl::Span<const int64> slice_sizes);
588 
589   // Creates a dynamic update slice instruction, which updates a slice
590   // of 'operand' with 'update' and 'start_indices'.
591   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
592       const Shape& shape, HloInstruction* operand, HloInstruction* update,
593       absl::Span<HloInstruction* const> start_indices);
594 
595   // Creates a concatenate instruction, where the operands are concatenated on
596   // the provided dimension.
597   static std::unique_ptr<HloInstruction> CreateConcatenate(
598       const Shape& shape, absl::Span<HloInstruction* const> operands,
599       int64 dimension);
600 
601   // Creates a reduce instruction, where the computation (given by the handle)
602   // is applied successively to every element in operand. For example, let f be
603   // the function to apply, which takes 2 arguments, an accumulator and the
604   // current value. Let init be an initial value (which is normally chosen to be
605   // the identity element for f, e.g. 0 if f is addition).
606   // Then the reduce HLO will compute:
607   // f(f(init, value0), value1), ...)
608   static std::unique_ptr<HloInstruction> CreateReduce(
609       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
610       absl::Span<const int64> dimensions_to_reduce,
611       HloComputation* reduce_computation);
612 
613   // A more general, multiple-argument version of the above.
614   // The function to apply, f, now takes N arguments:
615   // [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
616   // init_valueN], and returns an N-tuple. The performed computation is (for
617   // commutative and associative f operators) equivalent to:
618   //
619   // f_1 = f(init0, ...  initN, input0.value0, ..., inputN.value0)
620   // f_2 = f(f_1.tuple_element(0), ..., f_1.tuple_element(N), input0.value1,
621   // ..., inputN.value1)
622   // ...
623   static std::unique_ptr<HloInstruction> CreateReduce(
624       const Shape& shape, absl::Span<HloInstruction* const> operands,
625       absl::Span<HloInstruction* const> init_values,
626       absl::Span<const int64> dimensions_to_reduce,
627       HloComputation* reduce_computation);
628 
629   // Creates a reduce-window instruction, where the computation (given
630   // by the handle) is applied window-wise at each valid window
631   // position in the operand.
632   static std::unique_ptr<HloInstruction> CreateReduceWindow(
633       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
634       const Window& window, HloComputation* reduce_computation);
635 
636   // Creates a batch-norm-training instruction.
637   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
638       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
639       HloInstruction* offset, float epsilon, int64 feature_index);
640 
641   // Creates a batch-norm-inference instruction.
642   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
643       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
644       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
645       float epsilon, int64 feature_index);
646 
647   // Creates a batch-norm-grad instruction.
648   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
649       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
650       HloInstruction* mean, HloInstruction* variance,
651       HloInstruction* grad_output, float epsilon, int64 feature_index);
652 
653   // Creates a scatter computation that scatters the `source` array to the
654   // selected indices of each window.
655   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
656       const Shape& shape, HloInstruction* operand, HloComputation* select,
657       const Window& window, HloInstruction* source, HloInstruction* init_value,
658       HloComputation* scatter);
659 
660   // Creates a broadcast instruction.
661   static std::unique_ptr<HloInstruction> CreateBroadcast(
662       const Shape& shape, HloInstruction* operand,
663       absl::Span<const int64> broadcast_dimensions);
664 
665   // Creates a sequence of instructions that performs an explicit broadcast of
666   // the operand to the target shape.
667   //
668   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
669   // returned as a unique_ptr for API consistency with other factory methods in
670   // this interface.
671   //
672   // TODO(b/72173833) Ideally HloComputations would always be present, and so
673   // the adder being passed by the caller would not be necessary.
674   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
675       const Shape& output_shape, HloInstruction* operand,
676       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
677           adder);
678 
679   // Creates a pad instruction, where the operand is padded on the edges and
680   // between the elements with the given padding value.
681   static std::unique_ptr<HloInstruction> CreatePad(
682       const Shape& shape, HloInstruction* operand,
683       HloInstruction* padding_value, const PaddingConfig& padding_config);
684 
685   // Creates a reshape instruction, where the operand is flattened row-major
686   // order and then reshaped to the given result shape.
687   static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
688                                                        HloInstruction* operand);
689 
690   // Creates a transpose instruction which permutes the operand dimensions.
691   static std::unique_ptr<HloInstruction> CreateTranspose(
692       const Shape& shape, HloInstruction* operand,
693       absl::Span<const int64> dimensions);
694 
695   // Creates a n-ary sort op with a 'compare' computation which is used for
696   // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
697   // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
698   // specific index positions which should be compared, and should return a
699   // PRED. 'is_stable' specifies whether stable sorting is required.
700   static std::unique_ptr<HloInstruction> CreateSort(
701       const Shape& shape, int64 dimension,
702       absl::Span<HloInstruction* const> operands, HloComputation* compare,
703       bool is_stable);
704 
705   // Creates a while instruction, given a condition computation, a body
706   // computation, and the initial value for the input of the computations. For
707   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
708   // corresponds to the C code below.
709   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
710   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
711                                                      HloComputation* condition,
712                                                      HloComputation* body,
713                                                      HloInstruction* init);
714 
715   static std::unique_ptr<HloInstruction> CreateConditional(
716       const Shape& shape, HloInstruction* pred,
717       HloInstruction* true_computation_arg, HloComputation* true_computation,
718       HloInstruction* false_computation_arg, HloComputation* false_computation);
719 
720   static std::unique_ptr<HloInstruction> CreateConditional(
721       const Shape& shape, HloInstruction* branch_index,
722       absl::Span<HloComputation* const> branch_computations,
723       absl::Span<HloInstruction* const> branch_computation_args);
724 
725   static std::unique_ptr<HloInstruction> CreateGather(
726       const Shape& shape, HloInstruction* operand,
727       HloInstruction* start_indices,
728       const GatherDimensionNumbers& gather_dim_numbers,
729       absl::Span<const int64> slice_sizes);
730 
731   static std::unique_ptr<HloInstruction> CreateScatter(
732       const Shape& shape, HloInstruction* operand,
733       HloInstruction* scatter_indices, HloInstruction* updates,
734       HloComputation* update_computation,
735       const ScatterDimensionNumbers& scatter_dim_numbers);
736 
737   // Creates a kDomain instruction which delimits an HLO domain which have
738   // the provided user and operand side metadata.
739   static std::unique_ptr<HloInstruction> CreateDomain(
740       const Shape& shape, HloInstruction* operand,
741       std::unique_ptr<DomainMetadata> operand_side_metadata,
742       std::unique_ptr<DomainMetadata> user_side_metadata);
743 
744   // Creates a fusion instruction. A fusion instruction contains one or more
745   // fused instructions forming an expression with a single root
746   // "fused_root". Additional instructions can be added to the fusion
747   // instruction with the method FuseInstruction.
748   static std::unique_ptr<HloInstruction> CreateFusion(
749       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
750 
751   static std::unique_ptr<HloInstruction> CreateFusion(
752       const Shape& shape, FusionKind fusion_kind,
753       absl::Span<HloInstruction* const> operands,
754       HloComputation* fusion_computation);
755 
756   // Creates a call instruction that applies the given computation on the given
757   // operands. "shape" is the resultant shape.
758   static std::unique_ptr<HloInstruction> CreateCall(
759       const Shape& shape, absl::Span<HloInstruction* const> operands,
760       HloComputation* computation);
761 
762   // Creates a custom call instruction that applies the given custom call target
763   // to the given operands. "opaque" can be an arbitrary string with a
764   // backend-specific interpretation. "shape" is the resultant shape.
765   static std::unique_ptr<HloInstruction> CreateCustomCall(
766       const Shape& shape, absl::Span<HloInstruction* const> operands,
767       absl::string_view custom_call_target, absl::string_view opaque = "");
768 
769   // Overload which constrains the layouts of the operand and result. 'shape'
770   // and 'operand_shapes_with_layout' must have layouts.
771   // 'operand_shapes_with_layout' must have a compatible element for each
772   // operand.
773   static std::unique_ptr<HloInstruction> CreateCustomCall(
774       const Shape& shape, absl::Span<HloInstruction* const> operands,
775       absl::string_view custom_call_target,
776       absl::Span<const Shape> operand_shapes_with_layout,
777       absl::string_view opaque = "");
778 
779   // Creates a tuple instruction with the given elements. This is a convenience
780   // wrapper around CreateVariadic.
781   static std::unique_ptr<HloInstruction> CreateTuple(
782       absl::Span<HloInstruction* const> elements);
783 
784   // Creates a reverse instruction, which reverses the order of the elements
785   // in the specified dimensions.
786   static std::unique_ptr<HloInstruction> CreateReverse(
787       const Shape& shape, HloInstruction* operand,
788       absl::Span<const int64> dimensions);
789 
790   // Creates a Afterall instruction used for joining or creating new values of
791   // token type which thread through side-effecting operations. Operands must
792   // all be tokens, and there must be at least one operand.
793   static std::unique_ptr<HloInstruction> CreateAfterAll(
794       absl::Span<HloInstruction* const> operands);
795 
796   // Creates an AfterAll instruction which creates a token type out of thin air
797   // (no operands). This is a separate method from CreateAfterAll to facility
798   // the removal of operand-less AfterAll instructions.
799   // TODO(b/110532604): Remove this capability of creating a token from nothing
800   // when we plumb a primordial token from the entry computation.
801   static std::unique_ptr<HloInstruction> CreateToken();
802 
803   static std::unique_ptr<HloInstruction> CreateGetDimensionSize(
804       const Shape& shape, HloInstruction* operand, int64 dimension);
805 
806   static std::unique_ptr<HloInstruction> CreateAddDependency(
807       HloInstruction* data_operand, HloInstruction* token_operand);
808 
809   // Returns the opcode for this instruction.
opcode()810   HloOpcode opcode() const { return opcode_; }
811 
812   // Returns true if this instruction has a side effect, irrespective of whether
813   // any called computations may contain an instruction with side effects.
814   bool HasSideEffectNoRecurse() const;
815 
816   // Returns true if this instruction has a side effect. An instruction has a
817   // side effect if it uses certain opcodes or calls a computation with a side
818   // effect.
819   bool HasSideEffect() const;
820 
821   // Returns the result shape of this instruction.
822   const Shape& shape() const;
823 
824   // Returns the (mutable) result shape of this instruction.
mutable_shape()825   Shape* mutable_shape() { return &shape_; }
826 
827   // Returns the ith operand to this instruction.
828   const HloInstruction* operand(int64 i) const;
829 
830   // Returns the ith operand to this instruction.
831   HloInstruction* mutable_operand(int64 i);
832 
833   // Returns the number of operands to this instruction.
operand_count()834   int64 operand_count() const { return operands_.size(); }
835 
836   // Returns the vector of operands of this instruction.
837   using InstructionVector = absl::InlinedVector<HloInstruction*, 2>;
operands()838   const InstructionVector& operands() const { return operands_; }
839 
840   // Returns the vector of unique operands, in the same order they are found
841   // within the operand vector.
842   InstructionVector unique_operands() const;
843 
844   // Returns the index of 'target' in the operands sequence.
845   // Precondition: target must be an operand (or a fatal error will occur).
846   int64 operand_index(const HloInstruction* target) const;
847 
848   // Returns the number of users of this instruction.
user_count()849   int64 user_count() const { return users_.size(); }
850 
851   // Returns the users of this instruction.
users()852   const std::vector<HloInstruction*>& users() const { return users_; }
853 
854   // Returns true if this instruction is a user of 'instruction'.
IsUserOf(const HloInstruction * instruction)855   bool IsUserOf(const HloInstruction* instruction) const {
856     return ContainsKey(instruction->user_set_, this);
857   }
858 
859   // Adds a control dependency from this instruction to the given
860   // instruction. This instruction becomes a control predecessor of
861   // 'instruction', and 'instruction' becomes a control successor of this
862   // instruction. Returns an error status if either of the given instructions
863   // does not belong to the same computation.
864   //
865   // This is used to enforce an additional ordering requirement that is not
866   // captured by normal data dependencies, such as ordering among Send or Recv
867   // operations to avoid deadlock.
868   Status AddControlDependencyTo(HloInstruction* instruction);
869 
870   // Removes a previously added control dependency from this instruction to
871   // 'instruction'.
872   Status RemoveControlDependencyTo(HloInstruction* instruction);
873 
874   // Drops all control predecessors and successors from this HLO instruction.
875   Status DropAllControlDeps();
876 
877   // Copies the control predecessors and successors on this HLO instruction to
878   // `inst`.  Does not do a deep copy so this makes sense only if `inst` and
879   // this HLO are in the same module.
880   //
881   // Depending on the use cases we see in practice, in the future we may
882   // consider folding the logic here into Clone, CloneWithNewOperands and
883   // ReplaceAllUsesWith by treating control dependencies like data dependencies.
884   Status CopyAllControlDepsFrom(const HloInstruction* inst);
885 
886   // Returns the set of control predecessors (successors) of this
887   // instruction. Control predecessors (successors) must execute before (after)
888   // the current instruction.
control_predecessors()889   const std::vector<HloInstruction*>& control_predecessors() const {
890     return control_predecessors_;
891   }
control_successors()892   const std::vector<HloInstruction*>& control_successors() const {
893     return control_successors_;
894   }
895 
896   // Returns true if "other" performs the same computation as this instruction.
897   bool Identical(
898       const HloInstruction& other,
899       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
900           eq_operands = std::equal_to<const HloInstruction*>(),
901       const std::function<bool(const HloComputation*, const HloComputation*)>&
902           eq_computations = std::equal_to<const HloComputation*>(),
903       bool layout_sensitive = true) const {
904     // An instruction is always identical to itself.
905     if (this == &other) {
906       return true;
907     }
908 
909     // Identical instruction must have the same opcode, shape, and identical
910     // operands.
911     if (opcode() != other.opcode()) {
912       return false;
913     }
914     if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
915                            : ShapeUtil::Compatible(shape(), other.shape()))) {
916       return false;
917     }
918     if (operands().size() != other.operands().size()) {
919       return false;
920     }
921 
922     // Two AllReduces are Identical if they have the same all_reduce_id.
923     // Their operands don't have to be Identical.
924     if (!IsCrossModuleAllReduce()) {
925       // Use an explicit loop rather than ContainerEquals, because copying
926       // around std::functions may be too expensive in some cases.
927       for (size_t i = 0; i < operands().size(); ++i) {
928         if (!eq_operands(operand(i), other.operand(i))) {
929           return false;
930         }
931       }
932     }
933 
934     if (backend_config_ != other.backend_config_) {
935       return false;
936     }
937 
938     return IdenticalSlowPath(other, eq_computations);
939   }
940 
941   // Generates a hash value of an HLO instruction. Hash considers
942   // information on opcode, shape, operands, and typically a root instruction.
943   // This function returns the same hash value for equivalent HLO instructions,
944   // with respect to HloInstruction::Identical() method.
945   //
946   // Uses hash_operand function to compute hash values of its operands.
947   // At the very top level, hash_operand should be non-recursive to prevent
948   // non-termination.
949   uint64 Hash(
950       const std::function<uint64(const HloInstruction*)>& hash_operand) const;
951 
952   // Calls the above method with non-recursive hash_operand function.
953   uint64 Hash() const;
954 
955   // Returns whether the instruction has a constant operand.
956   bool HasConstantOperand() const;
957 
958   // Replaces the use of this instruction in "user" with "new_producer". Note
959   // that there might be multiple uses of this instruction in "user"; all will
960   // be replaced.
961   //
962   // If user is a fusion instruction, this function will remove any duplicated
963   // operands of it which could be created due to this replacement.
964   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
965 
966   // Same as ReplaceUseWith(), but new_producer can have a different shape.
967   Status ReplaceUseWithDifferentShape(HloInstruction* user,
968                                       HloInstruction* new_producer);
969 
970   // Replaces the specified operand with new_operand. The old and new operands
971   // must have compatible shapes ignoring floating-point precision.
972   //
973   // This function does NOT remove duplicated operands even if this instruction
974   // is a fusion, so that the existing operand numbers do not change.
975   Status ReplaceOperandWith(int64 operand_num, HloInstruction* new_operand);
976 
977   // Same as ReplaceOperandWith(), but new_operand can have a different shape.
978   Status ReplaceOperandWithDifferentShape(int64 operand_num,
979                                           HloInstruction* new_operand);
980 
981   // Replaces all uses of this instruction with the new producer. If
982   // new_producer is a user of this instruction then new_producer remains a use
983   // of this instruction to avoid introducing cycles into the graph.
984   //
985   // If this instruction is the root of its computation, sets the computation's
986   // root to new_producer.
987   //
988   // The new producer must have a compatible shape ignoring floating-point
989   // precision.
990   //
991   // If a user is a fusion instruction, this function will remove any duplicated
992   // operands of it which could be created due to this replacement.
993   Status ReplaceAllUsesWith(HloInstruction* new_producer);
994 
995   // Same as ReplaceAllUsesWith, but new_producer can have a different shape.
996   Status ReplaceAllUsesWithDifferentShape(HloInstruction* new_producer);
997 
998   // Performs a postorder DFS visit using this node as the root. If
999   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
1000   // complete. If ignore_control_predecessors is true, instructions only
1001   // reachable via control dependencies will not be visited, and the postorder
1002   // will not take control dependencies into account. It is as if the control
1003   // dependencies didn't exist in the graph at all.
1004   template <typename HloInstructionPtr>
1005   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
1006                 bool call_finish_visit = true,
1007                 bool ignore_control_predecessors = false);
1008   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
1009                 bool ignore_control_predecessors = false) const {
1010     return const_cast<HloInstruction*>(this)->Accept(
1011         visitor, call_finish_visit, ignore_control_predecessors);
1012   }
1013 
1014   // Same as Accept() above, but the order of operand and control predecessor
1015   // visitation is determined by the given operand order; if compare(A, B) ==
1016   // true, A is visited before B.
1017   using CompareFunction =
1018       std::function<bool(const HloInstruction*, const HloInstruction*)>;
1019   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
1020                                 const CompareFunction& operand_order,
1021                                 bool call_finish_visit = true);
1022 
1023   // Performs a postorder DFS visit using this node as the root. Calls the given
1024   // visitor function at each instruction.
1025   Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
1026   Status Accept(
1027       const std::function<Status(const HloInstruction*)>& visitor_func) const;
1028 
1029   // Visit this instruction and only this instruction with the given visitor.
1030   template <typename HloInstructionPtr>
1031   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
1032 
1033   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
1034   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
1035   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
1036   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
1037       const;
1038 
LatestNonGteAncestorAndIndex()1039   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
1040     auto rv =
1041         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
1042     return {const_cast<HloInstruction*>(rv.first), rv.second};
1043   }
1044 
1045   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
1046   const HloInstruction* LatestNonGteAncestor() const;
1047 
LatestNonGteAncestor()1048   HloInstruction* LatestNonGteAncestor() {
1049     return const_cast<HloInstruction*>(
1050         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
1051   }
1052 
1053   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
1054   // The setter should only be called by HloModule or HloComputation methods.
1055   //
1056   // Precondition: The instruction has a valid to_apply_ field.
1057   HloComputation* to_apply() const;
1058   void set_to_apply(HloComputation* to_apply);
1059 
1060   // Gets/sets the while_condition or while_body HloComputation for While. The
1061   // setters should only be called by HloModule or HloComputation methods.
1062   //
1063   // Precondition: The instruction is a While instruction.
1064   HloComputation* while_condition() const;
1065   HloComputation* while_body() const;
1066   void set_while_condition(HloComputation* while_condition);
1067   void set_while_body(HloComputation* while_body);
1068 
1069   HloInstruction* while_init() const;
1070 
1071   // Gets/sets the true and false HloComputation for Conditional.
1072   //
1073   // Precondition: The instruction is a predicated Conditional instruction.
1074   HloComputation* true_computation() const;
1075   HloComputation* false_computation() const;
1076 
1077   // Gets the branch HloComputations for Conditional.
1078   //
1079   // Precondition: The instruction is a Conditional instruction.
1080   const std::vector<HloComputation*>& branch_computations() const;
1081   int branch_count() const;
1082   HloComputation* branch_computation(int b) const;
1083   // Sets a branch HloComputation for Conditional.
1084   // The setter should only be called by HloModule or HloComputation methods.
1085   //
1086   // Precondition: The instruction is a Conditional instruction.
1087   void set_branch_computation(int b, HloComputation* computation);
1088 
1089   // Returns a string for the signature of this instruction if considered as a
1090   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
1091   string SignatureString() const;
1092 
1093   // Returns a debugging string that represents this instruction.
1094   //
1095   // (We express the default options using an overload rather than a default
1096   // param because gdb ignores default params, but does resolve overloads.)
1097   //
1098   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
1099   // default, backing off on providing full information for very large strings,
1100   // or provide a different name for a ToString-like function that does that.
ToString()1101   string ToString() const { return ToString(HloPrintOptions()); }
1102   string ToString(const HloPrintOptions& options) const;
1103 
1104   // Components of the ToString() representation:
1105 
1106   // Returns a string representation of the operand list.
1107   string OperandsToString(const HloPrintOptions& options) const;
1108 
1109   // Returns string representation of op-specific attributes.
1110   std::vector<string> ExtraAttributesToString(
1111       const HloPrintOptions& options) const;
1112 
1113   // As ToString, but returns a shorter string.
1114   string ToShortString() const;
1115 
1116   // Returns a serialized representation of this instruction.
1117   virtual HloInstructionProto ToProto() const;
1118 
1119   // Returns a category for the HLO. This could be something like "convolution"
1120   // or "elementwise".
1121   virtual string ToCategory() const;
1122 
1123   // Returns a logging instruction, if the output of this instruction is logged.
1124   //
1125   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
1126   HloInstruction* tracing() const;
1127   void set_tracing(HloInstruction* trace_instruction);
1128 
1129   // Returns true if this instruction is fused, ie contained within a fusion
1130   // instruction.
1131   bool IsFused() const;
1132 
1133   // Returns true if this instruction can be legally fused into a fusion
1134   // instruction.
1135   bool IsFusible() const;
1136 
1137   // Returns the sharding applied to this operator.
1138   // REQUIRES: has_sharding() is true.
sharding()1139   const HloSharding& sharding() const {
1140     CHECK(has_sharding());
1141     return *sharding_;
1142   }
sharding_ptr()1143   std::shared_ptr<const HloSharding> sharding_ptr() const { return sharding_; }
1144 
1145   // Returns the sharding applied to this operator, or default_ if none exists.
sharding_or_default(const HloSharding & default_)1146   const HloSharding& sharding_or_default(const HloSharding& default_) const {
1147     return sharding_ ? *sharding_ : default_;
1148   }
1149   // Returns the sharding unique device, if any.
sharding_unique_device()1150   absl::optional<int64> sharding_unique_device() const {
1151     if (sharding_ == nullptr) {
1152       return absl::optional<int64>();
1153     }
1154     return sharding_->UniqueDevice();
1155   }
1156   // Sets the sharding of this operator. Should only be called by HloModule or
1157   // HloComputation methods.
set_sharding(const HloSharding & sharding)1158   void set_sharding(const HloSharding& sharding) {
1159     sharding_ = std::make_shared<const HloSharding>(sharding);
1160   }
set_sharding(std::shared_ptr<const HloSharding> sharding)1161   void set_sharding(std::shared_ptr<const HloSharding> sharding) {
1162     sharding_ = std::move(sharding);
1163   }
1164   void set_single_sharding(const HloSharding& sharding);
1165   // Sets a sharding that assigns the current instruction to device.
set_device_sharding(int64 device)1166   void set_device_sharding(int64 device) {
1167     set_single_sharding(HloSharding::AssignDevice(device));
1168   }
1169   // Remove any sharding from this operator.
clear_sharding()1170   void clear_sharding() { sharding_ = nullptr; }
1171   // Return true if this operator has a sharding assigned.
has_sharding()1172   bool has_sharding() const { return sharding_ != nullptr; }
1173   // Checks whether the instruction has compatible sharding with the other
1174   // instruction.
has_compatible_sharding(const HloInstruction * other)1175   bool has_compatible_sharding(const HloInstruction* other) const {
1176     if (!has_sharding()) {
1177       return !other->has_sharding();
1178     }
1179     return other->has_sharding() ? sharding() == other->sharding() : false;
1180   }
1181 
1182   // When creating a new instruction which either replaces, or shifts up (kCopy
1183   // insertion case), another instruction, we need to make sure the certain
1184   // properties of the new instruction are copied into the derived one. As of
1185   // today, the metadata and sharding will be propagated to the derived
1186   // instruction.
1187   void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
1188 
1189   // Clones the HLO instruction. The clone will have the same opcode, shape, and
1190   // operands. After creation the clone has no uses. "this" (the instruction
1191   // cloned from) is not changed. Suffix is the string to append to the name of
1192   // the instruction to form the name of the cloned instruction.
1193   // Ignores the control predecessors and successors of this HLO instruction.
1194   std::unique_ptr<HloInstruction> Clone(
1195       const string& suffix = "clone", HloCloneContext* context = nullptr) const;
1196 
1197   // Clones the HLO instruction as above but with new shape and operands.
1198   std::unique_ptr<HloInstruction> CloneWithNewOperands(
1199       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1200       HloCloneContext* context = nullptr) const;
1201 
1202   // Returns the computations this instruction directly calls (if any).
called_computations()1203   const std::vector<HloComputation*>& called_computations() const {
1204     return called_computations_;
1205   }
1206 
1207   // Replaces all called computations based on a map function. This is needed
1208   // when we clone hlo_computations and want to let the instructions to point
1209   // to the newly cloned nodes.
ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1210   void ReplaceCalledComputations(
1211       std::function<HloComputation*(HloComputation*)> map_function) {
1212     for (int64 i = 0; i < called_computations_.size(); ++i) {
1213       called_computations_[i] = map_function(called_computations_[i]);
1214     }
1215   }
1216 
1217   // Clears out the called computations.
1218   //
1219   // This is, in particular, necessary when inlining function bodies into their
1220   // caller. If there were side-effecting operations in the called computations,
1221   // the call itself is considered side-effecting and thus cannot be removed. By
1222   // clearing out the computations, we reflect the fact that all side-effecting
1223   // properties have been reflected in the caller, and make the call HLO
1224   // removable.
ClearCalledComputations()1225   void ClearCalledComputations() { called_computations_.clear(); }
1226 
1227   // Returns true if this instruction performs an elementwise operation on
1228   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1229   // to compute the output at index {i_0,i_1,...,i_n}, the only element required
1230   // from the operand (if any) is the element at {i_0,i_1,...,i_n}.
1231   //
1232   // Note on performance: when this instruction is kFusion, this method, in the
1233   // worst case, scans all fused instructions. We could speed this up by
1234   // caching.
1235   bool IsElementwiseOnOperand(int64 operand_idx) const;
1236 
1237   // Returns true if this instruction is elementwise on all its operands.
1238   bool IsElementwise() const;
1239 
1240   // Returns true if this is a cross module all-reduce instruction.
1241   bool IsCrossModuleAllReduce() const;
1242 
1243   // Returns true if this is a cross-replica all-reduce instruction.
1244   bool IsCrossReplicaAllReduce() const;
1245 
1246   // Returns true if this instruction is binary and elementwise.
1247   bool IsElementwiseBinary() const;
1248 
1249   // Returns whether this instruction may reuse elements of its `i`th operand.
ReusesOperandElements(int64 i)1250   bool ReusesOperandElements(int64 i) const {
1251     return OperandElementUse(i) == UseKind::kReuse;
1252   }
1253 
1254   // Returns the indices that the given operand appear in the operand list of
1255   // this instruction. Note that an instruction can use the same operand
1256   // multiple times.
1257   std::vector<int64> OperandIndices(const HloInstruction* operand) const;
1258 
1259   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1260   // this reshape merely inserts or deletes 1-sized dimensions, return the input
1261   // indices of the deleted dimensions and the output indices of the inserted
1262   // dimensions.
1263   //
1264   // Precondition: this op must be a reshape.
1265   std::tuple<bool, std::vector<int64>, std::vector<int64>>
1266   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1267 
1268   // Gets the string identifier for this instruction.
name()1269   const string& name() const { return name_; }
1270 
1271   // Sets the string identifier for this instruction. Name will be sanitized to
1272   // match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
SetAndSanitizeName(const string & name)1273   void SetAndSanitizeName(const string& name) {
1274     name_ = NameUniquer::GetSanitizedName(name);
1275   }
1276 
1277   // Use the given NameUniquer to select a unique name for the instruction based
1278   // on the instruction's existing name.
1279   void UniquifyName(NameUniquer* name_uniquer);
1280 
1281   // Clear the unique ID of the instruction so that it can be re-assigned, such
1282   // as for the purpose of compacting the instruction unique IDs.
ClearUniqueIdInternal()1283   void ClearUniqueIdInternal() { unique_id_ = -1; }
1284 
1285   // Set the unique id for this instruction to "id"
SetUniqueId(int id)1286   void SetUniqueId(int id) {
1287     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1288     CHECK_GE(id, 0);
1289     unique_id_ = id;
1290   }
1291 
1292   // Return the unique ID assigned to this node via SetUniqueId (or -1
1293   // if no id has been assigned yet).
unique_id()1294   int unique_id() const { return unique_id_; }
1295 
1296   // Returns the backend-specific configuration for how a backend should compile
1297   // this HLO. The meaning of the field is backend specific. Not for use before
1298   // or during general HLO optimization, since HLO optimizations do not preserve
1299   // this field and they cannot interpret it due to its meaning being backend
1300   // specific.
1301   //
1302   // ConfigProto should be a protobuf Message type.
1303   template <typename ConfigProto>
backend_config()1304   StatusOr<ConfigProto> backend_config() const {
1305     ConfigProto proto;
1306     TF_RETURN_IF_ERROR(GetBackendConfigInternal(&proto));
1307     return std::move(proto);
1308   }
1309   Status set_backend_config(const tensorflow::protobuf::Message& proto);
1310 
1311   // Getter/setter for raw JSON-encoded backend config.  Prefer the
1312   // functions above that deal in proto Messages where possible.
raw_backend_config_string()1313   const string& raw_backend_config_string() const { return backend_config_; }
set_raw_backend_config_string(string config_str)1314   void set_raw_backend_config_string(string config_str) {
1315     backend_config_ = std::move(config_str);
1316   }
1317 
is_default_config()1318   bool is_default_config() const { return is_default_config_; }
set_default_config()1319   void set_default_config() { is_default_config_ = true; }
1320 
1321   // Returns a string representation of a proto in the format used by
1322   // raw_backend_config_string.
1323   //
1324   // This is morally equivalent to:
1325   //
1326   //   HloInstruction instr;
1327   //   TF_RETURN_IF_ERROR(instr.set_backend_config(proto));
1328   //   return instr.raw_backend_config_string();
1329   //
1330   static StatusOr<string> BackendConfigToRawString(
1331       const tensorflow::protobuf::Message& proto);
1332 
1333   // Returns the information used to tell the implementation information about
1334   // what sort of precision is requested. The meaning of the field is backend
1335   // specific. At the moment, it is only supported for kConvolution and kDot.
1336   // Transformations on one kDot or kConvolution to another will preserve this
1337   // information. Transformations to other HLOs will not preserve this
1338   // information but it is presumed that the alternate lowering is strictly
1339   // superior.
1340   // Precondition: opcode must be kConvolution or kDot.
1341   const PrecisionConfig& precision_config() const;
1342   PrecisionConfig* mutable_precision_config();
1343 
1344   // Sets the debug metadata for this instruction.
set_metadata(const OpMetadata & metadata)1345   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
metadata()1346   const OpMetadata& metadata() const { return metadata_; }
1347 
1348   // Set/get the computation containing this instruction. set_parent should only
1349   // be called by HloComputation methods which add/remove instructions to
1350   // computations.
set_parent(HloComputation * computation)1351   void set_parent(HloComputation* computation) { parent_ = computation; }
parent()1352   const HloComputation* parent() const { return parent_; }
parent()1353   HloComputation* parent() { return parent_; }
1354 
1355   // Returns the module for this instruction.
1356   HloModule* GetModule() const;
1357 
1358   // Returns whether we could assign input and output layouts to this
1359   // instruction to make it a bitcast.
1360   bool CouldBeBitcast() const;
1361 
1362   // Get/Set the number of partitions per outer dimension (in order, starting
1363   // with outer-most dimension first). Currently used by the parallel cpu
1364   // backend to partition HLOs into parallel tasks.
1365   //
1366   // TODO(b/62783254) Replace these methods with a more general way to
1367   // annotate HLOs with backend-specific information.
outer_dimension_partitions()1368   const std::vector<int64>& outer_dimension_partitions() const {
1369     return outer_dimension_partitions_;
1370   }
1371   void set_outer_dimension_partitions(
1372       const std::vector<int64>& outer_dimension_partitions);
1373 
1374   // Old methods kept for smooth subclassing transition BEGIN.
1375   // TODO(b/80131774): Remove this code.
1376 
1377   // Delegates to HloBatchNormInstruction::feature_index.
1378   int64 feature_index() const;
1379 
1380   // Delegates to HloBatchNormInstruction::epsilon.
1381   float epsilon() const;
1382 
1383   // Delegates to HloFftInstruction::fft_type.
1384   FftType fft_type() const;
1385 
1386   // Delegates to HloFftInstruction::fft_length.
1387   const std::vector<int64>& fft_length() const;
1388 
1389   // Delegates to HloSendRecvInstruction::channel_id.
1390   int64 channel_id() const;
1391 
1392   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()1393   virtual const std::vector<int64>& dimensions() const {
1394     LOG(FATAL) << "Unimplemented method.";
1395   }
dimensions(int64 index)1396   virtual int64 dimensions(int64 index) const {
1397     LOG(FATAL) << "Unimplemented method.";
1398   }
1399 
1400   // Delegates to HloConcatenateInstruction::concatenate_dimension.
1401   int64 concatenate_dimension() const;
1402 
1403   // Delegates to HloGetDimensionSizeInstruction::dimension.
1404   int64 dimension() const;
1405 
1406   // Returns whether this instruction does a rank-2 transposition.
1407   bool IsRank2Transpose() const;
1408 
1409   // Delegates to HloSliceInstruction::slice_start.
1410   int64 slice_starts(int64 dimension) const;
1411   const std::vector<int64>& slice_starts() const;
1412 
1413   // Delegates to HloSliceInstruction::slice_limits.
1414   int64 slice_limits(int64 dimension) const;
1415   const std::vector<int64>& slice_limits() const;
1416 
1417   // Delegates to HloSliceInstruction::slice_strides.
1418   int64 slice_strides(int64 dimension) const;
1419   const std::vector<int64>& slice_strides() const;
1420 
1421   // Returns the literal associated with this instruction.
1422   const Literal& literal() const;
1423 
1424   // Returns whether the instruction is a constant.
1425   bool IsConstant() const;
1426 
1427   // Delegate to HloConstantInstruction::RelayoutConstant.
1428   void RelayoutConstant(const Layout& new_layout,
1429                         const ShapeIndex& shape_index = {});
1430 
1431   // Delegates to HloTraceInstruction::TracingTag.
1432   string TracingTag() const;
1433 
1434   // Delegates to HloFusionInstruction::AddFusionOperand.
1435   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
1436 
1437   // Delegates to HloFusionInstruction::MergeFusionInstruction.
1438   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
1439 
1440   // Delegates to HloFusionInstruction::MergeFusionInstructionIntoMultiOutput.
1441   void MergeFusionInstructionIntoMultiOutput(
1442       HloInstruction* instruction_to_merge);
1443 
1444   // Delegates to HloFusionInstruction::FuseInstruction.
1445   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse);
1446 
1447   // Delegates to HloFusionInstruction::FuseInstructionIntoMultiOutput.
1448   HloInstruction* FuseInstructionIntoMultiOutput(
1449       HloInstruction* instruction_to_fuse);
1450 
1451   // Delegates to HloFusionInstruction::fused_instruction.
1452   HloComputation* fused_instructions_computation() const;
1453 
1454   // Delegates to HloFusionInstruction::fused_expression_root.
1455   HloInstruction* fused_expression_root() const;
1456 
1457   // Delegates to HloFusionInstruction::fused_instructions.
1458   const tensorflow::gtl::iterator_range<UnwrappingIterator<
1459       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
1460   fused_instructions() const;
1461 
1462   const tensorflow::gtl::iterator_range<
1463       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
1464   fused_instructions();
1465 
1466   // Delegates to HloFusionInstruction::fused_instruction_count.
1467   int64 fused_instruction_count() const;
1468 
1469   // Delegates to HloFusionInstruction::fused_parameter.
1470   HloInstruction* fused_parameter(int64 parameter_number) const;
1471 
1472   // Delegates to HloFusionInstruction::fused_parameters.
1473   const std::vector<HloInstruction*>& fused_parameters() const;
1474 
1475   // Returns true if this instruction is a fusion instruction that generates
1476   // multiple outputs.
1477   const bool IsMultiOutputFusion() const;
1478 
1479   // Delegates to HloFusionInstruction::fusion_kind.
1480   FusionKind fusion_kind() const;
1481 
1482   // Delegates to HloFusionInstruction::set_fusion_kind.
1483   void set_fusion_kind(FusionKind kind);
1484 
1485   // Delegates to HloRngInstruction::random_distribution.
1486   RandomDistribution random_distribution() const;
1487 
1488   // Delegates to HloParameterInstruction::parameter_number.
1489   int64 parameter_number() const;
1490 
1491   // Delegates to
1492   // HloParameterInstruction::set_parameter_replicated_at_leaf_buffers.
1493   void set_parameter_replicated_at_leaf_buffers(
1494       absl::Span<const bool> parameter_replicated_at_leaf_buffers);
1495 
1496   // Delegates to HloParameterInstruction::parameter_replicated_at_leaf_buffers.
1497   const absl::optional<std::vector<bool>>&
1498   parameter_replicated_at_leaf_buffers() const;
1499 
1500   // Delegates to HloGetTupleElementInstruction::tuple_index.
1501   int64 tuple_index() const;
1502 
1503   // Delegates to HloReducePrecisionInstruction::exponent_bits.
1504   int32 exponent_bits() const;
1505 
1506   // Delegates to HloReducePrecisionInstruction::mantissa_bits.
1507   int32 mantissa_bits() const;
1508 
1509   // Delegates to HloInfeedInstruction::infeed_config.
1510   string infeed_config() const;
1511 
1512   // Delegates to HloInfeedInstruction::set_infeed_config.
1513   void set_infeed_config(const string& config);
1514 
1515   // Returns the config for the Outfeed instruction.
1516   const string& outfeed_config() const;
1517 
1518   // Returns the shape for the Outfeed instruction.
1519   const Shape& outfeed_shape() const;
1520 
1521   // Delegates to HloCollectiveInstruction::replica_groups.
1522   const std::vector<ReplicaGroup>& replica_groups() const;
1523 
1524   // Delegates to HloCollectivePermuteInstruction::source_target_pairs.
1525   const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
1526 
1527   // Delegates to HloAllReduceInstruction::all_reduce_barrier.
1528   string all_reduce_barrier() const;
1529   void set_all_reduce_barrier(const string& barrier);
1530 
1531   // Delegates to HloAllReduceInstruction::all_reduce_id.
1532   absl::optional<int64> all_reduce_id() const;
1533   void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
1534 
1535   // Returns data on the window in a windowed operation such as
1536   // convolution.
window()1537   virtual const Window& window() const {
1538     LOG(FATAL) << "Unimplemented method.";
1539   }
1540 
1541   // Sets the window data in a windowed operation such as convolution.
set_window(const Window & window)1542   virtual void set_window(const Window& window) {
1543     LOG(FATAL) << "Unimplemented method.";
1544   }
1545 
1546   // Returns data on the dimension numbers used for a convolution operation,
1547   // which may be a kConvolution instruction or a kCustomCall that implements a
1548   // convolution.
1549   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const;
1550 
1551   // Sets the convolution dimension numbers on this instruction.  In general you
1552   // shouldn't need to call this; instead, specify the convolution dimension
1553   // numbers when you create the instruction.
1554   void set_convolution_dimension_numbers(
1555       const ConvolutionDimensionNumbers& dnums);
1556 
1557   // The number of feature groups. Must be a divisor of the input feature
1558   // dimension and output feature dimension.
1559   int64 feature_group_count() const;
1560 
1561   void set_feature_group_count(int64 feature_group_count);
1562 
1563   // The number of batch groups. Must be a divisor of the input batch dimension
1564   int64 batch_group_count() const;
1565 
1566   void set_batch_group_count(int64 batch_group_count);
1567 
1568   // Delegates to HloSelectAndScatterInstruction::select.
1569   HloComputation* select() const;
1570 
1571   // Delegates to HloSelectAndScatterInstruction::scatter.
1572   HloComputation* scatter() const;
1573 
1574   // Delegates to HloSelectAndScatterInstruction::set_select.
1575   void set_select(HloComputation* computation);
1576 
1577   // Delegates to HloSelectAndScatterInstruction::set_scatter.
1578   void set_scatter(HloComputation* computation);
1579 
1580   // Delegates to HloCustomCallInstruction::custom_call_target.
1581   const string& custom_call_target() const;
1582 
1583   // Delegates to HloPadInstruction::padding_config.
1584   const PaddingConfig& padding_config() const;
1585 
1586   // Delegates to HloDynamicSliceInstruction::slice_sizes.
1587   int64 slice_sizes(int64 dimension) const;
1588 
1589   // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes.
1590   const std::vector<int64>& dynamic_slice_sizes() const;
1591 
1592   // Delegates to HloGatherInstruction::gather_dimension_numbers.
1593   const GatherDimensionNumbers& gather_dimension_numbers() const;
1594   // Delegates to HloGatherInstruction::gather_slice_sizes.
1595   absl::Span<const int64> gather_slice_sizes() const;
1596 
1597   // Delegates to HloScatterInstruction::scatter_dimension_numbers().
1598   const ScatterDimensionNumbers& scatter_dimension_numbers() const;
1599 
1600   // Delegates to HloDotInstruction::dot_dimension_numbers().
1601   const DotDimensionNumbers& dot_dimension_numbers() const;
1602 
1603   // Delegates to HloDomainInstruction::operand_side_metadata().
1604   const DomainMetadata& operand_side_metadata() const;
1605 
1606   // Delegates to HloDomainInstruction::user_side_metadata().
1607   const DomainMetadata& user_side_metadata() const;
1608 
1609   // Delegates to HloCompareInstruction::direction().
1610   ComparisonDirection comparison_direction() const;
1611 
1612   // Delegates to HloTriangularSolveInstruction::triangular_solve_options().
1613   const TriangularSolveOptions& triangular_solve_options() const;
1614 
1615   // Delegates to HloCholeskyInstruction::cholesky_options().
1616   const CholeskyOptions& cholesky_options() const;
1617 
1618   // Old methods kept for smooth subclassing transition END.
1619 
1620  protected:
1621   enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
1622   // Helper class for computing OperandElementUse for kFusion.
1623   class FusionReusesParamElements;
1624 
1625   // Internal constructor for a given opcode/shape, other fields must be filled
1626   // by factory methods.
1627   HloInstruction(HloOpcode opcode, const Shape& shape);
1628 
1629   // Appends operand to the list of operands and adds this instruction as a user
1630   // of the operand.
1631   void AppendOperand(HloInstruction* operand);
1632 
RemoveOperandAt(int index)1633   void RemoveOperandAt(int index) {
1634     operands_.erase(operands_.begin() + index);
1635   }
1636 
1637   // Removes a list of operands with the given indices in ascending order.
1638   void RemoveOperandsAtAscendingIndices(
1639       absl::Span<const int> ascending_indices);
1640 
AppendComputation(HloComputation * computation)1641   void AppendComputation(HloComputation* computation) {
1642     called_computations_.push_back(computation);
1643   }
1644 
DetachFrom(HloInstruction * usee)1645   void DetachFrom(HloInstruction* usee) { usee->RemoveUser(this); }
1646 
set_called_computation(int index,HloComputation * computation)1647   void set_called_computation(int index, HloComputation* computation) {
1648     called_computations_[index] = computation;
1649   }
1650   // Indices of computations in called_computations_ for instructions which call
1651   // multiple computations.
1652   enum {
1653     // kWhile computations.
1654     kBodyComputationIndex = 0,
1655     kConditionComputationIndex = 1,
1656 
1657     // kSelectAndScatter computations.
1658     kSelectComputationIndex = 0,
1659     kScatterComputationIndex = 1,
1660 
1661     // kConditional computations.
1662     kTrueComputationIndex = 0,
1663     kFalseComputationIndex = 1,
1664   };
1665 
1666  private:
1667   // Implementation for non-common logic of CloneWithNewOperands.
CloneWithNewOperandsImpl(const Shape & shape,absl::Span<HloInstruction * const> new_operands,HloCloneContext * context)1668   virtual std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1669       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1670       HloCloneContext* context) const {
1671     // TODO(b/80131774): This should be pure virtual.
1672     LOG(FATAL) << "Unimplemented method.";
1673   }
1674 
1675   // Implementation for non-common logic of ExtraAttributesToString.
ExtraAttributesToStringImpl(const HloPrintOptions & options)1676   virtual std::vector<string> ExtraAttributesToStringImpl(
1677       const HloPrintOptions& options) const {
1678     return {};
1679   }
1680 
1681   // Implementation for IsElementwise if operand_idx is nullopt and for
1682   // IsElementwiseOnOperand if otherwise.
1683   //
1684   // NOTE: For all instructions other than kFusion, being elementwise on one of
1685   // the operands is equivalent to being elementwise on all the operands.
1686   virtual bool IsElementwiseImpl(
1687       const absl::optional<int64>& operand_idx) const;
1688   // Prints an instruction to a string.
1689   //
1690   // The canonical string representation needs to name operands and instruction
1691   // names in a consistent way. This is implemented through the
1692   // canonical_name_map.
1693   string ToStringWithCanonicalNameMap(
1694       const HloPrintOptions& options,
1695       CanonicalNameMap* canonical_name_map) const;
1696 
1697   // Prints an operand to a string.
1698   virtual string OperandsToStringWithCanonicalNameMap(
1699       const HloPrintOptions& options,
1700       CanonicalNameMap* canonical_name_map) const;
1701 
1702   // Allow HloInstruction to access the ToStringWithCanonicalNameMap() and
1703   // OperandsToStringWithCanonicalNameMap() functions.
1704   friend class HloComputation;
1705 
1706   // See comments on Identical().
1707   virtual bool IdenticalSlowPath(
1708       const HloInstruction& other,
1709       const std::function<bool(const HloComputation*, const HloComputation*)>&
1710           eq_computations) const;
1711 
1712   // Generates a hash value specific to a particular type of an instruction.
1713   // This function typically considers the inner root instruction.
1714   virtual uint64 InnerHash() const;
1715 
1716   // Creates an n-ary elementwise operation.
1717   static std::unique_ptr<HloInstruction> CreateNary(
1718       const Shape& shape, HloOpcode opcode,
1719       absl::Span<HloInstruction* const> operands);
1720 
1721   // Adds a user for this instruction.
1722   void AddUser(HloInstruction* user);
1723 
1724   // Removes a user for this instruction.
1725   void RemoveUser(HloInstruction* user);
1726 
1727   // Returns how this instruction uses elements of its `i`th operand.
1728   UseKind OperandElementUse(int64 i) const;
1729 
1730   // Helper for implementing backend_config().  Parses backend_config_ into the
1731   // given proto.
1732   Status GetBackendConfigInternal(tensorflow::protobuf::Message* proto) const;
1733 
1734   int unique_id_;  // Unique to this HloInstruction within a HloModule
1735 
1736   // Opcode for this instruction.
1737   HloOpcode opcode_;
1738 
1739   // Instruction operands.
1740   InstructionVector operands_;
1741 
1742   // The set of control predecessors of this instruction.
1743   // Note that the order of the instructions in the vector influences the order
1744   // computed in HloComputation::ComputeInstructionPostOrder, which may
1745   // influence the result of the compilation by changing the scheduling. We are
1746   // not sure if it matters.
1747   std::vector<HloInstruction*> control_predecessors_;
1748 
1749   // The users of this instruction. Users are HLOs where this instruction is an
1750   // operand. The vector users_ and the set user_set_ contain identical
1751   // members. The set enables fast membership testing and the vector enables
1752   // fast, stable iteration.
1753   std::vector<HloInstruction*> users_;
1754   absl::flat_hash_set<const HloInstruction*> user_set_;
1755 
1756   // The set of control successors of this instruction.
1757   std::vector<HloInstruction*> control_successors_;
1758 
1759   // The computation in which this instruction is contained.
1760   HloComputation* parent_ = nullptr;
1761 
1762   // Result shape of this instruction.
1763   Shape shape_;
1764 
1765   // The sharding, if one exists.
1766   // Uses std::shared_ptr to allow reuse of the same sharding object between
1767   // HloInstructions and other components as HloSharding can be very large for
1768   // many element tuples.
1769   std::shared_ptr<const HloSharding> sharding_;
1770 
1771   // Computations called by this instruction.
1772   std::vector<HloComputation*> called_computations_;
1773 
1774   // A trace instruction that consumes this instruction.
1775   //
1776   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
1777   // an operand.
1778   HloInstruction* trace_instruction_ = nullptr;
1779 
1780   // The backend-specific configuration for how a backend should compile this
1781   // HLO. See the documentation on backend_config().
1782   string backend_config_;
1783 
1784   // This field is assigned to true when backend_config_ is assigned to
1785   // a default configuration.
1786   bool is_default_config_ = false;
1787 
1788   // String identifier for instruction.
1789   string name_;
1790 
1791   // Metadata for debugging.
1792   OpMetadata metadata_;
1793 
1794   // The number of partitions per outer dimension (listed in order from
1795   // outer-most dimension first).
1796   std::vector<int64> outer_dimension_partitions_;
1797 
1798   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
1799 };
1800 
1801 // Explicit instantiations in hlo_instruction.cc.
1802 extern template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool);
1803 extern template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool);
1804 
1805 string ToString(HloInstruction::FusionKind kind);
1806 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
1807     const string& kind_name);
1808 
1809 // Custom (de)stringification functions for protos that live inside
1810 // HloInstruction.
1811 string PaddingConfigToString(const PaddingConfig& padding);
1812 string OpMetadataToString(const OpMetadata& metadata);
1813 string RandomDistributionToString(const RandomDistribution& distribution);
1814 string PrecisionToString(const PrecisionConfig::Precision& precision);
1815 string ConvolutionDimensionNumbersToString(
1816     const ConvolutionDimensionNumbers& dnums);
1817 
1818 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
1819 StatusOr<PrecisionConfig::Precision> StringToPrecision(const string& name);
1820 
1821 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
1822 
1823 // Map classes that guarantee a deterministic iteration order when the key is
1824 // an HloInstruction* or a const HloInstruction*.
1825 // To make the iteration order over the map deterministic, the comparator
1826 // should not be using the pointer values, but rather an intrinsic property of
1827 // the hlo. Exception: null pointer values compare less than non-null.
1828 struct HloPtrComparator {
1829   bool operator()(const HloInstruction* const& lhs,
1830                   const HloInstruction* const& rhs) const;
1831 };
1832 
1833 template <typename ValueT>
1834 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
1835 
1836 template <typename ValueT>
1837 using ConstHloInstructionMap =
1838     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
1839 
1840 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
1841 using ConstHloInstructionSet =
1842     std::set<const HloInstruction*, HloPtrComparator>;
1843 
1844 }  // namespace xla
1845 
1846 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
1847