1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // All HloInstruction subclasses are put in this file.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/shape.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 
26 namespace xla {
27 
28 class HloBatchNormInstruction : public HloInstruction {
29  public:
30   // Returns feature_index field associated with the instruction. The index
31   // represents the index of the feature dimension.
feature_index()32   int64 feature_index() const { return feature_index_; }
33 
34   // Returns a epsilon value associated with the instruction. The is a small
35   // number added to the variance to avoid divide-by-zero error.
epsilon()36   float epsilon() const { return epsilon_; }
37 
38   // Returns a serialized representation of this instruction.
39   HloInstructionProto ToProto() const override;
40 
41  protected:
42   explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
43                                    HloInstruction* operand,
44                                    HloInstruction* scale, float epsilon,
45                                    int64 feature_index);
46 
47  private:
48   std::vector<string> ExtraAttributesToStringImpl(
49       const HloPrintOptions& options) const override;
50   bool IdenticalSlowPath(
51       const HloInstruction& other,
52       const std::function<bool(const HloComputation*, const HloComputation*)>&
53           eq_computations) const override;
54   // A small float number added to the variance to avoid divide-by-zero error.
55   float epsilon_ = 0.0f;
56 
57   // An integer value representing the index of the feature dimension.
58   int64 feature_index_ = -1;
59 };
60 
61 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
62  public:
63   explicit HloBatchNormTrainingInstruction(const Shape& shape,
64                                            HloInstruction* operand,
65                                            HloInstruction* scale,
66                                            HloInstruction* offset,
67                                            float epsilon, int64 feature_index);
68 
69  private:
70   // Implementation for non-common logic of CloneWithNewOperands.
71   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
72       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
73       HloCloneContext* context) const override;
74 };
75 
76 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
77  public:
78   explicit HloBatchNormInferenceInstruction(
79       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
80       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
81       float epsilon, int64 feature_index);
82 
83  private:
84   // Implementation for non-common logic of CloneWithNewOperands.
85   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
86       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
87       HloCloneContext* context) const override;
88 };
89 
90 class HloBatchNormGradInstruction : public HloBatchNormInstruction {
91  public:
92   explicit HloBatchNormGradInstruction(
93       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
94       HloInstruction* mean, HloInstruction* variance,
95       HloInstruction* grad_output, float epsilon, int64 feature_index);
96 
97  private:
98   // Implementation for non-common logic of CloneWithNewOperands.
99   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
100       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
101       HloCloneContext* context) const override;
102 };
103 
104 class HloFftInstruction : public HloInstruction {
105  public:
106   explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
107                              FftType fft_type,
108                              absl::Span<const int64> fft_length);
fft_type()109   FftType fft_type() const { return fft_type_; }
110 
fft_length()111   const std::vector<int64>& fft_length() const { return fft_length_; }
112 
113   // Returns a serialized representation of this instruction.
114   HloInstructionProto ToProto() const override;
115 
116  private:
117   std::vector<string> ExtraAttributesToStringImpl(
118       const HloPrintOptions& options) const override;
119   bool IdenticalSlowPath(
120       const HloInstruction& other,
121       const std::function<bool(const HloComputation*, const HloComputation*)>&
122           eq_computations) const override;
123 
124   // Implementation for non-common logic of CloneWithNewOperands.
125   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
126       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
127       HloCloneContext* context) const override;
128 
129   // Describes FFT type for an FFT instruction.
130   FftType fft_type_ = FftType::FFT;
131 
132   // Indicates the FFT length for an FFT instruction.
133   std::vector<int64> fft_length_;
134 };
135 
136 class HloCopyStartInstruction : public HloInstruction {
137  public:
138   explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
139                                    bool is_cross_program_prefetch);
140 
is_cross_program_prefetch()141   bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
142   HloInstructionProto ToProto() const override;
143 
144  private:
145   std::vector<string> ExtraAttributesToStringImpl(
146       const HloPrintOptions& options) const override;
147   bool IdenticalSlowPath(
148       const HloInstruction& other,
149       const std::function<bool(const HloComputation*, const HloComputation*)>&
150           eq_computations) const override;
151   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
152       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
153       HloCloneContext* context) const override;
154 
155   bool is_cross_program_prefetch_;
156 };
157 
158 class HloCompareInstruction : public HloInstruction {
159  public:
160   explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
161                                  HloInstruction* rhs,
162                                  ComparisonDirection direction,
163                                  absl::optional<Comparison::Type> type);
direction()164   ComparisonDirection direction() const { return compare_.GetDirection(); }
type()165   Comparison::Type type() const { return compare_.GetType(); }
166   HloInstructionProto ToProto() const override;
167 
168  private:
169   std::vector<string> ExtraAttributesToStringImpl(
170       const HloPrintOptions& options) const override;
171   bool IdenticalSlowPath(
172       const HloInstruction& other,
173       const std::function<bool(const HloComputation*, const HloComputation*)>&
174           eq_computations) const override;
175   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
176       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
177       HloCloneContext* context) const override;
178 
179   Comparison compare_;
180 };
181 
182 class HloTriangularSolveInstruction : public HloInstruction {
183  public:
184   explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
185                                          HloInstruction* b,
186                                          const TriangularSolveOptions& options);
triangular_solve_options()187   const TriangularSolveOptions& triangular_solve_options() const {
188     return triangular_solve_options_;
189   }
190 
191   // Returns a serialized representation of this instruction.
192   HloInstructionProto ToProto() const override;
193 
194  private:
195   std::vector<string> ExtraAttributesToStringImpl(
196       const HloPrintOptions& options) const override;
197   bool IdenticalSlowPath(
198       const HloInstruction& other,
199       const std::function<bool(const HloComputation*, const HloComputation*)>&
200           eq_computations) const override;
201 
202   // Implementation for non-common logic of CloneWithNewOperands.
203   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
204       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
205       HloCloneContext* context) const override;
206 
207   TriangularSolveOptions triangular_solve_options_;
208 };
209 
210 class HloCholeskyInstruction : public HloInstruction {
211  public:
212   explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
213                                   const CholeskyOptions& options);
cholesky_options()214   const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
215 
216   // Returns a serialized representation of this instruction.
217   HloInstructionProto ToProto() const override;
218 
219  private:
220   std::vector<string> ExtraAttributesToStringImpl(
221       const HloPrintOptions& options) const override;
222   bool IdenticalSlowPath(
223       const HloInstruction& other,
224       const std::function<bool(const HloComputation*, const HloComputation*)>&
225           eq_computations) const override;
226 
227   // Implementation for non-common logic of CloneWithNewOperands.
228   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
229       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
230       HloCloneContext* context) const override;
231 
232   CholeskyOptions cholesky_options_;
233 };
234 
235 // Class that represents instructions that synchronize and transfer data between
236 // partitioned devices. Send/Recv and collective instructions (AllReduce,
237 // AllToAll, CollectivePermute) belong to this instruction type. A group of
238 // instructions (of the same opcode) with the same channel_id communicate during
239 // execution.
240 class HloChannelInstruction : public HloInstruction {
241  public:
242   // Returns the channel id associated with the instruction. The id is
243   // shared between each Send/Recv pair or a group of collective instructions
244   // and is globally unique to identify each channel.
channel_id()245   absl::optional<int64> channel_id() const { return channel_id_; }
246   void set_channel_id(const absl::optional<int64>& channel_id);
247 
248   // Whether this instruction is identical to `other` except for the values of
249   // channel IDs, as long as both have channel IDs or neither has a channel ID.
IdenticalSlowPathIgnoringChannelIdValues(const HloInstruction & other,const std::function<bool (const HloComputation *,const HloComputation *)> & eq_computations)250   virtual bool IdenticalSlowPathIgnoringChannelIdValues(
251       const HloInstruction& other,
252       const std::function<bool(const HloComputation*, const HloComputation*)>&
253           eq_computations) const {
254     return channel_id_.has_value() == other.channel_id().has_value();
255   }
256 
257  protected:
258   explicit HloChannelInstruction(HloOpcode opcode, const Shape& shape,
259                                  const absl::optional<int64>& channel_id);
260 
261   HloInstructionProto ToProto() const override;
262 
263   std::vector<string> ExtraAttributesToStringImpl(
264       const HloPrintOptions& options) const override;
265 
266   // Do not override IdenticalSlowPath(). Override
267   // IdenticalSlowPathIgnoringChannelIdValues() instead.
268   bool IdenticalSlowPath(
269       const HloInstruction& other,
270       const std::function<bool(const HloComputation*, const HloComputation*)>&
271           eq_computations) const final;
272 
273   absl::optional<int64> channel_id_;
274 };
275 
276 class HloSendRecvInstruction : public HloChannelInstruction {
277  public:
278   // Returns whether this send/recv instruction sends data to/from the host.
is_host_transfer()279   bool is_host_transfer() const { return is_host_transfer_; }
280 
281   // Returns a serialized representation of this instruction.
282   HloInstructionProto ToProto() const override;
283 
284  protected:
285   explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
286                                   int64 channel_id, bool is_host_transfer);
287 
288  private:
289   std::vector<string> ExtraAttributesToStringImpl(
290       const HloPrintOptions& options) const override;
291   bool IdenticalSlowPathIgnoringChannelIdValues(
292       const HloInstruction& other,
293       const std::function<bool(const HloComputation*, const HloComputation*)>&
294           eq_computations) const override;
295   // Whether this send/recv instruction sends data to/from the host.
296   bool is_host_transfer_;
297 };
298 
299 class HloSendInstruction : public HloSendRecvInstruction {
300  public:
301   explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
302                               int64 channel_id, bool is_host_transfer);
303 
304  private:
305   // Implementation for non-common logic of CloneWithNewOperands.
306   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
307       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
308       HloCloneContext* context) const override;
309 };
310 
311 class HloSendDoneInstruction : public HloSendRecvInstruction {
312  public:
313   explicit HloSendDoneInstruction(HloSendInstruction* operand,
314                                   bool is_host_transfer);
315 
316  private:
317   // Implementation for non-common logic of CloneWithNewOperands.
318   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
319       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
320       HloCloneContext* context) const override;
321 };
322 
323 class HloRecvInstruction : public HloSendRecvInstruction {
324  public:
325   explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
326                               int64 channel_id, bool is_host_transfer);
327 
328  private:
329   // Implementation for non-common logic of CloneWithNewOperands.
330   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
331       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
332       HloCloneContext* context) const override;
333 };
334 
335 class HloRecvDoneInstruction : public HloSendRecvInstruction {
336  public:
337   explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
338                                   bool is_host_transfer);
339 
340  private:
341   // Implementation for non-common logic of CloneWithNewOperands.
342   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
343       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
344       HloCloneContext* context) const override;
345 };
346 
347 class HloCollectiveInstruction : public HloChannelInstruction {
348  public:
replica_groups()349   const std::vector<ReplicaGroup>& replica_groups() const {
350     return replica_groups_;
351   }
352 
353   // Returns true if the layout of the AllReduce is enforced by XLA client (as
354   // the layout set in the shape). The only reason for the client to set the
355   // layout is to separately compile computations that communicate with
356   // AllReduce. Since this field is only set `true` by the client, the compiler
357   // only needs to propagate existing values (e.g., Clone, X64Rewriter) or set
358   // `false` for all other cases.
359   //
360   // When this is `true`, there may be communication endpoints outside the
361   // current compilation unit, so the compiler considers this AllReduce as
362   // side-effecting to disable compiler transformations. The compiler is free to
363   // transform unconstrained AllReduces differently across compilation units.
364   // It is an error for an HloModule to have a mix of constrained and
365   // unconstrained AllReduce instructions (checked by HloVerifier).
constrain_layout()366   bool constrain_layout() const { return constrain_layout_; }
367 
368  protected:
369   explicit HloCollectiveInstruction(
370       HloOpcode opcode, const Shape& shape,
371       absl::Span<HloInstruction* const> operands,
372       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
373       const absl::optional<int64>& channel_id);
374 
375   HloInstructionProto ToProto() const override;
376 
377   std::vector<string> ExtraAttributesToStringImpl(
378       const HloPrintOptions& options) const override;
379   bool IdenticalSlowPathIgnoringChannelIdValues(
380       const HloInstruction& other,
381       const std::function<bool(const HloComputation*, const HloComputation*)>&
382           eq_computations) const override;
383 
384   std::vector<ReplicaGroup> replica_groups_;
385   bool constrain_layout_;
386 };
387 
388 class HloAllGatherInstruction : public HloCollectiveInstruction {
389  public:
390   explicit HloAllGatherInstruction(
391       const Shape& shape, HloInstruction* operand, int64 all_gather_dimension,
392       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
393       const absl::optional<int64>& channel_id, bool use_global_device_ids);
394   // Same as HloAllReduceInstruction::use_global_device_ids.
use_global_device_ids()395   bool use_global_device_ids() const { return use_global_device_ids_; }
396 
397   // The dimension on which data from different participants are concatenated.
all_gather_dimension()398   int64 all_gather_dimension() const { return all_gather_dimension_; }
399 
400  protected:
401   std::vector<string> ExtraAttributesToStringImpl(
402       const HloPrintOptions& options) const override;
403   HloInstructionProto ToProto() const override;
404 
405  private:
406   bool IdenticalSlowPathIgnoringChannelIdValues(
407       const HloInstruction& other,
408       const std::function<bool(const HloComputation*, const HloComputation*)>&
409           eq_computations) const override;
410 
411   // Implementation for non-common logic of CloneWithNewOperands.
412   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
413       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
414       HloCloneContext* context) const override;
415 
416   int64 all_gather_dimension_;
417   bool use_global_device_ids_;
418 };
419 
420 class HloAllReduceInstruction : public HloCollectiveInstruction {
421  public:
422   explicit HloAllReduceInstruction(
423       const Shape& shape, absl::Span<HloInstruction* const> operands,
424       HloComputation* reduce_computation,
425       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
426       const absl::optional<int64>& channel_id, bool use_global_device_ids);
427 
428   // Returns true if the AllReduce does no communication, so it's equivalent
429   // to a mem copy.
430   bool IsNoop() const;
431 
432   // Returns true if the ids in the ReplicaGroup config represent a global id of
433   // (replica_id * partition_count + partition_id) instead of a replica id.
434   // This enables more flexible grouping of devices if this all-reduce is both
435   // cross-partition and cross-replica.
436   //
437   // For example with 2 replicas and 4 partitions,
438   // replica_groups={{0,1,4,5},{2,3,6,7}}, use_global_device_ids=true means that
439   // group[0] = (0,0), (0,1), (1,0), (1,1)
440   // group[1] = (0,2), (0,3), (1,2), (1,3)
441   // where each pair is (replica_id, partition_id).
use_global_device_ids()442   bool use_global_device_ids() const { return use_global_device_ids_; }
443 
444  protected:
445   std::vector<string> ExtraAttributesToStringImpl(
446       const HloPrintOptions& options) const override;
447   HloInstructionProto ToProto() const override;
448 
449  private:
450   bool IdenticalSlowPathIgnoringChannelIdValues(
451       const HloInstruction& other,
452       const std::function<bool(const HloComputation*, const HloComputation*)>&
453           eq_computations) const override;
454 
455   // Implementation for non-common logic of CloneWithNewOperands.
456   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
457       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
458       HloCloneContext* context) const override;
459 
460   bool use_global_device_ids_;
461 };
462 
463 class HloAllToAllInstruction : public HloCollectiveInstruction {
464  public:
465   explicit HloAllToAllInstruction(
466       const Shape& shape, absl::Span<HloInstruction* const> operands,
467       const std::vector<ReplicaGroup>& replica_groups, bool constrain_layout,
468       const absl::optional<int64>& channel_id,
469       const absl::optional<int64>& split_dimension);
470 
471   // AllToAll can optionally take a split dimension, which means that this
472   // AllToAll takes a single (flattened) array operand and produces an array
473   // output (instead of taking a list of operands and producing a tuple).
474   //
475   // split_dimension specifies which dimension in the operand is split across
476   // devices in each replica_group, and also means the concatenated dimension
477   // on the output (i.e., input and the output shapes are the same).
split_dimension()478   absl::optional<int64> split_dimension() const { return split_dimension_; }
set_split_dimension(int64 dim)479   void set_split_dimension(int64 dim) { split_dimension_ = dim; }
480 
481  protected:
482   std::vector<string> ExtraAttributesToStringImpl(
483       const HloPrintOptions& options) const override;
484   HloInstructionProto ToProto() const override;
485 
486  private:
487   bool IdenticalSlowPathIgnoringChannelIdValues(
488       const HloInstruction& other,
489       const std::function<bool(const HloComputation*, const HloComputation*)>&
490           eq_computations) const override;
491 
492   // Implementation for non-common logic of CloneWithNewOperands.
493   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
494       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
495       HloCloneContext* context) const override;
496 
497   absl::optional<int64> split_dimension_;
498 };
499 
500 class HloCollectivePermuteInstruction : public HloChannelInstruction {
501  public:
502   explicit HloCollectivePermuteInstruction(
503       HloOpcode opcode, const Shape& shape, HloInstruction* operand,
504       const std::vector<std::pair<int64, int64>>& source_target_pairs,
505       const absl::optional<int64>& channel_id);
506 
source_target_pairs()507   const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
508     return source_target_pairs_;
509   }
510 
511   // Returns a serialized representation of this instruction.
512   HloInstructionProto ToProto() const override;
513 
514  private:
515   std::vector<string> ExtraAttributesToStringImpl(
516       const HloPrintOptions& options) const override;
517   bool IdenticalSlowPathIgnoringChannelIdValues(
518       const HloInstruction& other,
519       const std::function<bool(const HloComputation*, const HloComputation*)>&
520           eq_computations) const override;
521 
522   // Implementation for non-common logic of CloneWithNewOperands.
523   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
524       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
525       HloCloneContext* context) const override;
526 
527   const std::vector<std::pair<int64, int64>> source_target_pairs_;
528 };
529 
530 class HloReverseInstruction : public HloInstruction {
531  public:
532   explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
533                                  absl::Span<const int64> dimensions);
534   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()535   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)536   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()537   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
538   // Returns a serialized representation of this instruction.
539   HloInstructionProto ToProto() const override;
540 
541  private:
542   std::vector<string> ExtraAttributesToStringImpl(
543       const HloPrintOptions& options) const override;
544   bool IdenticalSlowPath(
545       const HloInstruction& other,
546       const std::function<bool(const HloComputation*, const HloComputation*)>&
547           eq_computations) const override;
548   // Implementation for non-common logic of CloneWithNewOperands.
549   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
550       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
551       HloCloneContext* context) const override;
552 
553   std::vector<int64> dimensions_;
554 };
555 
556 class HloConcatenateInstruction : public HloInstruction {
557  public:
558   explicit HloConcatenateInstruction(const Shape& shape,
559                                      absl::Span<HloInstruction* const> operands,
560                                      int64 dimension);
561   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()562   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)563   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()564   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
565   // Accessor for the dimension in which a concatenate HLO should occur.
concatenate_dimension()566   int64 concatenate_dimension() const { return dimensions(0); }
567   // Returns a serialized representation of this instruction.
568   HloInstructionProto ToProto() const override;
569 
570  private:
571   std::vector<string> ExtraAttributesToStringImpl(
572       const HloPrintOptions& options) const override;
573   bool IdenticalSlowPath(
574       const HloInstruction& other,
575       const std::function<bool(const HloComputation*, const HloComputation*)>&
576           eq_computations) const override;
577   // Implementation for non-common logic of CloneWithNewOperands.
578   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
579       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
580       HloCloneContext* context) const override;
581 
582   std::vector<int64> dimensions_;
583 };
584 
585 class HloReduceInstruction : public HloInstruction {
586  public:
587   explicit HloReduceInstruction(const Shape& shape,
588                                 absl::Span<HloInstruction* const> args,
589                                 absl::Span<const int64> dimensions_to_reduce,
590                                 HloComputation* reduce_computation);
591   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()592   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)593   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()594   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
595   // Returns a serialized representation of this instruction.
596   HloInstructionProto ToProto() const override;
597 
598   // Returns the number of input arrays (and, consequentially, the number of
599   // init values) this reduce has.
input_count()600   int64 input_count() const { return operand_count() / 2; }
601 
602   // Returns the input tensors to be reduced.
inputs()603   absl::Span<HloInstruction* const> inputs() const {
604     return absl::MakeSpan(operands()).subspan(0, input_count());
605   }
606 
607   // Returns the init values of the reduction.
init_values()608   absl::Span<HloInstruction* const> init_values() const {
609     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
610   }
611 
612  private:
613   std::vector<string> ExtraAttributesToStringImpl(
614       const HloPrintOptions& options) const override;
615   bool IdenticalSlowPath(
616       const HloInstruction& other,
617       const std::function<bool(const HloComputation*, const HloComputation*)>&
618           eq_computations) const override;
619   // Implementation for non-common logic of CloneWithNewOperands.
620   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
621       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
622       HloCloneContext* context) const override;
623 
624   std::vector<int64> dimensions_;
625 };
626 
627 class HloSortInstruction : public HloInstruction {
628  public:
629   explicit HloSortInstruction(const Shape& shape, int64 dimension,
630                               absl::Span<HloInstruction* const> operands,
631                               HloComputation* compare, bool is_stable);
632   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()633   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)634   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()635   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
636   // Returns the sort dimension for this instruction
sort_dimension()637   int64 sort_dimension() const { return dimensions(0); }
638   // Returns a serialized representation of this instruction.
639   HloInstructionProto ToProto() const override;
640   // Returns the key operand to this instruction.
keys()641   const HloInstruction* keys() const { return operand(0); }
mutable_keys()642   HloInstruction* mutable_keys() { return mutable_operand(0); }
643   // Returns the number of value operands.
values_count()644   int64 values_count() const { return operand_count() - 1; }
is_stable()645   bool is_stable() const { return is_stable_; }
646 
647  private:
648   std::vector<string> ExtraAttributesToStringImpl(
649       const HloPrintOptions& options) const override;
650   bool IdenticalSlowPath(
651       const HloInstruction& other,
652       const std::function<bool(const HloComputation*, const HloComputation*)>&
653           eq_computations) const override;
654   // Implementation for non-common logic of CloneWithNewOperands.
655   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
656       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
657       HloCloneContext* context) const override;
658 
659   std::vector<int64> dimensions_;
660   bool is_stable_;
661 };
662 
663 class HloTransposeInstruction : public HloInstruction {
664  public:
665   explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
666                                    absl::Span<const int64> dimensions);
667   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()668   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)669   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()670   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
671   // Returns whether this instruction does a rank-2 transposition.
672   bool IsRank2Transpose() const;
673   // Returns a serialized representation of this instruction.
674   HloInstructionProto ToProto() const override;
675 
676  private:
677   std::vector<string> ExtraAttributesToStringImpl(
678       const HloPrintOptions& options) const override;
679   bool IdenticalSlowPath(
680       const HloInstruction& other,
681       const std::function<bool(const HloComputation*, const HloComputation*)>&
682           eq_computations) const override;
683   // Implementation for non-common logic of CloneWithNewOperands.
684   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
685       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
686       HloCloneContext* context) const override;
687 
688   std::vector<int64> dimensions_;
689 };
690 
691 class HloBroadcastInstruction : public HloInstruction {
692  public:
693   explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
694                                    absl::Span<const int64> broadcast_dimension);
695   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()696   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)697   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()698   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
699   // Returns a serialized representation of this instruction.
700   HloInstructionProto ToProto() const override;
701 
702  private:
703   std::vector<string> ExtraAttributesToStringImpl(
704       const HloPrintOptions& options) const override;
705   bool IdenticalSlowPath(
706       const HloInstruction& other,
707       const std::function<bool(const HloComputation*, const HloComputation*)>&
708           eq_computations) const override;
709   // Implementation for non-common logic of CloneWithNewOperands.
710   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
711       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
712       HloCloneContext* context) const override;
713 
714   std::vector<int64> dimensions_;
715 };
716 
717 class HloDynamicReshapeInstruction : public HloInstruction {
718  public:
719   explicit HloDynamicReshapeInstruction(
720       const Shape& shape, HloInstruction* data_operand,
721       absl::Span<HloInstruction* const> dim_sizes);
722 
723   // Returns the input dim sizes dimensions, which is operands[1:]
dim_sizes()724   absl::Span<HloInstruction* const> dim_sizes() const {
725     return absl::MakeSpan(operands()).subspan(1, operand_count());
726   }
727 
728   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
729       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
730       HloCloneContext* context) const override;
731 
732   // Returns the input dim size dimension, which is operands[1+i]
dim_sizes(int64 i)733   HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; }
734 };
735 
736 class HloReshapeInstruction : public HloInstruction {
737  public:
738   explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand,
739                                  int64 inferred_dimension);
inferred_dimension()740   int64 inferred_dimension() const { return inferred_dimension_; }
741   HloInstructionProto ToProto() const override;
742 
743  private:
744   std::vector<string> ExtraAttributesToStringImpl(
745       const HloPrintOptions& options) const override;
746   bool IdenticalSlowPath(
747       const HloInstruction& other,
748       const std::function<bool(const HloComputation*, const HloComputation*)>&
749           eq_computations) const override;
750   // Implementation for non-common logic of CloneWithNewOperands.
751   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
752       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
753       HloCloneContext* context) const override;
754   int64 inferred_dimension_;
755 };
756 
757 class HloMapInstruction : public HloInstruction {
758  public:
759   explicit HloMapInstruction(const Shape& shape,
760                              absl::Span<HloInstruction* const> operands,
761                              HloComputation* map_computation);
762   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()763   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)764   int64 dimensions(int64 index) const override { return dimensions()[index]; }
mutable_dimensions()765   std::vector<int64>* mutable_dimensions() override { return &dimensions_; }
766   // Returns a serialized representation of this instruction.
767   HloInstructionProto ToProto() const override;
768 
769  private:
770   bool IsElementwiseImpl(
771       const absl::optional<int64>& operand_idx) const override;
772   std::vector<string> ExtraAttributesToStringImpl(
773       const HloPrintOptions& options) const override;
774   bool IdenticalSlowPath(
775       const HloInstruction& other,
776       const std::function<bool(const HloComputation*, const HloComputation*)>&
777           eq_computations) const override;
778   // Implementation for non-common logic of CloneWithNewOperands.
779   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
780       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
781       HloCloneContext* context) const override;
782 
783   std::vector<int64> dimensions_;
784 };
785 
786 class HloSliceInstruction : public HloInstruction {
787  public:
788   explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
789                                absl::Span<const int64> start_indices,
790                                absl::Span<const int64> limit_indices,
791                                absl::Span<const int64> strides);
792 
793   HloInstructionProto ToProto() const override;
794 
795   // Returns the start index in the given dimension for a slice node.
slice_starts(int64 dimension)796   int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
slice_starts()797   const std::vector<int64>& slice_starts() const { return slice_starts_; }
mutable_slice_starts()798   std::vector<int64>* mutable_slice_starts() { return &slice_starts_; }
799 
800   // Returns the (exclusive) limit index in the given dimension for a slice
801   // node.
slice_limits(int64 dimension)802   int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
slice_limits()803   const std::vector<int64>& slice_limits() const { return slice_limits_; }
mutable_slice_limits()804   std::vector<int64>* mutable_slice_limits() { return &slice_limits_; }
805 
806   // Returns the stride in the given dimension for a slice node.
slice_strides(int64 dimension)807   int64 slice_strides(int64 dimension) const {
808     return slice_strides_[dimension];
809   }
slice_strides()810   const std::vector<int64>& slice_strides() const { return slice_strides_; }
mutable_slice_strides()811   std::vector<int64>* mutable_slice_strides() { return &slice_strides_; }
812 
813  private:
814   std::vector<string> ExtraAttributesToStringImpl(
815       const HloPrintOptions& options) const override;
816   bool IdenticalSlowPath(
817       const HloInstruction& other,
818       const std::function<bool(const HloComputation*, const HloComputation*)>&
819           eq_computations) const override;
820   // Implementation for non-common logic of CloneWithNewOperands.
821   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
822       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
823       HloCloneContext* context) const override;
824 
825   // Describes the [begin, end) index range for a slice.
826   std::vector<int64> slice_starts_;
827   std::vector<int64> slice_limits_;
828   std::vector<int64> slice_strides_;
829 };
830 
831 class HloConstantInstruction : public HloInstruction {
832  public:
833   explicit HloConstantInstruction(Literal literal);
834   explicit HloConstantInstruction(Literal literal, const Shape& shape);
835   // Used when the literal is too large and dropped.
836   explicit HloConstantInstruction(const Shape& shape);
837   // Returns the literal associated with this instruction.
literal()838   const Literal& literal() const { return *literal_; }
839   // Returns the (mutable) literal associated with this instruction.
mutable_literal()840   Literal* mutable_literal() { return &literal_.value(); }
841   // Returns whether there is literal associated with this instruction.
HasLiteral()842   bool HasLiteral() const { return literal_.has_value(); }
843   // Returns a serialized representation of this instruction.
844   HloInstructionProto ToProto() const override;
845 
846   // Change the layout for an Constant Hlo instruction to match new_layout.  For
847   // tuple shaped constants shape_index is the path to the internal array
848   // subshape whose layout needs to be changed.
849   void RelayoutConstant(const Layout& new_layout,
850                         const ShapeIndex& shape_index = {});
851 
852  private:
853   bool IsElementwiseImpl(
854       const absl::optional<int64>& operand_idx) const override;
855   bool IdenticalSlowPath(
856       const HloInstruction& other,
857       const std::function<bool(const HloComputation*, const HloComputation*)>&
858           eq_computations) const override;
859   string OperandsToStringWithCanonicalNameMap(
860       const HloPrintOptions& options,
861       CanonicalNameMap* canonical_name_map) const override;
862   // Implementation for non-common logic of CloneWithNewOperands.
863   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
864       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
865       HloCloneContext* context) const override;
866   absl::optional<Literal> literal_;
867 };
868 
869 class HloTraceInstruction : public HloInstruction {
870  public:
871   explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
872   // Returns a tag to be used in tracing.
TracingTag()873   string TracingTag() const { return literal_.GetR1U8AsString(); }
874   // Returns a serialized representation of this instruction.
875   HloInstructionProto ToProto() const override;
876 
877  private:
878   bool IdenticalSlowPath(
879       const HloInstruction& other,
880       const std::function<bool(const HloComputation*, const HloComputation*)>&
881           eq_computations) const override;
882   // Implementation for non-common logic of CloneWithNewOperands.
883   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
884       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
885       HloCloneContext* context) const override;
886   Literal literal_;
887 };
888 
889 class HloFusionInstruction : public HloInstruction {
890  public:
891   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
892                                 HloInstruction* fused_root);
893 
894   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
895                                 absl::Span<HloInstruction* const> operands,
896                                 HloComputation* fusion_computation);
897 
898   string ToCategory() const override;
899   // Returns a serialized representation of this instruction.
900   HloInstructionProto ToProto() const override;
901 
902   // Adds a new operand the fusion instruction.
903   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
904 
905   // Merges the fused instructions from 'instruction_to_merge' into the
906   // fused instruction set of 'this', updating operands as necessary.
907   //
908   // Precondition: 'instruction_to_merge' must be an operand of 'this'.
909   void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
910 
911   // Merges the fused instructions from instruction_to_merge into the fused
912   // instruction set of 'this' and generates multioutput fusion instructions.
913   // All the users of instruction_to_merge will be redirected to 'this'
914   // instruction. instruction_to_merge will be removed from its parent
915   // computation.
916   void MergeFusionInstructionIntoMultiOutput(
917       HloFusionInstruction* instruction_to_merge);
918 
919   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
920   // is cloned and the clone is placed in the fusion
921   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
922   // than moved to cleanly handle the case where the instruction has a use
923   // outside the fusion instruction. Moving such an instruction into a fusion
924   // instruction would violate the single-result invariant of HLO instructions
925   // and significantly complicate code generation.
FuseInstruction(HloInstruction * instruction_to_fuse)926   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
927     return FuseInstructionInternal(instruction_to_fuse);
928   }
929 
930   // Fuses the given instruction in this fusion instruction and generates a
931   // multioutput fusion instruction. A clone of the instruction_to_fuse will
932   // be part of the output of fusion instructions. The users of
933   // instruction_to_fuse will be redirected to this fusion instructions.
934   // instruction_to_fuse is unchanged otherwise.
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)935   HloInstruction* FuseInstructionIntoMultiOutput(
936       HloInstruction* instruction_to_fuse) {
937     return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
938   }
939 
940   // Returns the computation for this fused instruction.
941   HloComputation* fused_instructions_computation() const;
942 
943   // Returns the root instruction of the fused expression contained within this
944   // fusion instruction.
945   HloInstruction* fused_expression_root() const;
946 
947   // Returns the list of fused instructions inside this fusion instruction.  The
948   // returned type is a range of HloInstruction*s.
949   const tensorflow::gtl::iterator_range<UnwrappingIterator<
950       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
951   fused_instructions() const;
952 
953   const tensorflow::gtl::iterator_range<
954       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
955   fused_instructions();
956 
957   // Gets the number of instructions inside this fusion instruction.
958   int64 fused_instruction_count() const;
959 
960   // Returns the fused parameter instruction in this fusion instruction
961   // corresponding to the given parameter number.
962   HloInstruction* fused_parameter(int64 parameter_number) const;
963 
964   // Returns the vector of fused parameters inside this fusion instruction.
965   const std::vector<HloInstruction*>& fused_parameters() const;
966 
967   // Returns true if this instruction is a fusion instruction that generates
968   // multiple outputs.
IsMultiOutputFusion()969   const bool IsMultiOutputFusion() const {
970     return fused_expression_root()->opcode() == HloOpcode::kTuple;
971   }
972 
fusion_kind()973   FusionKind fusion_kind() const { return fusion_kind_; }
974 
set_fusion_kind(FusionKind kind)975   void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
976 
977   // If multiple operands are the same instruction, keeps only one of them.
978   Status DeduplicateFusionOperands();
979 
980  private:
981   // Fuses the given instruction into this fusion instruction.
982   // instruction_to_fuse is cloned and the clone is placed in the fusion
983   // instruction.  The users of instruction_to_fuse will be redirected to this
984   // fusion instruction. instruction_to_fuse is unchanged otherwise. When
985   // add_output is true, a clone of the instruction_to_fuse will be added as
986   // additional output resulting in a multi-output fusion.
987   HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
988                                           bool add_output = false);
989   // Clones the given instruction_to_fuse and insert the clone into this fusion
990   // instruction. If add_output is true, a clone of instruction_to_fuse will
991   // be in the output of the this fusion instruction (part of the tuple of the
992   // fusion root).
993   HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
994                                        bool add_output = false);
995 
996   bool IsElementwiseImpl(
997       const absl::optional<int64>& operand_idx) const override;
998   std::vector<string> ExtraAttributesToStringImpl(
999       const HloPrintOptions& options) const override;
1000   bool IdenticalSlowPath(
1001       const HloInstruction& other,
1002       const std::function<bool(const HloComputation*, const HloComputation*)>&
1003           eq_computations) const override;
1004   uint64 InnerHash() const override;
1005 
1006   // Implementation for non-common logic of CloneWithNewOperands.
1007   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1008       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1009       HloCloneContext* context) const override;
1010 
1011   // The type of the fusion. Used by kFusion only.
1012   FusionKind fusion_kind_;
1013 };
1014 
1015 class HloRngInstruction : public HloInstruction {
1016  public:
1017   explicit HloRngInstruction(const Shape& shape,
1018                              RandomDistribution distribution,
1019                              absl::Span<HloInstruction* const> parameters);
1020   // Returns the random distribution for this rng node.
random_distribution()1021   RandomDistribution random_distribution() const { return distribution_; }
1022   // Returns a serialized representation of this instruction.
1023   HloInstructionProto ToProto() const override;
1024 
1025  private:
1026   bool IsElementwiseImpl(
1027       const absl::optional<int64>& operand_idx) const override;
1028   std::vector<string> ExtraAttributesToStringImpl(
1029       const HloPrintOptions& options) const override;
1030   bool IdenticalSlowPath(
1031       const HloInstruction& other,
1032       const std::function<bool(const HloComputation*, const HloComputation*)>&
1033           eq_computations) const override;
1034   // Implementation for non-common logic of CloneWithNewOperands.
1035   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1036       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1037       HloCloneContext* context) const override;
1038 
1039   // The distribution requested for random number generation.
1040   RandomDistribution distribution_;
1041 };
1042 
1043 class HloParameterInstruction : public HloInstruction {
1044  public:
1045   explicit HloParameterInstruction(int64 parameter_number, const Shape& shape,
1046                                    const string& name);
parameter_number()1047   int64 parameter_number() const { return parameter_number_; }
1048 
1049   // Sets and gets the whether all replicas will receive the same parameter data
1050   // for each leaf buffer in data parallelism.
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)1051   void set_parameter_replicated_at_leaf_buffers(
1052       absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
1053     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1054              parameter_replicated_at_leaf_buffers.size());
1055     parameter_replicated_at_leaf_buffers_.emplace(
1056         parameter_replicated_at_leaf_buffers.begin(),
1057         parameter_replicated_at_leaf_buffers.end());
1058   }
set_parameter_replicated_at_leaf_buffers(const std::vector<bool> & parameter_replicated_at_leaf_buffers)1059   void set_parameter_replicated_at_leaf_buffers(
1060       const std::vector<bool>& parameter_replicated_at_leaf_buffers) {
1061     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
1062              parameter_replicated_at_leaf_buffers.size());
1063     parameter_replicated_at_leaf_buffers_ =
1064         parameter_replicated_at_leaf_buffers;
1065   }
1066   const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers()1067   parameter_replicated_at_leaf_buffers() const {
1068     return parameter_replicated_at_leaf_buffers_;
1069   }
1070 
1071   // Returns a serialized representation of this instruction.
1072   HloInstructionProto ToProto() const override;
1073 
1074  private:
1075   std::vector<string> ExtraAttributesToStringImpl(
1076       const HloPrintOptions& options) const override;
1077   bool IdenticalSlowPath(
1078       const HloInstruction& other,
1079       const std::function<bool(const HloComputation*, const HloComputation*)>&
1080           eq_computations) const override;
1081   string OperandsToStringWithCanonicalNameMap(
1082       const HloPrintOptions& options,
1083       CanonicalNameMap* canonical_name_map) const override;
1084   // Implementation for non-common logic of CloneWithNewOperands.
1085   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1086       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1087       HloCloneContext* context) const override;
1088 
1089   int64 parameter_number_ = 0;
1090 
1091   // Specifies whether each buffer has the same parameter value on all replicas
1092   // in data parallelism.
1093   absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
1094 };
1095 
1096 class HloGetTupleElementInstruction : public HloInstruction {
1097  public:
1098   explicit HloGetTupleElementInstruction(const Shape& shape,
1099                                          HloInstruction* operand, int64 index);
1100   // Returns the tuple index associated with this instruction.
tuple_index()1101   int64 tuple_index() const { return tuple_index_; }
1102   // Sets the tuple index associated with this instruction.
set_tuple_index(int64 new_tuple_index)1103   void set_tuple_index(int64 new_tuple_index) {
1104     tuple_index_ = new_tuple_index;
1105   }
1106   // Returns a serialized representation of this instruction.
1107   HloInstructionProto ToProto() const override;
1108 
1109  private:
1110   std::vector<string> ExtraAttributesToStringImpl(
1111       const HloPrintOptions& options) const override;
1112   bool IdenticalSlowPath(
1113       const HloInstruction& other,
1114       const std::function<bool(const HloComputation*, const HloComputation*)>&
1115           eq_computations) const override;
1116   // Implementation for non-common logic of CloneWithNewOperands.
1117   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1118       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1119       HloCloneContext* context) const override;
1120 
1121   int64 tuple_index_ = -1;
1122 };
1123 
1124 class HloReducePrecisionInstruction : public HloInstruction {
1125  public:
1126   explicit HloReducePrecisionInstruction(const Shape& shape,
1127                                          HloInstruction* operand,
1128                                          const int exponent_bits,
1129                                          const int mantissa_bits);
1130   // Returns the number of exponent bits for a reduce-precision node.
exponent_bits()1131   int32 exponent_bits() const { return exponent_bits_; }
1132   // Returns the number of mantissa bits for a reduce-precision node.
mantissa_bits()1133   int32 mantissa_bits() const { return mantissa_bits_; }
1134   // Returns a serialized representation of this instruction.
1135   HloInstructionProto ToProto() const override;
1136 
1137  private:
1138   std::vector<string> ExtraAttributesToStringImpl(
1139       const HloPrintOptions& options) const override;
1140   bool IdenticalSlowPath(
1141       const HloInstruction& other,
1142       const std::function<bool(const HloComputation*, const HloComputation*)>&
1143           eq_computations) const override;
1144   // Implementation for non-common logic of CloneWithNewOperands.
1145   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1146       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1147       HloCloneContext* context) const override;
1148 
1149   // The bit sizes for a reduce-precision operation.
1150   int32 exponent_bits_ = 0;
1151   int32 mantissa_bits_ = 0;
1152 };
1153 
1154 class HloInfeedInstruction : public HloInstruction {
1155  public:
1156   explicit HloInfeedInstruction(const Shape& infeed_shape,
1157                                 HloInstruction* token_operand,
1158                                 const string& config);
1159   // Returns the infeed configuration string. The infeed configuration includes
1160   // any metadata needed for the backend compiler (e.g., infeed buffer address)
1161   // and is target-dependent.
infeed_config()1162   string infeed_config() const { return infeed_config_; }
set_infeed_config(const string & config)1163   void set_infeed_config(const string& config) { infeed_config_ = config; }
1164   // Returns the shape of the data received by the infeed. This is not the same
1165   // as the shape of the infeed instruction which produces a tuple containing
1166   // the infeed data shape and a TOKEN.
infeed_shape()1167   const Shape& infeed_shape() const {
1168     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
1169     return ShapeUtil::GetSubshape(shape(), {0});
1170   }
1171   // Returns a serialized representation of this instruction.
1172   HloInstructionProto ToProto() const override;
1173 
1174  private:
1175   std::vector<string> ExtraAttributesToStringImpl(
1176       const HloPrintOptions& options) const override;
1177   bool IdenticalSlowPath(
1178       const HloInstruction& other,
1179       const std::function<bool(const HloComputation*, const HloComputation*)>&
1180           eq_computations) const override;
1181   // Implementation for non-common logic of CloneWithNewOperands.
1182   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1183       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1184       HloCloneContext* context) const override;
1185 
1186   // The string representation of the infeed configuration.
1187   string infeed_config_;
1188 };
1189 
1190 class HloOutfeedInstruction : public HloInstruction {
1191  public:
1192   explicit HloOutfeedInstruction(const Shape& outfeed_shape,
1193                                  HloInstruction* operand,
1194                                  HloInstruction* token_operand,
1195                                  absl::string_view outfeed_config);
1196   // Returns the shape for the Outfeed instruction.
outfeed_shape()1197   const Shape& outfeed_shape() const { return outfeed_shape_; }
1198   // Returns the mutable shape for the Outfeed instruction.
mutable_outfeed_shape()1199   Shape* mutable_outfeed_shape() { return &outfeed_shape_; }
1200   // Returns the config for the Outfeed instruction.
outfeed_config()1201   const string& outfeed_config() const { return outfeed_config_; }
set_outfeed_config(const string & config)1202   void set_outfeed_config(const string& config) { outfeed_config_ = config; }
1203   // Returns a serialized representation of this instruction.
1204   HloInstructionProto ToProto() const override;
1205 
1206  private:
1207   std::vector<string> ExtraAttributesToStringImpl(
1208       const HloPrintOptions& options) const override;
1209   bool IdenticalSlowPath(
1210       const HloInstruction& other,
1211       const std::function<bool(const HloComputation*, const HloComputation*)>&
1212           eq_computations) const override;
1213   // Implementation for non-common logic of CloneWithNewOperands.
1214   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1215       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1216       HloCloneContext* context) const override;
1217 
1218   // Shape of outfeed request.
1219   Shape outfeed_shape_;
1220   // Outfeed configuration information, only present for kOutfeed.
1221   string outfeed_config_;
1222 };
1223 
1224 class HloConvolutionInstruction : public HloInstruction {
1225  public:
1226   explicit HloConvolutionInstruction(
1227       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1228       int64 feature_group_count, int64 batch_group_count, const Window& window,
1229       const ConvolutionDimensionNumbers& dimension_numbers,
1230       const PrecisionConfig& precision_config);
window()1231   const Window& window() const override { return window_; }
set_window(const Window & window)1232   void set_window(const Window& window) override { window_ = window; }
convolution_dimension_numbers()1233   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1234     return convolution_dimension_numbers_;
1235   }
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1236   void set_convolution_dimension_numbers(
1237       const ConvolutionDimensionNumbers& dnums) {
1238     convolution_dimension_numbers_ = dnums;
1239   }
1240   // The number of feature groups. Must be a divisor of the input feature
1241   // dimension and output feature dimension.
feature_group_count()1242   int64 feature_group_count() const { return feature_group_count_; }
set_feature_group_count(int64 num_feature_groups)1243   void set_feature_group_count(int64 num_feature_groups) {
1244     feature_group_count_ = num_feature_groups;
1245   }
1246   // The number of batch groups. Must be a divisor of the input batch dimension.
batch_group_count()1247   int64 batch_group_count() const { return batch_group_count_; }
set_batch_group_count(int64 num_batch_groups)1248   void set_batch_group_count(int64 num_batch_groups) {
1249     batch_group_count_ = num_batch_groups;
1250   }
1251 
1252   // Returns the information used to tell the implementation information about
1253   // what sort of precision is requested. The meaning of the field is backend
1254   // specific. At the moment, it is only supported for kConvolution and kDot.
1255   // Transformations on one kDot or kConvolution to another will preserve this
1256   // information. Transformations to other HLOs will not preserve this
1257   // information but it is presumed that the alternate lowering is strictly
1258   // superior.
precision_config()1259   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1260   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1261 
1262   string ToCategory() const override;
1263   // Returns a serialized representation of this instruction.
1264   HloInstructionProto ToProto() const override;
1265 
1266  private:
1267   std::vector<string> ExtraAttributesToStringImpl(
1268       const HloPrintOptions& options) const override;
1269   bool IdenticalSlowPath(
1270       const HloInstruction& other,
1271       const std::function<bool(const HloComputation*, const HloComputation*)>&
1272           eq_computations) const override;
1273   // Implementation for non-common logic of CloneWithNewOperands.
1274   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1275       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1276       HloCloneContext* context) const override;
1277   // The number of feature groups. Must be a divisor of the input feature
1278   // dimension and output feature dimension.
1279   int64 feature_group_count_;
1280   // The number of batch groups. Must be a divisor of the input batch dimension.
1281   int64 batch_group_count_;
1282   // Describes the window used for a convolution.
1283   Window window_;
1284   // Describes the dimension numbers used for a convolution.
1285   ConvolutionDimensionNumbers convolution_dimension_numbers_;
1286   // Information used to communicate to the implementation about the algorithm
1287   // used to produce results. See the documentation on precision_config().
1288   PrecisionConfig precision_config_;
1289 };
1290 
1291 class HloReduceWindowInstruction : public HloInstruction {
1292  public:
1293   explicit HloReduceWindowInstruction(const Shape& shape,
1294                                       HloInstruction* operand,
1295                                       HloInstruction* init_value,
1296                                       const Window& window,
1297                                       HloComputation* reduce_computation);
1298   explicit HloReduceWindowInstruction(
1299       const Shape& shape, absl::Span<HloInstruction* const> operands,
1300       absl::Span<HloInstruction* const> init_values, const Window& window,
1301       HloComputation* reduce_computation);
window()1302   const Window& window() const override { return window_; }
set_window(const Window & window)1303   void set_window(const Window& window) override { window_ = window; }
1304   // Returns a serialized representation of this instruction.
1305   HloInstructionProto ToProto() const override;
1306   // Returns the number of input arrays (and, consequentially, the number of
1307   // init values) this reduce has.
input_count()1308   int64 input_count() const { return operand_count() / 2; }
1309   // Returns the input tensors to be reduced.
input_arrays()1310   absl::Span<HloInstruction* const> input_arrays() const {
1311     return absl::MakeSpan(operands()).subspan(0, input_count());
1312   }
1313   // Returns the init values of the reduction.
init_values()1314   absl::Span<HloInstruction* const> init_values() const {
1315     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
1316   }
1317   // Returns the shapes of input tensors to be reduced.
input_array_shapes()1318   absl::InlinedVector<const Shape*, 2> input_array_shapes() const {
1319     absl::InlinedVector<const Shape*, 2> shapes;
1320     for (const auto* op : input_arrays()) {
1321       VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
1322       shapes.push_back(&op->shape());
1323       VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
1324     }
1325     return shapes;
1326   }
1327   // Returns the init values of the reduction.
init_value_shapes()1328   absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
1329     absl::InlinedVector<const Shape*, 2> shapes;
1330     for (const auto* op : init_values()) {
1331       shapes.push_back(&op->shape());
1332     }
1333     return shapes;
1334   }
1335 
1336  private:
1337   std::vector<string> ExtraAttributesToStringImpl(
1338       const HloPrintOptions& options) const override;
1339   bool IdenticalSlowPath(
1340       const HloInstruction& other,
1341       const std::function<bool(const HloComputation*, const HloComputation*)>&
1342           eq_computations) const override;
1343   // Implementation for non-common logic of CloneWithNewOperands.
1344   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1345       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1346       HloCloneContext* context) const override;
1347 
1348   Window window_;
1349 };
1350 
1351 class HloSelectAndScatterInstruction : public HloInstruction {
1352  public:
1353   explicit HloSelectAndScatterInstruction(
1354       const Shape& shape, HloInstruction* operand, HloComputation* select,
1355       const Window& window, HloInstruction* source, HloInstruction* init_value,
1356       HloComputation* scatter);
window()1357   const Window& window() const override { return window_; }
set_window(const Window & window)1358   void set_window(const Window& window) override { window_ = window; }
1359   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
1360   // setters should only be called by HloModule or HloComputation methods.
select()1361   HloComputation* select() const {
1362     return called_computations()[kSelectComputationIndex];
1363   }
1364 
scatter()1365   HloComputation* scatter() const {
1366     return called_computations()[kScatterComputationIndex];
1367   }
1368 
set_select(HloComputation * computation)1369   void set_select(HloComputation* computation) {
1370     // Don't allow changing the computation for fused instructions so we don't
1371     // have to recompute called_instructions for the entire fusion instruction.
1372     CHECK(!IsFused());
1373     set_called_computation(kSelectComputationIndex, computation);
1374   }
1375 
set_scatter(HloComputation * computation)1376   void set_scatter(HloComputation* computation) {
1377     // Don't allow changing the computation for fused instructions so we don't
1378     // have to recompute called_instructions for the entire fusion instruction.
1379     CHECK(!IsFused());
1380     set_called_computation(kScatterComputationIndex, computation);
1381   }
1382   // Returns a serialized representation of this instruction.
1383   HloInstructionProto ToProto() const override;
1384 
1385  private:
1386   std::vector<string> ExtraAttributesToStringImpl(
1387       const HloPrintOptions& options) const override;
1388   bool IdenticalSlowPath(
1389       const HloInstruction& other,
1390       const std::function<bool(const HloComputation*, const HloComputation*)>&
1391           eq_computations) const override;
1392   // Implementation for non-common logic of CloneWithNewOperands.
1393   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1394       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1395       HloCloneContext* context) const override;
1396   Window window_;
1397 };
1398 
1399 class HloCustomCallInstruction : public HloInstruction {
1400  public:
1401   HloCustomCallInstruction(const Shape& shape,
1402                            absl::Span<HloInstruction* const> operands,
1403                            absl::string_view custom_call_target, string opaque);
1404 
1405   // Constructor for a custom call with constrained layout. 'shape' and
1406   // 'operands_with_layout' must all have layouts.
1407   HloCustomCallInstruction(const Shape& shape,
1408                            absl::Span<HloInstruction* const> operands,
1409                            absl::string_view custom_call_target, string opaque,
1410                            absl::Span<const Shape> operand_shapes_with_layout);
1411 
1412   // Constructor for a custom call with a to_apply computation.
1413   HloCustomCallInstruction(const Shape& shape,
1414                            absl::Span<HloInstruction* const> operands,
1415                            HloComputation* to_apply,
1416                            absl::string_view custom_call_target, string opaque);
1417 
1418   // Constructor for a custom call with multiple computations.
1419   HloCustomCallInstruction(
1420       const Shape& shape, absl::Span<HloInstruction* const> operands,
1421       absl::Span<HloComputation* const> called_computations,
1422       absl::string_view custom_call_target, string opaque);
1423 
window()1424   const Window& window() const override {
1425     CHECK(window_ != nullptr);
1426     return *window_;
1427   }
1428 
set_window(const Window & window)1429   void set_window(const Window& window) override {
1430     window_ = absl::make_unique<Window>(window);
1431   }
1432 
convolution_dimension_numbers()1433   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1434     CHECK(convolution_dimension_numbers_ != nullptr);
1435     return *convolution_dimension_numbers_;
1436   }
1437 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1438   void set_convolution_dimension_numbers(
1439       const ConvolutionDimensionNumbers& dnums) {
1440     convolution_dimension_numbers_ =
1441         absl::make_unique<ConvolutionDimensionNumbers>(dnums);
1442   }
1443   // TODO(jpienaar): Remove this accessor in the follow up.
opaque()1444   const string& opaque() const { return raw_backend_config_string(); }
custom_call_target()1445   const string& custom_call_target() const { return custom_call_target_; }
set_feature_group_count(int64 feature_group_count)1446   void set_feature_group_count(int64 feature_group_count) {
1447     feature_group_count_ = feature_group_count;
1448   }
set_batch_group_count(int64 batch_group_count)1449   void set_batch_group_count(int64 batch_group_count) {
1450     batch_group_count_ = batch_group_count;
1451   }
1452   // Sets whether this custom call has a side-effect - by default a custom call
1453   // has no side-effects.
set_custom_call_has_side_effect(bool custom_call_has_side_effect)1454   void set_custom_call_has_side_effect(bool custom_call_has_side_effect) {
1455     custom_call_has_side_effect_ = custom_call_has_side_effect;
1456   }
feature_group_count()1457   int64 feature_group_count() const { return feature_group_count_; }
batch_group_count()1458   int64 batch_group_count() const { return batch_group_count_; }
custom_call_has_side_effect()1459   bool custom_call_has_side_effect() const {
1460     return custom_call_has_side_effect_;
1461   }
1462   // Returns padding type used for ops like convolution.
padding_type()1463   PaddingType padding_type() const { return padding_type_; }
1464 
set_padding_type(PaddingType padding_type)1465   void set_padding_type(PaddingType padding_type) {
1466     padding_type_ = padding_type;
1467   }
1468 
1469   // Returns the literal associated with this instruction.
literal()1470   const Literal& literal() const { return *literal_; }
1471   // Set the value of literal to a new one.
set_literal(Literal && literal)1472   void set_literal(Literal&& literal) { literal_.emplace(std::move(literal)); }
1473   // Returns whether there is literal associated with this instruction.
HasLiteral()1474   bool HasLiteral() const { return literal_.has_value(); }
1475 
precision_config()1476   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1477   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1478 
1479   // Returns a serialized representation of this instruction.
1480   HloInstructionProto ToProto() const override;
1481 
1482   // Returns whether the result and operand layouts are constrained.
layout_constrained()1483   bool layout_constrained() const { return layout_constrained_; }
1484 
1485   // Returns the shapes (with layout) of the operands. CHECKs if this custom
1486   // call does not have constrained layouts.
operand_shapes_with_layout()1487   const std::vector<Shape>& operand_shapes_with_layout() const {
1488     CHECK(layout_constrained());
1489     return operand_shapes_with_layout_;
1490   }
1491   // Gets a list of output/operand buffer pairs that alias each other, where the
1492   // output buffer is represented as a ShapeIndex, and the operand buffer is
1493   // represented as the operand index and the ShapeIndex. By default this list
1494   // is empty.
1495   const std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>&
output_to_operand_aliasing()1496   output_to_operand_aliasing() const {
1497     return output_to_operand_aliasing_;
1498   }
1499   // Sets the list of output/operand buffer pairs that alias each other.
set_output_to_operand_aliasing(std::vector<std::pair<ShapeIndex,std::pair<int64,ShapeIndex>>> aliasing)1500   void set_output_to_operand_aliasing(
1501       std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1502           aliasing) {
1503     output_to_operand_aliasing_ = std::move(aliasing);
1504   }
1505 
1506  private:
1507   std::vector<string> ExtraAttributesToStringImpl(
1508       const HloPrintOptions& options) const override;
1509   bool IdenticalSlowPath(
1510       const HloInstruction& other,
1511       const std::function<bool(const HloComputation*, const HloComputation*)>&
1512           eq_computations) const override;
1513   // Implementation for non-common logic of CloneWithNewOperands.
1514   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1515       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1516       HloCloneContext* context) const override;
1517   // Name of a global symbol to call.
1518   string custom_call_target_;
1519   // Describes the window in a windowed operation such as convolution.
1520   std::unique_ptr<Window> window_;
1521   // Describes the dimension numbers used for a convolution.
1522   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1523   // The number of feature groups. This is used for grouped convolutions.
1524   int64 feature_group_count_;
1525   int64 batch_group_count_;
1526   // Whether the result and operand layouts are constrained.
1527   bool layout_constrained_;
1528   // Information used to communicate to the implementation about the algorithm
1529   // used to produce results for convolution instructions.
1530   PrecisionConfig precision_config_;
1531   // Describes the padding type for convolution instructions.
1532   PaddingType padding_type_;
1533   // For layout-constrained custom calls, this vector holds the shape with
1534   // layout for each operand.
1535   std::vector<Shape> operand_shapes_with_layout_;
1536   // Whether this custom call has a side-effect.
1537   bool custom_call_has_side_effect_;
1538   // A list of output/operand buffer pairs that alias each other. See comment of
1539   // output_to_operand_aliasing().
1540   std::vector<std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1541       output_to_operand_aliasing_;
1542   absl::optional<Literal> literal_;
1543 };
1544 
1545 class HloPadInstruction : public HloInstruction {
1546  public:
1547   explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
1548                              HloInstruction* padding_value,
1549                              const PaddingConfig& padding_config);
1550   // Returns the padding configuration for a pad node.
padding_config()1551   const PaddingConfig& padding_config() const { return padding_config_; }
mutable_padding_config()1552   PaddingConfig* mutable_padding_config() { return &padding_config_; }
1553   // Returns the padding value.
padding_value()1554   const HloInstruction* padding_value() const { return operand(1); }
mutable_padding_value()1555   HloInstruction* mutable_padding_value() { return mutable_operand(1); }
1556   // Returns a serialized representation of this instruction.
1557   HloInstructionProto ToProto() const override;
1558 
1559  private:
1560   std::vector<string> ExtraAttributesToStringImpl(
1561       const HloPrintOptions& options) const override;
1562   bool IdenticalSlowPath(
1563       const HloInstruction& other,
1564       const std::function<bool(const HloComputation*, const HloComputation*)>&
1565           eq_computations) const override;
1566   // Implementation for non-common logic of CloneWithNewOperands.
1567   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1568       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1569       HloCloneContext* context) const override;
1570 
1571   // The padding configuration that describes the edge padding and interior
1572   // padding of this pad instruction.
1573   PaddingConfig padding_config_;
1574 };
1575 
1576 class HloDynamicIndexInstruction : public HloInstruction {
1577  public:
HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1578   explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
1579       : HloInstruction(opcode, shape) {}
1580   virtual int64 first_index_operand_number() const = 0;
1581 
1582   // Returns a subspan of operands which represent the start indices.
index_operands()1583   absl::Span<HloInstruction* const> index_operands() const {
1584     return absl::MakeSpan(operands()).subspan(first_index_operand_number());
1585   }
1586 
1587   // Returns the shapes of the index operands.
index_shapes()1588   std::vector<Shape> index_shapes() const {
1589     std::vector<Shape> shapes;
1590     auto indices = index_operands();
1591     for (const HloInstruction* index : indices) {
1592       shapes.push_back(index->shape());
1593     }
1594     return shapes;
1595   }
1596 };
1597 
1598 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
1599  public:
1600   explicit HloDynamicSliceInstruction(const Shape& shape,
1601                                       HloInstruction* operand,
1602                                       HloInstruction* start_indices,
1603                                       absl::Span<const int64> slice_sizes);
1604   explicit HloDynamicSliceInstruction(
1605       const Shape& shape, HloInstruction* operand,
1606       absl::Span<HloInstruction* const> start_indices,
1607       absl::Span<const int64> slice_sizes);
1608   // Old methods kept for smooth subclassing transition END.
1609   // Returns the size of the slice in the given dimension for a dynamic
1610   // slice node.
slice_sizes(int64 dimension)1611   int64 slice_sizes(int64 dimension) const {
1612     return dynamic_slice_sizes_[dimension];
1613   }
dynamic_slice_sizes()1614   const std::vector<int64>& dynamic_slice_sizes() const {
1615     return dynamic_slice_sizes_;
1616   }
1617   // Returns a serialized representation of this instruction.
1618   HloInstructionProto ToProto() const override;
1619 
first_index_operand_number()1620   int64 first_index_operand_number() const override { return 1; }
1621 
1622  private:
1623   std::vector<string> ExtraAttributesToStringImpl(
1624       const HloPrintOptions& options) const override;
1625   bool IdenticalSlowPath(
1626       const HloInstruction& other,
1627       const std::function<bool(const HloComputation*, const HloComputation*)>&
1628           eq_computations) const override;
1629   // Implementation for non-common logic of CloneWithNewOperands.
1630   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1631       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1632       HloCloneContext* context) const override;
1633 
1634   // Describes the [start, start + size) range size for a dynamic slice
1635   // ('start' is specified dynamically in the second operand of the operation).
1636   std::vector<int64> dynamic_slice_sizes_;
1637 };
1638 
1639 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
1640  public:
1641   explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
1642                                             HloInstruction* operand,
1643                                             HloInstruction* update,
1644                                             HloInstruction* start_indices);
1645   explicit HloDynamicUpdateSliceInstruction(
1646       const Shape& shape, HloInstruction* operand, HloInstruction* update,
1647       absl::Span<HloInstruction* const> start_indices);
1648 
first_index_operand_number()1649   int64 first_index_operand_number() const override { return 2; }
1650 };
1651 
1652 class HloGatherInstruction : public HloInstruction {
1653  public:
1654   explicit HloGatherInstruction(
1655       const Shape& shape, HloInstruction* operand,
1656       HloInstruction* start_indices,
1657       const GatherDimensionNumbers& gather_dim_numbers,
1658       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
gather_dimension_numbers()1659   const GatherDimensionNumbers& gather_dimension_numbers() const {
1660     CHECK(gather_dimension_numbers_ != nullptr);
1661     return *gather_dimension_numbers_;
1662   }
gather_slice_sizes()1663   absl::Span<const int64> gather_slice_sizes() const {
1664     return gather_slice_sizes_;
1665   }
indices_are_sorted()1666   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)1667   void set_indices_are_sorted(bool indices_are_sorted) {
1668     indices_are_sorted_ = indices_are_sorted;
1669   }
1670   // Returns a serialized representation of this instruction.
1671   HloInstructionProto ToProto() const override;
1672 
1673   // Creates an instance of GatherDimensionNumbers.
1674   static GatherDimensionNumbers MakeGatherDimNumbers(
1675       absl::Span<const int64> offset_dims,
1676       absl::Span<const int64> collapsed_slice_dims,
1677       absl::Span<const int64> start_index_map, int64 index_vector_dim);
1678   // Returns the dump string of the given gather dimension numbers.
1679   static string GatherDimensionNumbersToString(
1680       const GatherDimensionNumbers& gather_dimension_numbers);
1681 
1682  private:
1683   std::vector<string> ExtraAttributesToStringImpl(
1684       const HloPrintOptions& options) const override;
1685   bool IdenticalSlowPath(
1686       const HloInstruction& other,
1687       const std::function<bool(const HloComputation*, const HloComputation*)>&
1688           eq_computations) const override;
1689   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1690       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1691       HloCloneContext* context) const override;
1692 
1693   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1694   std::vector<int64> gather_slice_sizes_;
1695   bool indices_are_sorted_;
1696 };
1697 
1698 class HloScatterInstruction : public HloInstruction {
1699  public:
1700   explicit HloScatterInstruction(
1701       const Shape& shape, HloInstruction* operand,
1702       HloInstruction* scatter_indices, HloInstruction* updates,
1703       HloComputation* update_computation,
1704       const ScatterDimensionNumbers& scatter_dim_numbers,
1705       bool indices_are_sorted, bool unique_indices);
scatter_dimension_numbers()1706   const ScatterDimensionNumbers& scatter_dimension_numbers() const {
1707     CHECK(scatter_dimension_numbers_ != nullptr);
1708     return *scatter_dimension_numbers_;
1709   }
indices_are_sorted()1710   bool indices_are_sorted() const { return indices_are_sorted_; }
set_indices_are_sorted(bool indices_are_sorted)1711   void set_indices_are_sorted(bool indices_are_sorted) {
1712     indices_are_sorted_ = indices_are_sorted;
1713   }
unique_indices()1714   bool unique_indices() const override { return unique_indices_; }
1715   // Returns a serialized representation of this instruction.
1716   HloInstructionProto ToProto() const override;
1717 
1718   // Creates an instance of ScatterDimensionNumbers.
1719   static ScatterDimensionNumbers MakeScatterDimNumbers(
1720       absl::Span<const int64> update_window_dims,
1721       absl::Span<const int64> inserted_window_dims,
1722       absl::Span<const int64> scatter_dims_to_operand_dims,
1723       int64 index_vector_dim);
1724   // Returns the dump string of the given scatter dimension numbers.
1725   static string ScatterDimensionNumbersToString(
1726       const ScatterDimensionNumbers& scatter_dimension_numbers);
1727 
1728  private:
1729   std::vector<string> ExtraAttributesToStringImpl(
1730       const HloPrintOptions& options) const override;
1731   bool IdenticalSlowPath(
1732       const HloInstruction& other,
1733       const std::function<bool(const HloComputation*, const HloComputation*)>&
1734           eq_computations) const override;
1735   // Implementation for non-common logic of CloneWithNewOperands.
1736   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1737       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1738       HloCloneContext* context) const override;
1739 
1740   std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
1741   bool indices_are_sorted_;
1742   bool unique_indices_;
1743 };
1744 
1745 class HloIotaInstruction : public HloInstruction {
1746  public:
1747   explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
1748   // Returns the dimension sizes or numbers associated with this instruction.
iota_dimension()1749   int64 iota_dimension() const { return iota_dimension_; }
1750   // Returns a serialized representation of this instruction.
1751   HloInstructionProto ToProto() const override;
1752 
1753  private:
1754   std::vector<string> ExtraAttributesToStringImpl(
1755       const HloPrintOptions& options) const override;
1756   bool IdenticalSlowPath(
1757       const HloInstruction& other,
1758       const std::function<bool(const HloComputation*, const HloComputation*)>&
1759           eq_computations) const override;
1760   // Implementation for non-common logic of CloneWithNewOperands.
1761   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1762       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1763       HloCloneContext* context) const override;
1764 
1765   const int64 iota_dimension_;
1766 };
1767 
1768 class HloDotInstruction : public HloInstruction {
1769  public:
1770   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
1771   // dimensions specified in 'dimension_numbers'.
1772   explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
1773                              HloInstruction* rhs,
1774                              const DotDimensionNumbers& dimension_numbers,
1775                              const PrecisionConfig& precision_config);
1776 
1777   // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()1778   const DotDimensionNumbers& dot_dimension_numbers() const {
1779     return dot_dimension_numbers_;
1780   }
1781 
1782   // Returns the information used to tell the implementation information about
1783   // what sort of precision is requested. The meaning of the field is backend
1784   // specific. At the moment, it is only supported for kConvolution and kDot.
1785   // Transformations on one kDot or kConvolution to another will preserve this
1786   // information. Transformations to other HLOs will not preserve this
1787   // information but it is presumed that the alternate lowering is strictly
1788   // superior.
precision_config()1789   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1790   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1791 
1792   // Returns a serialized representation of this instruction.
1793   HloInstructionProto ToProto() const override;
1794 
1795  private:
1796   std::vector<string> ExtraAttributesToStringImpl(
1797       const HloPrintOptions& options) const override;
1798   bool IdenticalSlowPath(
1799       const HloInstruction& other,
1800       const std::function<bool(const HloComputation*, const HloComputation*)>&
1801           eq_computations) const override;
1802   // Implementation for non-common logic of CloneWithNewOperands.
1803   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1804       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1805       HloCloneContext* context) const override;
1806   // Returns the dump string of the dot dimension numbers.
1807   string DotDimensionNumbersToString() const;
1808 
1809   // Describes the dimension numbers used for a dot.
1810   DotDimensionNumbers dot_dimension_numbers_;
1811 
1812   // Information used to communicate to the implementation about the algorithm
1813   // used to produce results. See the documentation on precision_config().
1814   PrecisionConfig precision_config_;
1815 };
1816 
1817 class HloDomainInstruction : public HloInstruction {
1818  public:
1819   explicit HloDomainInstruction(
1820       const Shape& shape, HloInstruction* operand,
1821       std::unique_ptr<DomainMetadata> operand_side_metadata,
1822       std::unique_ptr<DomainMetadata> user_side_metadata);
1823 
1824   // Returns a serialized representation of this instruction.
1825   HloInstructionProto ToProto() const override;
1826 
1827   // Retrieves the operand side metadata of a kDomain instruction.
operand_side_metadata()1828   const DomainMetadata& operand_side_metadata() const {
1829     return *operand_side_metadata_;
1830   }
1831   // Retrieves the user side metadata of a kDomain instruction.
user_side_metadata()1832   const DomainMetadata& user_side_metadata() const {
1833     return *user_side_metadata_;
1834   }
1835 
1836  private:
1837   std::vector<string> ExtraAttributesToStringImpl(
1838       const HloPrintOptions& options) const override;
1839   bool IdenticalSlowPath(
1840       const HloInstruction& other,
1841       const std::function<bool(const HloComputation*, const HloComputation*)>&
1842           eq_computations) const override;
1843   // Implementation for non-common logic of CloneWithNewOperands.
1844   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1845       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1846       HloCloneContext* context) const override;
1847 
1848   std::unique_ptr<DomainMetadata> operand_side_metadata_;
1849   std::unique_ptr<DomainMetadata> user_side_metadata_;
1850 };
1851 
1852 class HloGetDimensionSizeInstruction : public HloInstruction {
1853  public:
1854   explicit HloGetDimensionSizeInstruction(const Shape& shape,
1855                                           HloInstruction* operand,
1856                                           int64 dimension);
1857 
1858   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1859   int64 dimension() const { return dimension_; }
1860   // Returns a serialized representation of this instruction.
1861   HloInstructionProto ToProto() const override;
1862 
1863  private:
1864   std::vector<string> ExtraAttributesToStringImpl(
1865       const HloPrintOptions& options) const override;
1866   bool IdenticalSlowPath(
1867       const HloInstruction& other,
1868       const std::function<bool(const HloComputation*, const HloComputation*)>&
1869           eq_computations) const override;
1870   // Implementation for non-common logic of CloneWithNewOperands.
1871   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1872       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1873       HloCloneContext* context) const override;
1874 
1875   int64 dimension_;
1876 };
1877 
1878 class HloSetDimensionSizeInstruction : public HloInstruction {
1879  public:
1880   explicit HloSetDimensionSizeInstruction(const Shape& shape,
1881                                           HloInstruction* operand,
1882                                           HloInstruction* val, int64 dimension);
1883 
1884   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1885   int64 dimension() const { return dimension_; }
1886   // Returns a serialized representation of this instruction.
1887   HloInstructionProto ToProto() const override;
1888 
1889  private:
1890   std::vector<string> ExtraAttributesToStringImpl(
1891       const HloPrintOptions& options) const override;
1892   bool IdenticalSlowPath(
1893       const HloInstruction& other,
1894       const std::function<bool(const HloComputation*, const HloComputation*)>&
1895           eq_computations) const override;
1896   // Implementation for non-common logic of CloneWithNewOperands.
1897   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1898       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1899       HloCloneContext* context) const override;
1900 
1901   int64 dimension_;
1902 };
1903 
1904 class HloRngGetAndUpdateStateInstruction : public HloInstruction {
1905  public:
1906   explicit HloRngGetAndUpdateStateInstruction(const Shape& shape, int64 delta);
1907 
1908   // Returns the delta value.
delta()1909   int64 delta() const { return delta_; }
set_delta(int64 delta)1910   void set_delta(int64 delta) { delta_ = delta; }
1911   // Returns a serialized representation of this instruction.
1912   HloInstructionProto ToProto() const override;
1913 
1914  private:
1915   std::vector<string> ExtraAttributesToStringImpl(
1916       const HloPrintOptions& options) const override;
1917   bool IdenticalSlowPath(
1918       const HloInstruction& other,
1919       const std::function<bool(const HloComputation*, const HloComputation*)>&
1920           eq_computations) const override;
1921   // Implementation for non-common logic of CloneWithNewOperands.
1922   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1923       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1924       HloCloneContext* context) const override;
1925 
1926   int64 delta_;
1927 };
1928 
1929 class HloRngBitGeneratorInstruction : public HloInstruction {
1930  public:
1931   HloRngBitGeneratorInstruction(const Shape& shape, HloInstruction* state,
1932                                 RandomAlgorithm algorithm);
1933 
algorithm()1934   RandomAlgorithm algorithm() const { return algorithm_; }
1935   HloInstructionProto ToProto() const override;
1936 
1937  private:
1938   std::vector<string> ExtraAttributesToStringImpl(
1939       const HloPrintOptions& options) const override;
1940   bool IdenticalSlowPath(
1941       const HloInstruction& other,
1942       const std::function<bool(const HloComputation*, const HloComputation*)>&
1943           eq_computations) const override;
1944   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1945       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1946       HloCloneContext* context) const override;
1947 
1948   RandomAlgorithm algorithm_;
1949 };
1950 
1951 }  // namespace xla
1952 
1953 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
1954