1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // HLO instructions are in DAG form and represent the computations that the user
17 // has built up via the XLA service interface. They are ultimately lowered
18 // in a platform-aware way by traversing the HLO DAG and emitting a lowered
19 // form; e.g. see DfsHloVisitor.
20 
21 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
22 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
23 
24 #include <functional>
25 #include <iosfwd>
26 #include <list>
27 #include <memory>
28 #include <set>
29 #include <string>
30 #include <tuple>
31 #include <unordered_map>
32 #include <unordered_set>
33 #include <vector>
34 
35 #include "tensorflow/compiler/xla/iterator_util.h"
36 #include "tensorflow/compiler/xla/literal_util.h"
37 #include "tensorflow/compiler/xla/map_util.h"
38 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
39 #include "tensorflow/compiler/xla/service/hlo.pb.h"
40 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
41 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
42 #include "tensorflow/compiler/xla/service/name_uniquer.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/xla_data.pb.h"
45 #include "tensorflow/core/lib/core/status.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/gtl/array_slice.h"
48 #include "tensorflow/core/lib/gtl/flatmap.h"
49 #include "tensorflow/core/lib/gtl/inlined_vector.h"
50 #include "tensorflow/core/lib/gtl/iterator_range.h"
51 #include "tensorflow/core/platform/logging.h"
52 #include "tensorflow/core/platform/macros.h"
53 #include "tensorflow/core/platform/types.h"
54 
55 namespace xla {
56 
57 class HloComputation;
58 class HloModule;
59 
60 // A bunch of switches that control how the hlo text should be printed.
61 class HloPrintOptions {
62  public:
63   // Constructs the default print options: don't print large constants, don't
64   // compact operands, no indentation.
HloPrintOptions()65   HloPrintOptions()
66       : print_large_constants_(false),
67         print_subcomputation_references_(true),
68         print_metadata_(true),
69         compact_operands_(false),
70         print_operand_shape_(true),
71         print_program_shape_(true),
72         print_percent_(true),
73         indent_amount_(0) {}
74 
ShortParsable()75   static HloPrintOptions ShortParsable() {
76     return HloPrintOptions()
77         .set_print_large_constants(true)
78         .set_print_subcomputation_references(true)
79         .set_print_metadata(false)
80         .set_print_operand_shape(false)
81         .set_print_program_shape(false)
82         .set_print_percent(false);
83   }
84 
85   // If true, large constants will be printed out.
set_print_large_constants(bool value)86   HloPrintOptions& set_print_large_constants(bool value) {
87     print_large_constants_ = value;
88     return *this;
89   }
90 
91   // If true, the names of subcomputations (e.g. a fusion node's fused
92   // computation) won't be printed.  This makes the resulting text not parsable.
93   //
94   // A CustomCall's call target is printed even if
95   // print_subcomputation_references is false, because the call target isn't an
96   // HloComputation.
set_print_subcomputation_references(bool value)97   HloPrintOptions& set_print_subcomputation_references(bool value) {
98     print_subcomputation_references_ = value;
99     return *this;
100   }
101 
102   // If true, metatdata will be printed.
set_print_metadata(bool value)103   HloPrintOptions& set_print_metadata(bool value) {
104     print_metadata_ = value;
105     return *this;
106   }
107 
108   // If true, operands' shapes will be printed.
set_print_operand_shape(bool value)109   HloPrintOptions& set_print_operand_shape(bool value) {
110     print_operand_shape_ = value;
111     return *this;
112   }
113 
114   // If true, program shape of hlo computations will be printed.
set_print_program_shape(bool value)115   HloPrintOptions& set_print_program_shape(bool value) {
116     print_program_shape_ = value;
117     return *this;
118   }
119 
120   // If true, names will be printed with prefix '%'.
set_print_percent(bool value)121   HloPrintOptions& set_print_percent(bool value) {
122     print_percent_ = value;
123     return *this;
124   }
125 
126   // If true, only a part of operands will be printed out, and their names will
127   // be omitted (note that in this case the text will not be parsable).
set_compact_operands(bool value)128   HloPrintOptions& set_compact_operands(bool value) {
129     compact_operands_ = value;
130     return *this;
131   }
132 
133   // The indent of the hlo text block.
set_indent_amount(int value)134   HloPrintOptions& set_indent_amount(int value) {
135     indent_amount_ = value;
136     return *this;
137   }
138 
print_large_constants()139   bool print_large_constants() const { return print_large_constants_; }
print_subcomputation_references()140   bool print_subcomputation_references() const {
141     return print_subcomputation_references_;
142   }
print_metadata()143   bool print_metadata() const { return print_metadata_; }
compact_operands()144   bool compact_operands() const { return compact_operands_; }
print_operand_shape()145   bool print_operand_shape() const { return print_operand_shape_; }
print_program_shape()146   bool print_program_shape() const { return print_program_shape_; }
print_percent()147   bool print_percent() const { return print_percent_; }
indent_amount()148   int indent_amount() const { return indent_amount_; }
149 
150  private:
151   bool print_large_constants_;
152   bool print_subcomputation_references_;
153   bool print_metadata_;
154   bool compact_operands_;
155   bool print_operand_shape_;
156   bool print_program_shape_;
157   bool print_percent_;
158   int indent_amount_;
159 };
160 
161 // HLO instructions are the IR used by the high-level compiler.
162 class HloInstruction {
163  public:
164   enum class FusionKind {
165     kLoop,          // Fused into a loop.
166     kInput,         // Op's input is fused into the op itself.
167     kOutput,        // Op's output is fused into the op itself.
168                     // REQUIRES: At least one operand buffer must be able
169                     // to alias the output buffer.
170     kTransposeDot,  // Fused into a dot with transposed operands.
171     kCustom,        // Custom category for backend-specific fusions that
172                     // do not match any of the more specific ones.
173   };
174 
175   ~HloInstruction();
176 
177   // Creates an instruction from the given proto. Arguments:
178   //
179   //   module: the module which will contain the instruction. The newly created
180   //     instruction is *not* added to the module or any computation, however.
181   //   proto: the proto to convert from.
182   //   instruction_map: a map from instruction name to HloInstruction*. This map
183   //     must contain all operands of the newly constructed instruction.
184   //   computation_map: a map from computation name to HloComputation*. This map
185   //     must contain all computations which the newly constructed instruction
186   //     calls.
187   //   add_fused_computation: A function to call to add a fused
188   //     computation. Used (clearly) when the instruction is a fusion
189   //     instruction.
190   static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
191       HloModule* module, const HloInstructionProto& proto,
192       const tensorflow::gtl::FlatMap<string, HloInstruction*>& instruction_map,
193       const tensorflow::gtl::FlatMap<string, HloComputation*>& computation_map,
194       const std::function<void(std::unique_ptr<HloComputation>)>&
195           add_fused_computation);
196 
197   // Creates a parameter-retrieving instruction.
198   static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
199                                                          const Shape& shape,
200                                                          const string& name);
201 
202   // Creates a literal constant instruction.
203   static std::unique_ptr<HloInstruction> CreateConstant(
204       std::unique_ptr<Literal> literal);
205 
206   // Creates a get tuple element instruction.
207   static std::unique_ptr<HloInstruction> CreateGetTupleElement(
208       const Shape& shape, HloInstruction* operand, int64 index);
209 
210   // Creates a trace instruction that logs the input operand in the computation.
211   static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
212                                                      HloInstruction* operand);
213 
214   // Creates a random number generation instruction that fills a shape with
215   // random numbers from a given distribution.
216   static std::unique_ptr<HloInstruction> CreateRng(
217       const Shape& shape, RandomDistribution distribution,
218       tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
219 
220   // Creates a unary instruction (one operand).
221   // Precondition: opcode must be a legitimate unary operation.
222   static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
223                                                      HloOpcode opcode,
224                                                      HloInstruction* operand);
225 
226   // Creates a binary instruction (two operands).
227   // Precondition: opcode must be a legitimate binary operation.
228   static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
229                                                       HloOpcode opcode,
230                                                       HloInstruction* lhs,
231                                                       HloInstruction* rhs);
232 
233   // Creates a ternary instruction (three operands).
234   // Precondition: opcode must be a legitimate ternary operation.
235   static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
236                                                        HloOpcode opcode,
237                                                        HloInstruction* lhs,
238                                                        HloInstruction* rhs,
239                                                        HloInstruction* ehs);
240 
241   // Creates a variadic instruction (variable number of operands).
242   // Precondition: opcode must be a legitimate variadic operation.
243   static std::unique_ptr<HloInstruction> CreateVariadic(
244       const Shape& shape, HloOpcode opcode,
245       tensorflow::gtl::ArraySlice<HloInstruction*> operands);
246 
247   // Creates a map instruction, where the computation (given by the handle) is
248   // applied element-wise to every element in operands (across the operands,
249   // at a given index) with the same `static_operands`.
250   static std::unique_ptr<HloInstruction> CreateMap(
251       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
252       HloComputation* map_computation,
253       tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
254 
255   // Creates a convolution op, where rhs is the convolutional filter
256   // and window describes how the filter is applied to lhs.
257   static std::unique_ptr<HloInstruction> CreateConvolve(
258       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
259       const Window& window,
260       const ConvolutionDimensionNumbers& dimension_numbers);
261 
262   // Creates an FFT op, of the type indicated by fft_type.
263   static std::unique_ptr<HloInstruction> CreateFft(
264       const Shape& shape, HloInstruction* operand, FftType fft_type,
265       tensorflow::gtl::ArraySlice<int64> fft_length);
266 
267   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
268   // dimensions specified in 'dimension_numbers'.
269   static std::unique_ptr<HloInstruction> CreateDot(
270       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
271       const DotDimensionNumbers& dimension_numbers);
272 
273   // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
274   // of the LHS with dimension 0 of the RHS with no batch dimensions.  Both LHS
275   // and the RHS must be of rank 2.
276   static std::unique_ptr<HloInstruction> CreateCanonicalDot(
277       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
278 
279   // Creates a reduce-precision op, where operand is the data to reduce in
280   // precision, and exponent_bits and mantissa_bits describe the precision to
281   // reduce it to.
282   static std::unique_ptr<HloInstruction> CreateReducePrecision(
283       const Shape& shape, HloInstruction* operand, const int exponent_bits,
284       const int mantissa_bits);
285 
286   // Creates a cross replica sum op.
287   static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
288       const Shape& shape,
289       tensorflow::gtl::ArraySlice<HloInstruction*> operands);
290 
291   // Creates a conversion instruction, where operand is the data to convert and
292   // shape is the target shape for the conversion.
293   static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
294                                                        HloInstruction* operand);
295 
296   // Creates a bitcast conversion instruction, where operand is the data to
297   // convert and shape is the target shape for the conversion.
298   static std::unique_ptr<HloInstruction> CreateBitcastConvert(
299       const Shape& shape, HloInstruction* operand);
300 
301   // Creates an infeed instruction, which reads data of the given shape from the
302   // Infeed interface of the device.
303   static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
304                                                       const string& config);
305 
306   // Creates an outfeed instruction, which outputs data.
307   static std::unique_ptr<HloInstruction> CreateOutfeed(
308       const Shape& shape, HloInstruction* operand,
309       tensorflow::StringPiece outfeed_config);
310 
311   // Creates an asynchronous send instruction with the given channel id, which
312   // initiates sending the operand data to a unique receive instruction in
313   // another computation that has the same channel id.
314   static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
315                                                     int64 channel_id);
316 
317   // Blocks until data transfer for the Send instruction (operand) is complete.
318   // The operand must be kSend.
319   static std::unique_ptr<HloInstruction> CreateSendDone(
320       HloInstruction* operand);
321 
322   // Creates an asynchronous receive instruction with the given channel id,
323   // which allocates resources to receive data of the given shape from a unique
324   // send instruction in another computation that has the same channel id.
325   static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
326                                                     int64 channel_id);
327 
328   // Blocks until data transfer for the Recv instruction (operand) is complete
329   // and returns the receive buffer. The operand must be kRecv.
330   static std::unique_ptr<HloInstruction> CreateRecvDone(
331       HloInstruction* operand);
332 
333   // Creates a slice instruction, where the operand is sliced by the given
334   // start/limit indices.
335   static std::unique_ptr<HloInstruction> CreateSlice(
336       const Shape& shape, HloInstruction* operand,
337       tensorflow::gtl::ArraySlice<int64> start_indices,
338       tensorflow::gtl::ArraySlice<int64> limit_indices,
339       tensorflow::gtl::ArraySlice<int64> strides);
340 
341   // Creates a slice instruction, where the first operand is sliced by
342   // start indices specified in the second operand, and by size specified in
343   // 'slice_sizes'.
344   static std::unique_ptr<HloInstruction> CreateDynamicSlice(
345       const Shape& shape, HloInstruction* operand,
346       HloInstruction* start_indices,
347       tensorflow::gtl::ArraySlice<int64> slice_sizes);
348 
349   // Creates a dynamic update slice instruction, which updates a slice
350   // of 'operand' with 'update' and 'start_indices'.
351   static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
352       const Shape& shape, HloInstruction* operand, HloInstruction* update,
353       HloInstruction* start_indices);
354 
355   // Creates a concatenate instruction, where the operands are concatenated on
356   // the provided dimension.
357   static std::unique_ptr<HloInstruction> CreateConcatenate(
358       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
359       int64 dimension);
360 
361   // Creates a reduce instruction, where the computation (given by the handle)
362   // is applied successively to every element in operand. That is, if f is the
363   // function to apply (which either takes 2 [accumulator, value] or 3
364   // [accumulator, index, value] arguments) and init is a reduction operator
365   // specified initial value (for example, 0 for addition), then this operation
366   // will compute:
367   //   f(f(init, [index0], value0), [index1], value1), ...)
368   static std::unique_ptr<HloInstruction> CreateReduce(
369       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
370       tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
371       HloComputation* reduce_computation);
372 
373   // Creates a reduce-window instruction, where the computation (given
374   // by the handle) is applied window-wise at each valid window
375   // position in the operand.
376   static std::unique_ptr<HloInstruction> CreateReduceWindow(
377       const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
378       const Window& window, HloComputation* reduce_computation);
379 
380   // Creates a batch-norm-training instruction.
381   static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
382       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
383       HloInstruction* offset, float epsilon, int64 feature_index);
384 
385   // Creates a batch-norm-inference instruction.
386   static std::unique_ptr<HloInstruction> CreateBatchNormInference(
387       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
388       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
389       float epsilon, int64 feature_index);
390 
391   // Creates a batch-norm-grad instruction.
392   static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
393       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
394       HloInstruction* mean, HloInstruction* variance,
395       HloInstruction* grad_output, float epsilon, int64 feature_index);
396 
397   // Creates a scatter computation that scatters the `source` array to the
398   // selected indices of each window.
399   static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
400       const Shape& shape, HloInstruction* operand, HloComputation* select,
401       const Window& window, HloInstruction* source, HloInstruction* init_value,
402       HloComputation* scatter);
403 
404   // Creates a broadcast instruction.
405   static std::unique_ptr<HloInstruction> CreateBroadcast(
406       const Shape& shape, HloInstruction* operand,
407       tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
408 
409   // Creates a sequence of instructions that performs an explicit broadcast of
410   // the operand to the target shape.
411   //
412   // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
413   // returned as a unique_ptr for API consistency with other factory methods in
414   // this interface.
415   //
416   // TODO(b/72173833) Ideally HloComputations would always be present, and so
417   // the adder being passed by the caller would not be necessary.
418   static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
419       const Shape& output_shape, HloInstruction* operand,
420       const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
421           adder);
422 
423   // Creates a pad instruction, where the operand is padded on the edges and
424   // between the elements with the given padding value.
425   static std::unique_ptr<HloInstruction> CreatePad(
426       const Shape& shape, HloInstruction* operand,
427       HloInstruction* padding_value, const PaddingConfig& padding_config);
428 
429   // Creates a reshape instruction, where the operand is flattened row-major
430   // order and then reshaped to the given result shape.
431   static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
432                                                        HloInstruction* operand);
433 
434   // Creates a transpose instruction which permutes the operand dimensions.
435   static std::unique_ptr<HloInstruction> CreateTranspose(
436       const Shape& shape, HloInstruction* operand,
437       tensorflow::gtl::ArraySlice<int64> dimensions);
438 
439   // Creates a while instruction, given a condition computation, a body
440   // computation, and the initial value for the input of the computations. For
441   // example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
442   // corresponds to the C code below.
443   // int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
444   static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
445                                                      HloComputation* condition,
446                                                      HloComputation* body,
447                                                      HloInstruction* init);
448 
449   static std::unique_ptr<HloInstruction> CreateConditional(
450       const Shape& shape, HloInstruction* pred,
451       HloInstruction* true_computation_arg, HloComputation* true_computation,
452       HloInstruction* false_computation_arg, HloComputation* false_computation);
453 
454   static std::unique_ptr<HloInstruction> CreateGather(
455       const Shape& shape, HloInstruction* operand,
456       HloInstruction* gather_indices,
457       const GatherDimensionNumbers& gather_dim_numbers,
458       tensorflow::gtl::ArraySlice<int64> window_bounds);
459 
460   // Creates a fusion instruction. A fusion instruction contains one or more
461   // fused instructions forming an expression with a single root
462   // "fused_root". Additional instructions can be added to the fusion
463   // instruction with the method FuseInstruction.
464   static std::unique_ptr<HloInstruction> CreateFusion(
465       const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
466 
467   static std::unique_ptr<HloInstruction> CreateFusion(
468       const Shape& shape, FusionKind fusion_kind,
469       tensorflow::gtl::ArraySlice<HloInstruction*> operands,
470       HloComputation* fusion_computation);
471 
472   // Creates a call instruction that applies the given computation on the given
473   // operands. "shape" is the resultant shape.
474   static std::unique_ptr<HloInstruction> CreateCall(
475       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
476       HloComputation* computation);
477 
478   // Creates a custom call instruction that applies the given custom call target
479   // to the given operands. "shape" is the resultant shape.
480   static std::unique_ptr<HloInstruction> CreateCustomCall(
481       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
482       tensorflow::StringPiece custom_call_target);
483 
484   // Creates a HostCompute instruction, which records host-side control and
485   // data dependencies for use in instruction scheduling.
486   static std::unique_ptr<HloInstruction> CreateHostCompute(
487       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
488       tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
489 
490   // Creates a tuple instruction with the given elements. This is a convenience
491   // wrapper around CreateVariadic.
492   static std::unique_ptr<HloInstruction> CreateTuple(
493       tensorflow::gtl::ArraySlice<HloInstruction*> elements);
494 
495   // Creates a reverse instruction, which reverses the order of the elements
496   // in the specified dimensions.
497   static std::unique_ptr<HloInstruction> CreateReverse(
498       const Shape& shape, HloInstruction* operand,
499       tensorflow::gtl::ArraySlice<int64> dimensions);
500 
501   // Creates an instance of GatherDimensionNumbers.
502   static GatherDimensionNumbers MakeGatherDimNumbers(
503       tensorflow::gtl::ArraySlice<int64> output_window_dims,
504       tensorflow::gtl::ArraySlice<int64> elided_window_dims,
505       tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims);
506 
507   // Returns the opcode for this instruction.
opcode()508   HloOpcode opcode() const { return opcode_; }
509 
510   // Returns true if this instruction has a side effect. An instruction has a
511   // side effect if it uses certain opcodes or calls a computation with a side
512   // effect.
513   bool HasSideEffect() const;
514 
515   // Returns the result shape of this instruction.
516   const Shape& shape() const;
517 
518   // Returns the (mutable) result shape of this instruction.
mutable_shape()519   Shape* mutable_shape() { return &shape_; }
520 
521   // Returns the ith operand to this instruction.
522   const HloInstruction* operand(int64 i) const;
523 
524   // Returns the ith operand to this instruction.
525   HloInstruction* mutable_operand(int64 i);
526 
527   // Returns the number of operands to this instruction.
operand_count()528   int64 operand_count() const { return operands_.size(); }
529 
530   // Returns the vector of operands of this instruction.
531   using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
operands()532   const InstructionVector& operands() const { return operands_; }
533 
534   // Returns the index of 'target' in the operands sequence.
535   // Precondition: target must be an operand (or a fatal error will occur).
536   int64 operand_index(const HloInstruction* target) const;
537 
538   // Returns the number of users of this instruction.
user_count()539   int64 user_count() const { return users_.size(); }
540 
541   // Returns the users of this instruction.
users()542   const std::vector<HloInstruction*>& users() const { return users_; }
543 
544   // Returns true if this instruction is a user of 'instruction'.
IsUserOf(const HloInstruction * instruction)545   bool IsUserOf(const HloInstruction* instruction) const {
546     return ContainsKey(instruction->user_set_, this);
547   }
548 
549   // Adds a control dependency from this instruction to the given
550   // instruction. This instruction becomes a control predecessor of
551   // 'instruction', and 'instruction' becomes a control successor of this
552   // instruction. Returns an error status if either of the given instructions
553   // does not belong to the same computation.
554   //
555   // This is used to enforce an additional ordering requirement that is not
556   // captured by normal data dependencies, such as ordering among Send or Recv
557   // operations to avoid deadlock.
558   Status AddControlDependencyTo(HloInstruction* instruction);
559 
560   // Removes a previously added control dependency from this instruction to
561   // 'instruction'.
562   Status RemoveControlDependencyTo(HloInstruction* instruction);
563 
564   // Returns the set of control predecessors (successors) of this
565   // instruction. Control predecessors (successors) must execute before (after)
566   // the current instruction.
control_predecessors()567   const std::vector<HloInstruction*>& control_predecessors() const {
568     return control_predecessors_;
569   }
control_successors()570   const std::vector<HloInstruction*>& control_successors() const {
571     return control_successors_;
572   }
573 
574   // Returns true if "other" performs the same computation as this instruction.
575   bool Identical(
576       const HloInstruction& other,
577       const std::function<bool(const HloInstruction*, const HloInstruction*)>&
578           eq_operands = std::equal_to<const HloInstruction*>(),
579       const std::function<bool(const HloComputation*, const HloComputation*)>&
580           eq_computations = std::equal_to<const HloComputation*>(),
581       bool layout_sensitive = true) const {
582     // An instruction is always identical to itself.
583     if (this == &other) {
584       return true;
585     }
586 
587     // Identical instruction must have the same opcode, shape, and identical
588     // operands.
589     if (opcode() != other.opcode()) {
590       return false;
591     }
592     using EqShapeFuncType = bool (*)(const Shape&, const Shape&);
593     EqShapeFuncType eq_shapes =
594         layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible;
595     if (!eq_shapes(shape(), other.shape())) {
596       return false;
597     }
598     if (operands().size() != other.operands().size()) {
599       return false;
600     }
601 
602     // Use an explicit loop rather than ContainerEquals, because copying around
603     // std::functions may be too expensive in some cases.
604     for (size_t i = 0; i < operands().size(); ++i) {
605       if (!eq_operands(operand(i), other.operand(i))) {
606         return false;
607       }
608     }
609 
610     return IdenticalSlowPath(other, eq_computations, eq_shapes);
611   }
612 
613   // Returns whether the instruction has a constant operand.
614   bool HasConstantOperand() const;
615 
616   // Returns whether this instruction does a rank-2 transposition.
617   bool IsRank2Transpose() const;
618 
619   // Replaces the use of this instruction in "user" with "new_producer". Note
620   // that there might be multiple uses of this instruction in "user"; all will
621   // be replaced.
622   Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
623 
624   // Replaces the specified operand with new_operand.
625   Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
626 
627   // Replaces all uses of this instruction with the new producer. If
628   // new_producer is a user of this instruction then new_producer remains a use
629   // of this instruction to avoid introducing cycles into the graph.
630   //
631   // If this instruction is the root of its computation, sets the computation's
632   // root to new_producer.
633   Status ReplaceAllUsesWith(HloInstruction* new_producer);
634 
635   // Detaches an instruction from its operands. That is, remove the instruction
636   // from each operand's user set. This should only be called prior to
637   // deallocating the instruction.
638   void DetachFromOperands();
639 
640   // Performs a postorder DFS visit using this node as the root. If
641   // call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
642   // complete. If ignore_control_predecessors is true, instructions only
643   // reachable via control dependencies will not be visited, and the postorder
644   // will not take control dependencies into account. It is as if the control
645   // dependencies didn't exist in the graph at all.
646   template <typename HloInstructionPtr>
647   Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
648                 bool call_finish_visit = true,
649                 bool ignore_control_predecessors = false);
650   Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
651                 bool ignore_control_predecessors = false) const {
652     return const_cast<HloInstruction*>(this)->Accept(
653         visitor, call_finish_visit, ignore_control_predecessors);
654   }
655 
656   // Same as Accept() above, but the order of operand and control predecessor
657   // visitation is determined by the given operand order; if compare(A, B) ==
658   // true, A is visited before B.
659   using CompareFunction =
660       std::function<bool(const HloInstruction*, const HloInstruction*)>;
661   Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
662                                 const CompareFunction& operand_order,
663                                 bool call_finish_visit = true);
664 
665   // Performs a postorder DFS visit using this node as the root. Calls the given
666   // visitor function at each instruction.
667   Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
668   Status Accept(
669       const std::function<Status(const HloInstruction*)>& visitor_func) const;
670 
671   // Visits all instructions rooted at this instruction using the given visitor
672   // in the given order. 'order' must contain at least the set of instructions
673   // rooted at this node (ie, those accessible from a DFS traversal from this
674   // instruction). Instructions contained in 'order' which are not in the set of
675   // instructions rooted at this node are ignored. 'order' must also be a valid
676   // topological sort of these instructions (defs appear before uses) though
677   // need not be a DFS post-order.
678   Status AcceptOrdered(DfsHloVisitor* visitor,
679                        const std::vector<const HloInstruction*>& order);
680 
681   // Visit this instruction and only this instruction with the given visitor.
682   template <typename HloInstructionPtr>
683   Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
684 
685   // Returns the literal associated with this instruction.
686   //
687   // Note: only constant and parameter opcodes have an associated literal.
688   const Literal& literal() const;
689 
690   // Returns the parameter number associated with this instruction.
691   //
692   // Note: only parameter opcodes have an associated parameter number.
parameter_number()693   int64 parameter_number() const {
694     CHECK_EQ(HloOpcode::kParameter, opcode_);
695     return parameter_number_;
696   }
697 
698   // Returns the dimension sizes or numbers associated with this instruction.
699   //
700   // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
701   // and reverse.
702   const std::vector<int64>& dimensions() const;
703   int64 dimensions(int64 index) const;
704 
705   // Accessor for the dimension in which a concatenate HLO should occur.
706   // Precondition: opcode() == HloOpcode::kConcatenate
707   int64 concatenate_dimension() const;
708 
709   // Returns the tuple index associated with this instruction.
710   //
711   // Precondition: opcode() == HloOpcode::kGetTupleElement
712   int64 tuple_index() const;
713 
714   // Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
715   // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
716   // (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
717   std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
718       const;
719 
LatestNonGteAncestorAndIndex()720   std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
721     auto rv =
722         const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
723     return {const_cast<HloInstruction*>(rv.first), rv.second};
724   }
725 
726   // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
727   const HloInstruction* LatestNonGteAncestor() const;
728 
LatestNonGteAncestor()729   HloInstruction* LatestNonGteAncestor() {
730     return const_cast<HloInstruction*>(
731         const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
732   }
733 
734   // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
735   // The setter should only be called by HloModule or HloComputation methods.
736   //
737   // Precondition: The instruction has a valid to_apply_ field.
738   HloComputation* to_apply() const;
739   void set_to_apply(HloComputation* to_apply);
740 
741   // Returns the custom_call_target for CustomCall.
742   // Precondition: opcode() == HloOpcode::kCustomCall
743   const string& custom_call_target() const;
744 
745   // Returns the config for the Outfeed instruction.
746   // Precondition: opcode() == HloOpcode::kOutfeed
747   const string& outfeed_config() const;
748 
749   // Returns the shape for the Outfeed instruction.
750   // Precondition: opcode() == HloOpcode::kOutfeed
751   const Shape& outfeed_shape() const;
752 
753   // Gets/sets the while_condition or while_body HloComputation for While. The
754   // setters should only be called by HloModule or HloComputation methods.
755   //
756   // Precondition: The instruction is a While instruction.
757   HloComputation* while_condition() const;
758   HloComputation* while_body() const;
759   void set_while_condition(HloComputation* while_condition);
760   void set_while_body(HloComputation* while_body);
761 
762   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
763   // setters should only be called by HloModule or HloComputation methods.
764   //
765   // Precondition: opcode() == HloOpcode::kSelectAndScatter.
766   HloComputation* select() const;
767   HloComputation* scatter() const;
768   void set_select(HloComputation* select);
769   void set_scatter(HloComputation* scatter);
770 
771   // Gets/sets the true and false HloComputation for Conditional. The setters
772   // should only be called by HloModule or HloComputation methods.
773   //
774   // Precondition: The instruction is a Conditional instruction.
775   HloComputation* true_computation() const;
776   HloComputation* false_computation() const;
777   void set_true_computation(HloComputation* true_computation);
778   void set_false_computation(HloComputation* false_computation);
779 
780   // Returns a string for the signature of this instruction if considered as a
781   // function, e.g. the signature of an F32 add is (F32, F32) -> F32.
782   string SignatureString() const;
783 
784   // Returns a debugging string that represents this instruction.
785   //
786   // (We express the default options using an overload rather than a default
787   // param because gdb ignores default params, but does resolve overloads.)
788   //
789   // TODO(b/73348663): Make ToString() adaptive to the size of the string by
790   // default, backing off on providing full information for very large strings,
791   // or provide a different name for a ToString-like function that does that.
ToString()792   string ToString() const { return ToString(HloPrintOptions()); }
793   string ToString(const HloPrintOptions& options) const;
794 
795   // Components of the ToString() representation:
796 
797   // Returns a string representation of the operand list.
798   string OperandsToString(const HloPrintOptions& options) const;
799 
800   // Returns string representation of op-specific attributes.
801   std::vector<string> ExtraAttributesToString(
802       const HloPrintOptions& options) const;
803 
804   // As ToString, but returns a shorter string.
805   string ToShortString() const;
806 
807   // Returns a serialized representation of this instruction.
808   HloInstructionProto ToProto() const;
809 
810   // Returns a category for the HLO. This could be something like "convolution"
811   // or "elementwise".
812   string ToCategory() const;
813 
814   // Returns a logging instruction, if the output of this instruction is logged.
815   //
816   // Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
817   HloInstruction* tracing() const;
818   void set_tracing(HloInstruction* trace_instruction);
819 
820   // Returns the channel id associated with the instruction. The id is
821   // shared between each Send/Recv pair and is globally unique to identify each
822   // channel.
823   //
824   // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv
channel_id()825   int64 channel_id() const { return channel_id_; }
826 
827   // Returns feature_index field associated with the instruction. The index
828   // represents the index of the feature dimension.
829   //
830   // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
831   // or kBatchNormGrad.
feature_index()832   int64 feature_index() const { return feature_index_; }
833 
834   // Returns a epsilon value associated with the instruction. The is a small
835   // number added to the variance to avoid divide-by-zero error.
836   //
837   // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference,
838   // or kBatchNormGrad.
epsilon()839   float epsilon() const { return epsilon_; }
840 
841   // Returns the infeed configuration string. The infeed configuration includes
842   // any metadata needed for the backend compiler (e.g., infeed buffer address)
843   // and is target-dependent.
infeed_config()844   string infeed_config() const { return infeed_config_; }
set_infeed_config(const string & config)845   void set_infeed_config(const string& config) { infeed_config_ = config; }
846 
847   // Returns a tag to be used in tracing.
848   //
849   // Precondition: opcode() == HloOpcode::kTrace
850   string TracingTag() const;
851 
852   // Returns whether the instruction is a constant.
853   bool IsConstant() const;
854 
855   // Returns true if this instruction is fused, ie contained within a fusion
856   // instruction.
857   bool IsFused() const;
858 
859   // Returns the computation for this fused instruction.
860   //
861   // Precondition: opcode() == HloOpcode::kFusion
862   HloComputation* fused_instructions_computation() const;
863 
864   // Returns true if this instruction can be legally fused into a fusion
865   // instruction.
866   bool IsFusable() const;
867 
868   // Returns the root instruction of the fused expression contained within this
869   // fusion instruction.
870   //
871   // Precondition: opcode() == HloOpcode::kFusion
872   HloInstruction* fused_expression_root() const;
873 
874   // Returns the list of fused instructions inside this fusion instruction.  The
875   // returned type is a range of HloInstruction*s.
876   //
877   // Precondition: opcode() == HloOpcode::kFusion
878   const tensorflow::gtl::iterator_range<UnwrappingIterator<
879       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
880   fused_instructions() const;
881 
882   const tensorflow::gtl::iterator_range<
883       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
884   fused_instructions();
885 
886   // Gets the number of instructions inside this fusion instruction.
887   //
888   // Precondition: opcode() == HloOpcode::kFusion
889   int64 fused_instruction_count() const;
890 
891   // Returns the fused parameter instruction in this fusion instruction
892   // corresponding to the given parameter number.
893   //
894   // Precondition: opcode() == HloOpcode::kFusion
895   HloInstruction* fused_parameter(int64 parameter_number) const;
896 
897   // Returns the vector of fused parameters inside this fusion instruction.
898   //
899   // Precondition: opcode() == HloOpcode::kFusion
900   const std::vector<HloInstruction*>& fused_parameters() const;
901 
902   // Returns true if this instruction is a fusion instruction that generates
903   // multiple outputs.
IsMultiOutputFusion()904   const bool IsMultiOutputFusion() const {
905     return opcode() == HloOpcode::kFusion &&
906            fused_expression_root()->opcode() == HloOpcode::kTuple;
907   }
908 
fusion_kind()909   FusionKind fusion_kind() const {
910     CHECK_EQ(HloOpcode::kFusion, opcode_);
911     return fusion_kind_;
912   }
913 
set_fusion_kind(FusionKind kind)914   void set_fusion_kind(FusionKind kind) {
915     CHECK_EQ(HloOpcode::kFusion, opcode_);
916     fusion_kind_ = kind;
917   }
918 
919   // Returns the sharding applied to this operator.
920   // REQUIRES: has_sharding() is true.
sharding()921   const HloSharding& sharding() const {
922     CHECK(has_sharding());
923     return *sharding_;
924   }
925   // Returns the sharding applied to this operator, or default_ if none exists.
sharding_or_default(const HloSharding & default_)926   const HloSharding& sharding_or_default(const HloSharding& default_) const {
927     return sharding_ ? *sharding_ : default_;
928   }
929   // Sets the sharding of this operator. Should only be called by HloModule or
930   // HloComputation methods.
set_sharding(const HloSharding & sharding)931   void set_sharding(const HloSharding& sharding) {
932     sharding_ = MakeUnique<HloSharding>(sharding);
933   }
934   // Remove any sharding from this operator.
clear_sharding()935   void clear_sharding() { sharding_ = nullptr; }
936   // Return true if this operator has a sharding assigned.
has_sharding()937   bool has_sharding() const { return sharding_ != nullptr; }
938 
939   // Adds a new operand the fusion instruction.
940   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
941 
942   // Merges the fused instructions from 'instruction_to_merge' into the
943   // fused instruction set of 'this', updating operands as necessary.
944   //
945   // Precondition: opcode() == HloOpcode::kFusion
946   // Predondition: 'instruction_to_merge' must be an operand of 'this'.
947   void MergeFusionInstruction(HloInstruction* instruction_to_merge);
948 
949   // Merges the fused instructions from instruction_to_merge into the fused
950   // instruction set of 'this' and generates multioutput fusion instructions.
951   // All the users of instruction_to_merge will be redirected to 'this'
952   // instruction. instruction_to_merge will be removed from its parent
953   // computation.
954   //
955   // Precondition: opcode() == HloOpcode::kFusion
956   void MergeFusionInstructionIntoMultiOutput(
957       HloInstruction* instruction_to_merge);
958 
959   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
960   // is cloned and the clone is placed in the fusion
961   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
962   // than moved to cleanly handle the case where the instruction has a use
963   // outside the fusion instruction. Moving such an instruction into a fusion
964   // instruction would violate the single-result invariant of HLO instructions
965   // and significantly complicate code generation.
966   //
967   // Precondition: this->opcode() == HloOpcode::kFusion
FuseInstruction(HloInstruction * instruction_to_fuse)968   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
969     return FuseInstructionInternal(instruction_to_fuse);
970   }
971 
972   // Fuses the given instruction in this fusion instruction and generate
973   // multioutput fusion instruction. A clone of the instruction_to_fuse will
974   // be part of the output of fusion instructions. The users of
975   // instruction_to_fuse will be redirected to this fusion instructions.
976   // instruction_to_fuse will be removed from its parent computation.
977   //
978   // Precondition: this->opcode() == HloOpcode::kFusion
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)979   HloInstruction* FuseInstructionIntoMultiOutput(
980       HloInstruction* instruction_to_fuse) {
981     return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
982   }
983 
984   // Returns the start index in the given dimension for a slice node.
985   //
986   // Precondition: opcode() == HloOpcode::kSlice
slice_starts(int64 dimension)987   int64 slice_starts(int64 dimension) const {
988     CHECK_EQ(HloOpcode::kSlice, opcode_);
989     return slice_starts_[dimension];
990   }
slice_starts()991   const std::vector<int64>& slice_starts() const { return slice_starts_; }
992 
993   // Returns the (exclusive) limit index in the given dimension for a slice
994   // node.
995   //
996   // Precondition: opcode() == HloOpcode::kSlice
slice_limits(int64 dimension)997   int64 slice_limits(int64 dimension) const {
998     CHECK_EQ(HloOpcode::kSlice, opcode_);
999     return slice_limits_[dimension];
1000   }
slice_limits()1001   const std::vector<int64>& slice_limits() const {
1002     CHECK_EQ(HloOpcode::kSlice, opcode_);
1003     return slice_limits_;
1004   }
1005 
1006   // Returns the stride in the given dimension for a slice node.
1007   //
1008   // Precondition: opcode() == HloOpcode::kSlice
slice_strides(int64 dimension)1009   int64 slice_strides(int64 dimension) const {
1010     CHECK_EQ(HloOpcode::kSlice, opcode_);
1011     return slice_strides_[dimension];
1012   }
slice_strides()1013   const std::vector<int64>& slice_strides() const { return slice_strides_; }
1014 
1015   // Returns the flag that describes whether a slice must be lowered into an
1016   // offset into the original operand.
IsInPlaceSlice()1017   bool IsInPlaceSlice() const { return is_in_place_slice_; }
1018 
1019   // Sets and returns the flag that describes whether a slice must be lowered
1020   // into an offset into the original operand.
SetIsInPlaceSlice(bool value)1021   bool SetIsInPlaceSlice(bool value) {
1022     is_in_place_slice_ = value;
1023     return value;
1024   }
1025 
1026   // Returns the size of the slice in the given dimension for a dynamic
1027   // slice node.
1028   //
1029   // Precondition: opcode() == HloOpcode::kDynamicSlice
slice_sizes(int64 dimension)1030   int64 slice_sizes(int64 dimension) const {
1031     CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1032     return dynamic_slice_sizes_[dimension];
1033   }
dynamic_slice_sizes()1034   const std::vector<int64>& dynamic_slice_sizes() const {
1035     CHECK_EQ(HloOpcode::kDynamicSlice, opcode_);
1036     return dynamic_slice_sizes_;
1037   }
1038 
1039   // Returns the number of exponent bits for a reduce-precision node.
1040   //
1041   // Precondition: opcode() == HloOpcode::kReducePrecision
exponent_bits()1042   int32 exponent_bits() const {
1043     CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1044     return exponent_bits_;
1045   }
1046 
1047   // Returns the number of mantissa bits for a reduce-precision node.
1048   //
1049   // Precondition: opcode() == HloOpcode::kReducePrecision
mantissa_bits()1050   int32 mantissa_bits() const {
1051     CHECK_EQ(HloOpcode::kReducePrecision, opcode_);
1052     return mantissa_bits_;
1053   }
1054 
1055   // Returns data on the window in a windowed operation such as
1056   // convolution.
window()1057   const Window& window() const {
1058     CHECK(window_ != nullptr);
1059     return *window_;
1060   }
1061 
1062   // Sets the window data in a windowed operation such as convolution.
set_window(const Window & window)1063   void set_window(const Window& window) {
1064     window_ = MakeUnique<Window>(window);
1065   }
1066 
1067   // Returns the padding configuration for a pad node.
1068   //
1069   // Precondition: opcode() == HloOpcode::kPad
padding_config()1070   const PaddingConfig& padding_config() const {
1071     CHECK(padding_config_ != nullptr);
1072     return *padding_config_;
1073   }
1074 
1075   // Returns data on the dimension numbers used for a convolution operation,
1076   // which may be a kConvolution instruction or a kCustomCall that implements a
1077   // convolution.
convolution_dimension_numbers()1078   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1079     CHECK(convolution_dimension_numbers_ != nullptr);
1080     return *convolution_dimension_numbers_;
1081   }
1082 
1083   // Sets the convolution dimension numbers on this instruction.  In general you
1084   // shouldn't need to call this; instead, specify the convolution dimension
1085   // numbers when you create the instruction.
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1086   void set_convolution_dimension_numbers(
1087       const ConvolutionDimensionNumbers& dnums) {
1088     convolution_dimension_numbers_ =
1089         MakeUnique<ConvolutionDimensionNumbers>(dnums);
1090   }
1091 
fft_type()1092   FftType fft_type() const {
1093     CHECK_EQ(HloOpcode::kFft, opcode_);
1094     return fft_type_;
1095   }
1096 
fft_length()1097   const std::vector<int64>& fft_length() const {
1098     CHECK_EQ(HloOpcode::kFft, opcode_);
1099     return fft_length_;
1100   }
1101 
1102   // Returns the dump string of the convolution dimension numbers.
1103   string ConvolutionDimensionNumbersToString() const;
1104 
1105   // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()1106   const DotDimensionNumbers& dot_dimension_numbers() const {
1107     CHECK(dot_dimension_numbers_ != nullptr);
1108     return *dot_dimension_numbers_;
1109   }
1110 
1111   // Returns the dump string of the dot dimension numbers.
1112   string DotDimensionNumbersToString() const;
1113 
gather_dimension_numbers()1114   const GatherDimensionNumbers& gather_dimension_numbers() const {
1115     CHECK(gather_dimension_numbers_ != nullptr);
1116     return *gather_dimension_numbers_;
1117   }
1118 
gather_window_bounds()1119   tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
1120     CHECK_EQ(opcode(), HloOpcode::kGather);
1121     return gather_window_bounds_;
1122   }
1123 
1124   // Returns the dump string of the gather dimension numbers.
1125   string GatherDimensionNumbersToString() const;
1126 
1127   // Returns the random distribution for this rng node.
1128   //
1129   // Precondition: opcode() == HloOpcode::kRng
1130   RandomDistribution random_distribution() const;
1131 
1132   // Clones the HLO instruction. The clone will have the same opcode, shape, and
1133   // operands. After creation the clone has no uses. "this" (the instruction
1134   // cloned from) is not changed. Suffix is the string to append to the name of
1135   // the instruction to form the name of the cloned instruction.
1136   // If the module pointer is not nullptr, it will be the module where
1137   // the cloned computations will be added to (in order to support deep
1138   // cloning).
1139   std::unique_ptr<HloInstruction> Clone(const string& suffix = "clone",
1140                                         HloModule* module = nullptr) const;
1141 
1142   // Clones the HLO instruction as above but with new shape and operands.
1143   // If the module pointer is not nullptr, it will be the module where
1144   // the cloned computations will be added to (in order to support deep
1145   // cloning).
1146   std::unique_ptr<HloInstruction> CloneWithNewOperands(
1147       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1148       HloModule* module = nullptr) const;
1149 
1150   // Returns the computations this instruction directly calls (if any).
called_computations()1151   const std::vector<HloComputation*>& called_computations() const {
1152     return called_computations_;
1153   }
1154 
1155   // Replaces all called computations based on a map function. This is needed
1156   // when we clone hlo_computations and want to let the instructions to point
1157   // to the newly cloned nodes.
ReplaceCalledComputations(std::function<HloComputation * (HloComputation *)> map_function)1158   void ReplaceCalledComputations(
1159       std::function<HloComputation*(HloComputation*)> map_function) {
1160     for (int64 i = 0; i < called_computations_.size(); ++i) {
1161       called_computations_[i] = map_function(called_computations_[i]);
1162     }
1163   }
1164 
1165   // Clears out the called computations.
1166   //
1167   // This is, in particular, necessary when inlining function bodies into their
1168   // caller. If there were side-effecting operations in the called computations,
1169   // the call itself is considered side-effecting and thus cannot be removed. By
1170   // clearing out the computations, we reflect the fact that all side-effecting
1171   // properties have been reflected in the caller, and make the call HLO
1172   // removable.
ClearCalledComputations()1173   void ClearCalledComputations() { called_computations_.clear(); }
1174 
1175   // Returns true if this instruction performs an elementwise operation on
1176   // `operand_idx`-th operand. An instruction is elementwise on an operand iff,
1177   // after performing necessary implicit broadcast
1178   // (cs/IrArray::EmitArrayElementAddress), to compute the output at index
1179   // {i_0,i_1,...,i_n}, the only element required from the operand (if any) is
1180   // the element at {i_0,i_1,...,i_n}.
1181   //
1182   // Note on performance: when this instruction is kFusion, this method, in the
1183   // worst case, scans all fused instructions. We could speed this up by
1184   // caching.
1185   bool IsElementwiseOnOperand(int64 operand_idx) const;
1186 
1187   // Returns true if this instruction is elementwise on all its operands.
1188   bool IsElementwise() const;
1189 
1190   // Returns true if this elementwise instruction implicitly broadcasts operand
1191   // `operand_idx`.
1192   //
1193   // Precondition: this instruction should be an elementwise operation.
1194   bool ImplicitlyBroadcastsOperand(int64 operand_idx) const;
1195 
1196   // Returns true if this instruction is binary and elementwise.
1197   bool IsElementwiseBinary() const;
1198 
1199   // Returns whether this instruction may reuse elements of its `i`th operand.
ReusesOperandElements(int64 i)1200   bool ReusesOperandElements(int64 i) const {
1201     return OperandElementUse(i) == UseKind::kReuse;
1202   }
1203 
1204   // Returns the indices that the given operand appear in the operand list of
1205   // this instruction. Note that an instruction can use the same operand
1206   // multiple times.
1207   std::vector<int64> OperandIndices(const HloInstruction* operand) const;
1208 
1209   // Convenience helper for ShapeUtil::InsertedOrDeleted1SizedDimensions. If
1210   // this reshape merely inserts or deletes 1-sized dimensions, return the input
1211   // indices of the deleted dimensions and the output indices of the inserted
1212   // dimensions.
1213   //
1214   // Precondition: this op must be a reshape.
1215   std::tuple<bool, std::vector<int64>, std::vector<int64>>
1216   ReshapeMerelyInsertsOrDeletes1SizedDimensions() const;
1217 
1218   // Gets/sets the string identifier for this instruction.
name()1219   const string& name() const { return name_; }
set_name(tensorflow::StringPiece name)1220   void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); }
1221 
1222   // Use the given NameUniquer to select a unique name for the instruction based
1223   // on the instruction's existing name.
1224   void UniquifyName(NameUniquer* name_uniquer);
1225 
1226   // Set the unique id for this instruction to "id"
SetUniqueId(int id)1227   void SetUniqueId(int id) {
1228     CHECK_EQ(unique_id_, -1);  // Should not be assigned already
1229     CHECK_GE(id, 0);
1230     unique_id_ = id;
1231   }
1232 
1233   // Return the unique ID assigned to this node via SetUniqueId (or -1
1234   // if no id has been assigned yet).
unique_id()1235   int unique_id() const { return unique_id_; }
1236 
1237   // Sets the debug metadata for this instruction.
set_metadata(const OpMetadata & metadata)1238   void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
metadata()1239   const OpMetadata& metadata() const { return metadata_; }
1240 
1241   // Set/get the computation containing this instruction. set_parent should only
1242   // be called by HloComputation methods which add/remove instructions to
1243   // computations.
set_parent(HloComputation * computation)1244   void set_parent(HloComputation* computation) { parent_ = computation; }
parent()1245   const HloComputation* parent() const { return parent_; }
parent()1246   HloComputation* parent() { return parent_; }
1247 
1248   // Returns the module for this instruction.
1249   HloModule* GetModule() const;
1250 
1251   // Returns whether we could assign input and output layouts to this
1252   // instruction to make it a bitcast.
1253   bool CouldBeBitcast() const;
1254 
1255   // Get/Set the number of partitions per outer dimension (in order, starting
1256   // with outer-most dimension first). Currently used by the parallel cpu
1257   // backend to partition HLOs into parallel tasks.
1258   // TODO(b/62783254) Replace these methods with a more general way to
1259   // annotate HLOs with backend-specific information.
outer_dimension_partitions()1260   const std::vector<int64>& outer_dimension_partitions() const {
1261     return outer_dimension_partitions_;
1262   }
1263   void set_outer_dimension_partitions(
1264       const std::vector<int64>& outer_dimension_partitions);
1265 
1266   // Change the layout for an Constant Hlo instruction to match new_layout.  For
1267   // tuple shaped constants shape_index is the path to the internal array
1268   // subshape whose layout needs to be changed.
1269   void RelayoutConstant(const Layout& new_layout,
1270                         const ShapeIndex& shape_index = {});
1271 
1272  private:
1273   enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
1274 
1275   // Helper class for computing OperandElementUse for kFusion.
1276   class FusionReusesParamElements;
1277 
1278   // See comments on Identical().
1279   // eq_shapes() is used to check shapes for equality, and would normally be
1280   // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on
1281   // whether we want a layout-sensitive check or not.
1282   bool IdenticalSlowPath(
1283       const HloInstruction& other,
1284       const std::function<bool(const HloComputation*, const HloComputation*)>&
1285           eq_computations,
1286       const std::function<bool(const Shape&, const Shape&)>& eq_shapes) const;
1287 
1288   // Creates an n-ary elementwise operation.
1289   static std::unique_ptr<HloInstruction> CreateNary(
1290       const Shape& shape, HloOpcode opcode,
1291       tensorflow::gtl::ArraySlice<HloInstruction*> operands);
1292 
1293   // Appends operand to the list of operands and adds this instruction as a user
1294   // of the operand.
1295   void AppendOperand(HloInstruction* operand);
1296 
1297   // Adds a user for this instruction.
1298   void AddUser(HloInstruction* user);
1299 
1300   // Removes a user for this instruction.
1301   void RemoveUser(HloInstruction* user);
1302 
1303   // Internal constructor for a given opcode/shape, other fields must be filled
1304   // by factory methods.
1305   HloInstruction(HloOpcode opcode, const Shape& shape);
1306 
1307   // Fuses the given instruction into this fusion instruction. When add_output
1308   // is false (which is the default), instruction_to_fuse is cloned and the
1309   // clone is placed in the fusion instruction. instruction_to_fuse is
1310   // unchanged.
1311   //
1312   // When add_output is true, a clone of the instruction_to_fuse will be part
1313   // of the output of fusion instructions. The users of instruction_to_fuse
1314   // will be redirected to this fusion instructions. instruction_to_fuse will
1315   // be removed from its parent computation.
1316   //
1317   // Precondition: this->opcode() == HloOpcode::kFusion
1318   HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
1319                                           bool add_output = false);
1320 
1321   // Clones the given instruction_to_fuse and insert the clone into this fusion
1322   // instruction. If add_output is true, a clone of instruction_to_fuse will
1323   // be in the output of the this fusion instruction (part of the tuple of the
1324   // fusion root).
1325   //
1326   // Precondition: opcode() == HloOpcode::kFusion
1327   HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
1328                                        bool add_output = false);
1329 
1330   // Clones a fusion instruction with a new shape and operands.
1331   std::unique_ptr<HloInstruction> CloneFusionWithNewOperands(
1332       const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
1333       HloModule* module = nullptr) const;
1334 
1335   // Returns true if this instruction can legally have the dimensions field
1336   // set. Used for checking precondition of dimensions field accessors.
1337   bool CanHaveDimensionsField() const;
1338 
1339   // Returns how this instruction uses elements of its `i`th operand.
1340   UseKind OperandElementUse(int64 i) const;
1341 
1342   int unique_id_;  // Unique to this HloInstruction within a HloModule
1343 
1344   // Opcode for this instruction.
1345   HloOpcode opcode_;
1346 
1347   // Instruction operands.
1348   InstructionVector operands_;
1349 
1350   // The set of control predecessors of this instruction.
1351   std::vector<HloInstruction*> control_predecessors_;
1352 
1353   // The users of this instruction. Users are HLOs where this instruction is an
1354   // operand. The vector users_ and the set user_set_ contain identical
1355   // members. The set enables fast membership testing and the vector enables
1356   // fast, stable iteration.
1357   std::vector<HloInstruction*> users_;
1358   std::unordered_set<const HloInstruction*> user_set_;
1359 
1360   // The set of control successors of this instruction.
1361   std::vector<HloInstruction*> control_successors_;
1362 
1363   // The computation in which this instruction is contained.
1364   HloComputation* parent_ = nullptr;
1365 
1366   // Shape of outfeed request.
1367   Shape outfeed_shape_;
1368 
1369   // Result shape of this instruction.
1370   Shape shape_;
1371 
1372   // Literal, only present for kConstant.
1373   std::unique_ptr<Literal> literal_;
1374 
1375   // Constant index, only present for kGetTupleElement.
1376   int64 tuple_index_ = -1;
1377 
1378   // Dimensions present for some operations that require reshaping or
1379   // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse.
1380   std::vector<int64> dimensions_;
1381 
1382   // Describes the window in a windowed operation such as convolution.
1383   std::unique_ptr<Window> window_;
1384 
1385   // Describes the dimension numbers used for a convolution.
1386   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1387 
1388   // Describes the dimension numbers used for a dot.
1389   std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
1390 
1391   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1392   std::vector<int64> gather_window_bounds_;
1393 
1394   // Describes FFT type for an FFT instruction.
1395   FftType fft_type_ = FftType::FFT;
1396 
1397   // Indicates the FFT length for an FFT instruction.
1398   std::vector<int64> fft_length_;
1399 
1400   // Describes the [begin, end) index range for a slice.
1401   std::vector<int64> slice_starts_;
1402   std::vector<int64> slice_limits_;
1403   std::vector<int64> slice_strides_;
1404 
1405   // Describes whether the slice can be lowered to an offset into the operand.
1406   bool is_in_place_slice_ = false;
1407 
1408   // The bit sizes for a reduce-precision operation.
1409   int32 exponent_bits_ = 0;
1410   int32 mantissa_bits_ = 0;
1411 
1412   // Describes the [start, start + size) range size for a dynamic slice
1413   // ('start' is specified dynamically in the second operand of the operation).
1414   std::vector<int64> dynamic_slice_sizes_;
1415 
1416   // The padding configuration that describes the edge padding and interior
1417   // padding of this pad instruction. Only set for pad instructions.
1418   std::unique_ptr<PaddingConfig> padding_config_;
1419 
1420   // The type of the fusion. Used by kFusion only.
1421   FusionKind fusion_kind_;
1422 
1423   // The sharding, if one exists.
1424   std::unique_ptr<HloSharding> sharding_;
1425 
1426   // For parameter instructions this field holds the parameter number.
1427   int64 parameter_number_ = 0;
1428 
1429   // Name of a global symbol to call, only present for kCustomCall.
1430   string custom_call_target_;
1431 
1432   // Name to use for host send/recv channels, only present for kHostCompute.
1433   string channel_name_;
1434 
1435   // Estimate of the duration of a host computation in nanoseconds.
1436   int64 cost_estimate_ns_;
1437 
1438   // Computations called by this instruction.
1439   std::vector<HloComputation*> called_computations_;
1440 
1441   // Indices of computations in called_computations_ for instructions which call
1442   // multiple computations.
1443   enum {
1444     // kWhile computations.
1445     kBodyComputationIndex = 0,
1446     kConditionComputationIndex = 1,
1447 
1448     // kSelectAndScatter computations.
1449     kSelectComputationIndex = 0,
1450     kScatterComputationIndex = 1,
1451 
1452     // kConditional computations.
1453     kTrueComputationIndex = 0,
1454     kFalseComputationIndex = 1,
1455   };
1456 
1457   // Outfeed configuration information, only present for kOutfeed.
1458   string outfeed_config_;
1459 
1460   // A trace instruction that consumes this instruction.
1461   //
1462   // Invariant: if trace_instruction_ != nullptr, trace_instruction has this as
1463   // an operand.
1464   HloInstruction* trace_instruction_ = nullptr;
1465 
1466   // The distribution requested for random number generation.
1467   // Only present for kRng.
1468   RandomDistribution distribution_;
1469 
1470   // A small float number added to the variance to avoid divide-by-zero error.
1471   // Only present for kBatchNormTraining.
1472   float epsilon_ = 0.0f;
1473 
1474   // An integer value representing the index of the feature dimension.
1475   // Only present for kBatchNormTraining.
1476   int64 feature_index_ = -1;
1477 
1478   // Represents a unique identifier for each Send/Recv instruction pair.
1479   // Only present for kSend or kRecv.
1480   int64 channel_id_ = -1;
1481 
1482   // The string representation of the infeed configuration.
1483   string infeed_config_;
1484 
1485   // String identifier for instruction.
1486   string name_;
1487 
1488   // Metadata for debugging.
1489   OpMetadata metadata_;
1490 
1491   // The number of partitions per outer dimension (listed in order from
1492   // outer-most dimension first).
1493   std::vector<int64> outer_dimension_partitions_;
1494 
1495   TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction);
1496 };
1497 
1498 string ToString(HloInstruction::FusionKind kind);
1499 StatusOr<HloInstruction::FusionKind> StringToFusionKind(
1500     const string& kind_name);
1501 
1502 // Custom (de)stringification functions for protos that live inside
1503 // HloInstruction.
1504 string PaddingConfigToString(const PaddingConfig& padding);
1505 string OpMetadataToString(const OpMetadata& metadata);
1506 string RandomDistributionToString(const RandomDistribution& distribution);
1507 StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
1508 
1509 std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);
1510 
1511 // Map classes that guarantee a deterministic iteration order when the key is
1512 // an HloInstruction* or a const HloInstruction*.
1513 // To make the iteration order over the map deterministic, the comparator
1514 // should not be using the pointer values, but rather an intrinsic property of
1515 // the hlo.
1516 //
1517 // Note that this cannot be used for HLO instructions across multiple modules
1518 // since the id of HLO instructions are only unique within each HLO module.
1519 struct HloPtrComparator {
operatorHloPtrComparator1520   bool operator()(const HloInstruction* const& lhs,
1521                   const HloInstruction* const& rhs) const {
1522     return lhs->unique_id() < rhs->unique_id();
1523   }
1524 };
1525 
1526 template <typename ValueT>
1527 using HloInstructionMap = std::map<HloInstruction*, ValueT, HloPtrComparator>;
1528 
1529 template <typename ValueT>
1530 using ConstHloInstructionMap =
1531     std::map<const HloInstruction*, ValueT, HloPtrComparator>;
1532 
1533 using HloInstructionSet = std::set<HloInstruction*, HloPtrComparator>;
1534 using ConstHloInstructionSet =
1535     std::set<const HloInstruction*, HloPtrComparator>;
1536 
1537 }  // namespace xla
1538 
1539 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
1540