1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // All HloInstruction subclasses are put in this file.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
20 
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 
24 namespace xla {
25 
26 class HloBatchNormInstruction : public HloInstruction {
27  public:
28   // Returns feature_index field associated with the instruction. The index
29   // represents the index of the feature dimension.
feature_index()30   int64 feature_index() const { return feature_index_; }
31 
32   // Returns a epsilon value associated with the instruction. The is a small
33   // number added to the variance to avoid divide-by-zero error.
epsilon()34   float epsilon() const { return epsilon_; }
35 
36   // Returns a serialized representation of this instruction.
37   HloInstructionProto ToProto() const override;
38 
39  protected:
40   explicit HloBatchNormInstruction(HloOpcode opcode, const Shape& shape,
41                                    HloInstruction* operand,
42                                    HloInstruction* scale, float epsilon,
43                                    int64 feature_index);
44 
45  private:
46   std::vector<string> ExtraAttributesToStringImpl(
47       const HloPrintOptions& options) const override;
48   bool IdenticalSlowPath(
49       const HloInstruction& other,
50       const std::function<bool(const HloComputation*, const HloComputation*)>&
51           eq_computations) const override;
52   // A small float number added to the variance to avoid divide-by-zero error.
53   float epsilon_ = 0.0f;
54 
55   // An integer value representing the index of the feature dimension.
56   int64 feature_index_ = -1;
57 };
58 
59 class HloBatchNormTrainingInstruction : public HloBatchNormInstruction {
60  public:
61   explicit HloBatchNormTrainingInstruction(const Shape& shape,
62                                            HloInstruction* operand,
63                                            HloInstruction* scale,
64                                            HloInstruction* offset,
65                                            float epsilon, int64 feature_index);
66 
67  private:
68   // Implementation for non-common logic of CloneWithNewOperands.
69   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
70       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
71       HloCloneContext* context) const override;
72 };
73 
74 class HloBatchNormInferenceInstruction : public HloBatchNormInstruction {
75  public:
76   explicit HloBatchNormInferenceInstruction(
77       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
78       HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
79       float epsilon, int64 feature_index);
80 
81  private:
82   // Implementation for non-common logic of CloneWithNewOperands.
83   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
84       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
85       HloCloneContext* context) const override;
86 };
87 
88 class HloBatchNormGradInstruction : public HloBatchNormInstruction {
89  public:
90   explicit HloBatchNormGradInstruction(
91       const Shape& shape, HloInstruction* operand, HloInstruction* scale,
92       HloInstruction* mean, HloInstruction* variance,
93       HloInstruction* grad_output, float epsilon, int64 feature_index);
94 
95  private:
96   // Implementation for non-common logic of CloneWithNewOperands.
97   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
98       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
99       HloCloneContext* context) const override;
100 };
101 
102 class HloFftInstruction : public HloInstruction {
103  public:
104   explicit HloFftInstruction(const Shape& shape, HloInstruction* operand,
105                              FftType fft_type,
106                              absl::Span<const int64> fft_length);
fft_type()107   FftType fft_type() const { return fft_type_; }
108 
fft_length()109   const std::vector<int64>& fft_length() const { return fft_length_; }
110 
111   // Returns a serialized representation of this instruction.
112   HloInstructionProto ToProto() const override;
113 
114  private:
115   std::vector<string> ExtraAttributesToStringImpl(
116       const HloPrintOptions& options) const override;
117   bool IdenticalSlowPath(
118       const HloInstruction& other,
119       const std::function<bool(const HloComputation*, const HloComputation*)>&
120           eq_computations) const override;
121 
122   // Implementation for non-common logic of CloneWithNewOperands.
123   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
124       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
125       HloCloneContext* context) const override;
126 
127   // Describes FFT type for an FFT instruction.
128   FftType fft_type_ = FftType::FFT;
129 
130   // Indicates the FFT length for an FFT instruction.
131   std::vector<int64> fft_length_;
132 };
133 
134 class HloCompareInstruction : public HloInstruction {
135  public:
136   explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
137                                  HloInstruction* rhs,
138                                  ComparisonDirection direction);
direction()139   ComparisonDirection direction() const { return direction_; }
140   HloInstructionProto ToProto() const override;
141 
142  private:
143   std::vector<string> ExtraAttributesToStringImpl(
144       const HloPrintOptions& options) const override;
145   bool IdenticalSlowPath(
146       const HloInstruction& other,
147       const std::function<bool(const HloComputation*, const HloComputation*)>&
148           eq_computations) const override;
149   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
150       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
151       HloCloneContext* context) const override;
152 
153   ComparisonDirection direction_;
154 };
155 
156 class HloTriangularSolveInstruction : public HloInstruction {
157  public:
158   explicit HloTriangularSolveInstruction(const Shape& shape, HloInstruction* a,
159                                          HloInstruction* b,
160                                          const TriangularSolveOptions& options);
triangular_solve_options()161   const TriangularSolveOptions& triangular_solve_options() const {
162     return triangular_solve_options_;
163   }
164 
165   // Returns a serialized representation of this instruction.
166   HloInstructionProto ToProto() const override;
167 
168  private:
169   std::vector<string> ExtraAttributesToStringImpl(
170       const HloPrintOptions& options) const override;
171   bool IdenticalSlowPath(
172       const HloInstruction& other,
173       const std::function<bool(const HloComputation*, const HloComputation*)>&
174           eq_computations) const override;
175 
176   // Implementation for non-common logic of CloneWithNewOperands.
177   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
178       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
179       HloCloneContext* context) const override;
180 
181   TriangularSolveOptions triangular_solve_options_;
182 };
183 
184 class HloCholeskyInstruction : public HloInstruction {
185  public:
186   explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
187                                   const CholeskyOptions& options);
cholesky_options()188   const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
189 
190   // Returns a serialized representation of this instruction.
191   HloInstructionProto ToProto() const override;
192 
193  private:
194   std::vector<string> ExtraAttributesToStringImpl(
195       const HloPrintOptions& options) const override;
196   bool IdenticalSlowPath(
197       const HloInstruction& other,
198       const std::function<bool(const HloComputation*, const HloComputation*)>&
199           eq_computations) const override;
200 
201   // Implementation for non-common logic of CloneWithNewOperands.
202   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
203       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
204       HloCloneContext* context) const override;
205 
206   CholeskyOptions cholesky_options_;
207 };
208 
209 class HloSendRecvInstruction : public HloInstruction {
210  public:
211   // Returns the channel id associated with the instruction. The id is
212   // shared between each Send/Recv pair and is globally unique to identify each
213   // channel.
channel_id()214   int64 channel_id() const { return channel_id_; }
215 
216   // Returns whether this send/recv instruction sends data to/from the host.
is_host_transfer()217   bool is_host_transfer() const { return is_host_transfer_; }
218 
219   // Returns a serialized representation of this instruction.
220   HloInstructionProto ToProto() const override;
221 
222  protected:
223   explicit HloSendRecvInstruction(HloOpcode opcode, const Shape& shape,
224                                   int64 channel_id, bool is_host_transfer);
225 
226  private:
227   std::vector<string> ExtraAttributesToStringImpl(
228       const HloPrintOptions& options) const override;
229   bool IdenticalSlowPath(
230       const HloInstruction& other,
231       const std::function<bool(const HloComputation*, const HloComputation*)>&
232           eq_computations) const override;
233   // Represents a unique identifier for each Send/Recv instruction pair.
234   int64 channel_id_;
235 
236   // Whether this send/recv instruction sends data to/from the host.
237   bool is_host_transfer_;
238 };
239 
240 class HloSendInstruction : public HloSendRecvInstruction {
241  public:
242   explicit HloSendInstruction(HloInstruction* operand, HloInstruction* token,
243                               int64 channel_id, bool is_host_transfer);
244 
245  private:
246   // Implementation for non-common logic of CloneWithNewOperands.
247   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
248       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
249       HloCloneContext* context) const override;
250 };
251 
252 class HloSendDoneInstruction : public HloSendRecvInstruction {
253  public:
254   explicit HloSendDoneInstruction(HloSendInstruction* operand,
255                                   bool is_host_transfer);
256 
257  private:
258   // Implementation for non-common logic of CloneWithNewOperands.
259   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
260       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
261       HloCloneContext* context) const override;
262 };
263 
264 class HloRecvInstruction : public HloSendRecvInstruction {
265  public:
266   explicit HloRecvInstruction(const Shape& shape, HloInstruction* token,
267                               int64 channel_id, bool is_host_transfer);
268 
269  private:
270   // Implementation for non-common logic of CloneWithNewOperands.
271   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
272       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
273       HloCloneContext* context) const override;
274 };
275 
276 class HloRecvDoneInstruction : public HloSendRecvInstruction {
277  public:
278   explicit HloRecvDoneInstruction(HloRecvInstruction* operand,
279                                   bool is_host_transfer);
280 
281  private:
282   // Implementation for non-common logic of CloneWithNewOperands.
283   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
284       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
285       HloCloneContext* context) const override;
286 };
287 
288 class HloCollectiveInstruction : public HloInstruction {
289  public:
replica_groups()290   const std::vector<ReplicaGroup>& replica_groups() const {
291     return replica_groups_;
292   }
293 
294  protected:
295   explicit HloCollectiveInstruction(
296       HloOpcode opcode, const Shape& shape,
297       absl::Span<HloInstruction* const> operands,
298       const std::vector<ReplicaGroup>& replica_groups);
299 
300   HloInstructionProto ToProto() const override;
301 
302   std::vector<string> ExtraAttributesToStringImpl(
303       const HloPrintOptions& options) const override;
304   bool IdenticalSlowPath(
305       const HloInstruction& other,
306       const std::function<bool(const HloComputation*, const HloComputation*)>&
307           eq_computations) const override;
308 
309   std::vector<ReplicaGroup> replica_groups_;
310 };
311 
312 class HloAllReduceInstruction : public HloCollectiveInstruction {
313  public:
314   explicit HloAllReduceInstruction(
315       const Shape& shape, absl::Span<HloInstruction* const> operands,
316       HloComputation* reduce_computation,
317       const std::vector<ReplicaGroup>& replica_groups,
318       absl::string_view barrier, const absl::optional<int64>& all_reduce_id);
319 
320   // Returns the barrier config used for the AllReduce implementation of
321   // each backend.
all_reduce_barrier()322   string all_reduce_barrier() const { return all_reduce_barrier_; }
set_all_reduce_barrier(string barrier)323   void set_all_reduce_barrier(string barrier) { all_reduce_barrier_ = barrier; }
324 
all_reduce_id()325   absl::optional<int64> all_reduce_id() const { return all_reduce_id_; }
326   void set_all_reduce_id(const absl::optional<int64>& all_reduce_id);
327 
328   // Returns a serialized representation of this instruction.
329   HloInstructionProto ToProto() const override;
330 
331   // Returns true if the AllReduce does no communication, so it's equivalent
332   // to a mem copy.
333   bool IsNoop() const;
334 
335  private:
336   std::vector<string> ExtraAttributesToStringImpl(
337       const HloPrintOptions& options) const override;
338   bool IdenticalSlowPath(
339       const HloInstruction& other,
340       const std::function<bool(const HloComputation*, const HloComputation*)>&
341           eq_computations) const override;
342 
343   // Implementation for non-common logic of CloneWithNewOperands.
344   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
345       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
346       HloCloneContext* context) const override;
347 
348   // The string representation of the barrier config used for AllReduce.
349   string all_reduce_barrier_;
350 
351   // For Allreduce nodes from different modules, if they have the same
352   // all_reduce_id, they will be 'Allreduce'd. If empty, Allreduce will not be
353   // applied cross modules.
354   absl::optional<int64> all_reduce_id_;
355 };
356 
357 class HloAllToAllInstruction : public HloCollectiveInstruction {
358  public:
359   explicit HloAllToAllInstruction(
360       const Shape& shape, absl::Span<HloInstruction* const> operands,
361       const std::vector<ReplicaGroup>& replica_groups);
362 
363  private:
364   // Implementation for non-common logic of CloneWithNewOperands.
365   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
366       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
367       HloCloneContext* context) const override;
368 };
369 
370 class HloCollectivePermuteInstruction : public HloInstruction {
371  public:
372   explicit HloCollectivePermuteInstruction(
373       const Shape& shape, HloInstruction* operand,
374       const std::vector<std::pair<int64, int64>>& source_target_pairs);
375 
source_target_pairs()376   const std::vector<std::pair<int64, int64>>& source_target_pairs() const {
377     return source_target_pairs_;
378   }
379 
380   // Returns a serialized representation of this instruction.
381   HloInstructionProto ToProto() const override;
382 
383  private:
384   std::vector<string> ExtraAttributesToStringImpl(
385       const HloPrintOptions& options) const override;
386   bool IdenticalSlowPath(
387       const HloInstruction& other,
388       const std::function<bool(const HloComputation*, const HloComputation*)>&
389           eq_computations) const override;
390 
391   // Implementation for non-common logic of CloneWithNewOperands.
392   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
393       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
394       HloCloneContext* context) const override;
395 
396   const std::vector<std::pair<int64, int64>> source_target_pairs_;
397 };
398 
399 class HloReverseInstruction : public HloInstruction {
400  public:
401   explicit HloReverseInstruction(const Shape& shape, HloInstruction* operand,
402                                  absl::Span<const int64> dimensions);
403   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()404   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)405   int64 dimensions(int64 index) const override { return dimensions()[index]; }
406   // Returns a serialized representation of this instruction.
407   HloInstructionProto ToProto() const override;
408 
409  private:
410   std::vector<string> ExtraAttributesToStringImpl(
411       const HloPrintOptions& options) const override;
412   bool IdenticalSlowPath(
413       const HloInstruction& other,
414       const std::function<bool(const HloComputation*, const HloComputation*)>&
415           eq_computations) const override;
416   // Implementation for non-common logic of CloneWithNewOperands.
417   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
418       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
419       HloCloneContext* context) const override;
420 
421   std::vector<int64> dimensions_;
422 };
423 
424 class HloConcatenateInstruction : public HloInstruction {
425  public:
426   explicit HloConcatenateInstruction(const Shape& shape,
427                                      absl::Span<HloInstruction* const> operands,
428                                      int64 dimension);
429   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()430   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)431   int64 dimensions(int64 index) const override { return dimensions()[index]; }
432   // Accessor for the dimension in which a concatenate HLO should occur.
concatenate_dimension()433   int64 concatenate_dimension() const { return dimensions(0); }
434   // Returns a serialized representation of this instruction.
435   HloInstructionProto ToProto() const override;
436 
437  private:
438   std::vector<string> ExtraAttributesToStringImpl(
439       const HloPrintOptions& options) const override;
440   bool IdenticalSlowPath(
441       const HloInstruction& other,
442       const std::function<bool(const HloComputation*, const HloComputation*)>&
443           eq_computations) const override;
444   // Implementation for non-common logic of CloneWithNewOperands.
445   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
446       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
447       HloCloneContext* context) const override;
448 
449   std::vector<int64> dimensions_;
450 };
451 
452 class HloReduceInstruction : public HloInstruction {
453  public:
454   explicit HloReduceInstruction(const Shape& shape,
455                                 absl::Span<HloInstruction* const> args,
456                                 absl::Span<const int64> dimensions_to_reduce,
457                                 HloComputation* reduce_computation);
458   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()459   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)460   int64 dimensions(int64 index) const override { return dimensions()[index]; }
461   // Returns a serialized representation of this instruction.
462   HloInstructionProto ToProto() const override;
463 
464   // Returns the number of input arrays (and, consequentially, the number of
465   // init values) this reduce has.
input_count()466   int64 input_count() const { return operand_count() / 2; }
467 
468   // Returns the input tensors to be reduced.
inputs()469   absl::Span<HloInstruction* const> inputs() const {
470     return absl::MakeSpan(operands()).subspan(0, input_count());
471   }
472 
473   // Returns the init values of the reduction.
init_values()474   absl::Span<HloInstruction* const> init_values() const {
475     return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
476   }
477 
478  private:
479   std::vector<string> ExtraAttributesToStringImpl(
480       const HloPrintOptions& options) const override;
481   bool IdenticalSlowPath(
482       const HloInstruction& other,
483       const std::function<bool(const HloComputation*, const HloComputation*)>&
484           eq_computations) const override;
485   // Implementation for non-common logic of CloneWithNewOperands.
486   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
487       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
488       HloCloneContext* context) const override;
489 
490   std::vector<int64> dimensions_;
491 };
492 
493 class HloSortInstruction : public HloInstruction {
494  public:
495   explicit HloSortInstruction(const Shape& shape, int64 dimension,
496                               absl::Span<HloInstruction* const> operands,
497                               HloComputation* compare, bool is_stable);
498   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()499   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)500   int64 dimensions(int64 index) const override { return dimensions()[index]; }
501   // Returns the sort dimension for this instruction
sort_dimension()502   int64 sort_dimension() const { return dimensions(0); }
503   // Returns a serialized representation of this instruction.
504   HloInstructionProto ToProto() const override;
505   // Returns the key operand to this instruction.
keys()506   const HloInstruction* keys() const { return operand(0); }
mutable_keys()507   HloInstruction* mutable_keys() { return mutable_operand(0); }
508   // Returns the number of value operands.
values_count()509   int64 values_count() const { return operand_count() - 1; }
is_stable()510   bool is_stable() const { return is_stable_; }
511 
512  private:
513   std::vector<string> ExtraAttributesToStringImpl(
514       const HloPrintOptions& options) const override;
515   bool IdenticalSlowPath(
516       const HloInstruction& other,
517       const std::function<bool(const HloComputation*, const HloComputation*)>&
518           eq_computations) const override;
519   // Implementation for non-common logic of CloneWithNewOperands.
520   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
521       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
522       HloCloneContext* context) const override;
523 
524   std::vector<int64> dimensions_;
525   bool is_stable_;
526 };
527 
528 class HloTransposeInstruction : public HloInstruction {
529  public:
530   explicit HloTransposeInstruction(const Shape& shape, HloInstruction* operand,
531                                    absl::Span<const int64> dimensions);
532   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()533   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)534   int64 dimensions(int64 index) const override { return dimensions()[index]; }
535   // Returns whether this instruction does a rank-2 transposition.
536   bool IsRank2Transpose() const;
537   // Returns a serialized representation of this instruction.
538   HloInstructionProto ToProto() const override;
539 
540  private:
541   std::vector<string> ExtraAttributesToStringImpl(
542       const HloPrintOptions& options) const override;
543   bool IdenticalSlowPath(
544       const HloInstruction& other,
545       const std::function<bool(const HloComputation*, const HloComputation*)>&
546           eq_computations) const override;
547   // Implementation for non-common logic of CloneWithNewOperands.
548   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
549       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
550       HloCloneContext* context) const override;
551 
552   std::vector<int64> dimensions_;
553 };
554 
555 class HloBroadcastInstruction : public HloInstruction {
556  public:
557   explicit HloBroadcastInstruction(const Shape& shape, HloInstruction* operand,
558                                    absl::Span<const int64> broadcast_dimension);
559   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()560   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)561   int64 dimensions(int64 index) const override { return dimensions()[index]; }
562   // Returns a serialized representation of this instruction.
563   HloInstructionProto ToProto() const override;
564 
565  private:
566   std::vector<string> ExtraAttributesToStringImpl(
567       const HloPrintOptions& options) const override;
568   bool IdenticalSlowPath(
569       const HloInstruction& other,
570       const std::function<bool(const HloComputation*, const HloComputation*)>&
571           eq_computations) const override;
572   // Implementation for non-common logic of CloneWithNewOperands.
573   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
574       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
575       HloCloneContext* context) const override;
576 
577   std::vector<int64> dimensions_;
578 };
579 
580 class HloMapInstruction : public HloInstruction {
581  public:
582   explicit HloMapInstruction(const Shape& shape,
583                              absl::Span<HloInstruction* const> operands,
584                              HloComputation* map_computation);
585   // Returns the dimension sizes or numbers associated with this instruction.
dimensions()586   const std::vector<int64>& dimensions() const override { return dimensions_; }
dimensions(int64 index)587   int64 dimensions(int64 index) const override { return dimensions()[index]; }
588   // Returns a serialized representation of this instruction.
589   HloInstructionProto ToProto() const override;
590 
591  private:
592   bool IsElementwiseImpl(
593       const absl::optional<int64>& operand_idx) const override;
594   std::vector<string> ExtraAttributesToStringImpl(
595       const HloPrintOptions& options) const override;
596   bool IdenticalSlowPath(
597       const HloInstruction& other,
598       const std::function<bool(const HloComputation*, const HloComputation*)>&
599           eq_computations) const override;
600   // Implementation for non-common logic of CloneWithNewOperands.
601   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
602       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
603       HloCloneContext* context) const override;
604 
605   std::vector<int64> dimensions_;
606 };
607 
608 class HloSliceInstruction : public HloInstruction {
609  public:
610   explicit HloSliceInstruction(const Shape& shape, HloInstruction* operand,
611                                absl::Span<const int64> start_indices,
612                                absl::Span<const int64> limit_indices,
613                                absl::Span<const int64> strides);
614 
615   HloInstructionProto ToProto() const override;
616 
617   // Returns the start index in the given dimension for a slice node.
slice_starts(int64 dimension)618   int64 slice_starts(int64 dimension) const { return slice_starts_[dimension]; }
slice_starts()619   const std::vector<int64>& slice_starts() const { return slice_starts_; }
620 
621   // Returns the (exclusive) limit index in the given dimension for a slice
622   // node.
slice_limits(int64 dimension)623   int64 slice_limits(int64 dimension) const { return slice_limits_[dimension]; }
slice_limits()624   const std::vector<int64>& slice_limits() const { return slice_limits_; }
625 
626   // Returns the stride in the given dimension for a slice node.
slice_strides(int64 dimension)627   int64 slice_strides(int64 dimension) const {
628     return slice_strides_[dimension];
629   }
slice_strides()630   const std::vector<int64>& slice_strides() const { return slice_strides_; }
631 
632  private:
633   std::vector<string> ExtraAttributesToStringImpl(
634       const HloPrintOptions& options) const override;
635   bool IdenticalSlowPath(
636       const HloInstruction& other,
637       const std::function<bool(const HloComputation*, const HloComputation*)>&
638           eq_computations) const override;
639   // Implementation for non-common logic of CloneWithNewOperands.
640   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
641       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
642       HloCloneContext* context) const override;
643 
644   // Describes the [begin, end) index range for a slice.
645   std::vector<int64> slice_starts_;
646   std::vector<int64> slice_limits_;
647   std::vector<int64> slice_strides_;
648 };
649 
650 class HloConstantInstruction : public HloInstruction {
651  public:
652   explicit HloConstantInstruction(Literal literal);
653   // Used when the literal is too large and dropped.
654   explicit HloConstantInstruction(const Shape& shape);
655   // Returns the literal associated with this instruction.
literal()656   const Literal& literal() const { return *literal_; }
657   // Returns whether there is literal associated with this instruction.
HasLiteral()658   bool HasLiteral() const { return literal_.has_value(); }
659   // Returns a serialized representation of this instruction.
660   HloInstructionProto ToProto() const override;
661 
662   // Change the layout for an Constant Hlo instruction to match new_layout.  For
663   // tuple shaped constants shape_index is the path to the internal array
664   // subshape whose layout needs to be changed.
665   void RelayoutConstant(const Layout& new_layout,
666                         const ShapeIndex& shape_index = {});
667 
668  private:
669   bool IsElementwiseImpl(
670       const absl::optional<int64>& operand_idx) const override;
671   bool IdenticalSlowPath(
672       const HloInstruction& other,
673       const std::function<bool(const HloComputation*, const HloComputation*)>&
674           eq_computations) const override;
675   string OperandsToStringWithCanonicalNameMap(
676       const HloPrintOptions& options,
677       CanonicalNameMap* canonical_name_map) const override;
678   // Implementation for non-common logic of CloneWithNewOperands.
679   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
680       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
681       HloCloneContext* context) const override;
682   absl::optional<Literal> literal_;
683 };
684 
685 class HloTraceInstruction : public HloInstruction {
686  public:
687   explicit HloTraceInstruction(const string& tag, HloInstruction* operand);
688   // Returns a tag to be used in tracing.
TracingTag()689   string TracingTag() const { return literal_.GetR1U8AsString(); }
690   // Returns a serialized representation of this instruction.
691   HloInstructionProto ToProto() const override;
692 
693  private:
694   bool IdenticalSlowPath(
695       const HloInstruction& other,
696       const std::function<bool(const HloComputation*, const HloComputation*)>&
697           eq_computations) const override;
698   // Implementation for non-common logic of CloneWithNewOperands.
699   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
700       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
701       HloCloneContext* context) const override;
702   Literal literal_;
703 };
704 
705 class HloFusionInstruction : public HloInstruction {
706  public:
707   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
708                                 HloInstruction* fused_root);
709 
710   explicit HloFusionInstruction(const Shape& shape, FusionKind fusion_kind,
711                                 absl::Span<HloInstruction* const> operands,
712                                 HloComputation* fusion_computation);
713 
714   string ToCategory() const override;
715   // Returns a serialized representation of this instruction.
716   HloInstructionProto ToProto() const override;
717 
718   // Adds a new operand the fusion instruction.
719   HloInstruction* AddFusionOperand(HloInstruction* new_operand);
720 
721   // Merges the fused instructions from 'instruction_to_merge' into the
722   // fused instruction set of 'this', updating operands as necessary.
723   //
724   // Predondition: 'instruction_to_merge' must be an operand of 'this'.
725   void MergeFusionInstruction(HloFusionInstruction* instruction_to_merge);
726 
727   // Merges the fused instructions from instruction_to_merge into the fused
728   // instruction set of 'this' and generates multioutput fusion instructions.
729   // All the users of instruction_to_merge will be redirected to 'this'
730   // instruction. instruction_to_merge will be removed from its parent
731   // computation.
732   void MergeFusionInstructionIntoMultiOutput(
733       HloFusionInstruction* instruction_to_merge);
734 
735   // Fuses the given instruction in this fusion instruction. instruction_to_fuse
736   // is cloned and the clone is placed in the fusion
737   // instruction. instruction_to_fuse is unchanged. Instruction is cloned rather
738   // than moved to cleanly handle the case where the instruction has a use
739   // outside the fusion instruction. Moving such an instruction into a fusion
740   // instruction would violate the single-result invariant of HLO instructions
741   // and significantly complicate code generation.
FuseInstruction(HloInstruction * instruction_to_fuse)742   HloInstruction* FuseInstruction(HloInstruction* instruction_to_fuse) {
743     return FuseInstructionInternal(instruction_to_fuse);
744   }
745 
746   // Fuses the given instruction in this fusion instruction and generate
747   // multioutput fusion instruction. A clone of the instruction_to_fuse will
748   // be part of the output of fusion instructions. The users of
749   // instruction_to_fuse will be redirected to this fusion instructions.
750   // instruction_to_fuse will be removed from its parent computation.
FuseInstructionIntoMultiOutput(HloInstruction * instruction_to_fuse)751   HloInstruction* FuseInstructionIntoMultiOutput(
752       HloInstruction* instruction_to_fuse) {
753     return FuseInstructionInternal(instruction_to_fuse, /* add_output */ true);
754   }
755 
756   // Returns the computation for this fused instruction.
757   HloComputation* fused_instructions_computation() const;
758 
759   // Returns the root instruction of the fused expression contained within this
760   // fusion instruction.
761   HloInstruction* fused_expression_root() const;
762 
763   // Returns the list of fused instructions inside this fusion instruction.  The
764   // returned type is a range of HloInstruction*s.
765   const tensorflow::gtl::iterator_range<UnwrappingIterator<
766       std::list<std::unique_ptr<HloInstruction>>::const_iterator>>
767   fused_instructions() const;
768 
769   const tensorflow::gtl::iterator_range<
770       UnwrappingIterator<std::list<std::unique_ptr<HloInstruction>>::iterator>>
771   fused_instructions();
772 
773   // Gets the number of instructions inside this fusion instruction.
774   int64 fused_instruction_count() const;
775 
776   // Returns the fused parameter instruction in this fusion instruction
777   // corresponding to the given parameter number.
778   HloInstruction* fused_parameter(int64 parameter_number) const;
779 
780   // Returns the vector of fused parameters inside this fusion instruction.
781   const std::vector<HloInstruction*>& fused_parameters() const;
782 
783   // Returns true if this instruction is a fusion instruction that generates
784   // multiple outputs.
IsMultiOutputFusion()785   const bool IsMultiOutputFusion() const {
786     return fused_expression_root()->opcode() == HloOpcode::kTuple;
787   }
788 
fusion_kind()789   FusionKind fusion_kind() const { return fusion_kind_; }
790 
set_fusion_kind(FusionKind kind)791   void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
792 
793   // If multiple operands are the same instruction, keeps only one of them.
794   Status DeduplicateFusionOperands();
795 
796  private:
797   // Fuses the given instruction into this fusion instruction. When add_output
798   // is false (which is the default), instruction_to_fuse is cloned and the
799   // clone is placed in the fusion instruction. instruction_to_fuse is
800   // unchanged.
801   //
802   // When add_output is true, a clone of the instruction_to_fuse will be part
803   // of the output of fusion instructions. The users of instruction_to_fuse
804   // will be redirected to this fusion instructions. instruction_to_fuse will
805   // be removed from its parent computation.
806   HloInstruction* FuseInstructionInternal(HloInstruction* instruction_to_fuse,
807                                           bool add_output = false);
808   // Clones the given instruction_to_fuse and insert the clone into this fusion
809   // instruction. If add_output is true, a clone of instruction_to_fuse will
810   // be in the output of the this fusion instruction (part of the tuple of the
811   // fusion root).
812   HloInstruction* CloneAndFuseInternal(HloInstruction* instruction_to_fuse,
813                                        bool add_output = false);
814 
815   bool IsElementwiseImpl(
816       const absl::optional<int64>& operand_idx) const override;
817   std::vector<string> ExtraAttributesToStringImpl(
818       const HloPrintOptions& options) const override;
819   bool IdenticalSlowPath(
820       const HloInstruction& other,
821       const std::function<bool(const HloComputation*, const HloComputation*)>&
822           eq_computations) const override;
823   uint64 InnerHash() const override;
824 
825   // Implementation for non-common logic of CloneWithNewOperands.
826   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
827       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
828       HloCloneContext* context) const override;
829 
830   // The type of the fusion. Used by kFusion only.
831   FusionKind fusion_kind_;
832 };
833 
834 class HloRngInstruction : public HloInstruction {
835  public:
836   explicit HloRngInstruction(const Shape& shape,
837                              RandomDistribution distribution,
838                              absl::Span<HloInstruction* const> parameters);
839   // Returns the random distribution for this rng node.
random_distribution()840   RandomDistribution random_distribution() const { return distribution_; }
841   // Returns a serialized representation of this instruction.
842   HloInstructionProto ToProto() const override;
843 
844  private:
845   bool IsElementwiseImpl(
846       const absl::optional<int64>& operand_idx) const override;
847   std::vector<string> ExtraAttributesToStringImpl(
848       const HloPrintOptions& options) const override;
849   bool IdenticalSlowPath(
850       const HloInstruction& other,
851       const std::function<bool(const HloComputation*, const HloComputation*)>&
852           eq_computations) const override;
853   // Implementation for non-common logic of CloneWithNewOperands.
854   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
855       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
856       HloCloneContext* context) const override;
857 
858   // The distribution requested for random number generation.
859   RandomDistribution distribution_;
860 };
861 
862 class HloParameterInstruction : public HloInstruction {
863  public:
864   explicit HloParameterInstruction(int64 parameter_number, const Shape& shape,
865                                    const string& name);
parameter_number()866   int64 parameter_number() const { return parameter_number_; }
867 
868   // Sets and gets the whether all replicas will receive the same parameter data
869   // for each leaf buffer in data parallelism.
set_parameter_replicated_at_leaf_buffers(absl::Span<const bool> parameter_replicated_at_leaf_buffers)870   void set_parameter_replicated_at_leaf_buffers(
871       absl::Span<const bool> parameter_replicated_at_leaf_buffers) {
872     CHECK_EQ(ShapeUtil::GetLeafCount(shape()),
873              parameter_replicated_at_leaf_buffers.size());
874     parameter_replicated_at_leaf_buffers_.emplace(
875         parameter_replicated_at_leaf_buffers.begin(),
876         parameter_replicated_at_leaf_buffers.end());
877   }
878   const absl::optional<std::vector<bool>>&
parameter_replicated_at_leaf_buffers()879   parameter_replicated_at_leaf_buffers() const {
880     return parameter_replicated_at_leaf_buffers_;
881   }
882 
883   // Returns a serialized representation of this instruction.
884   HloInstructionProto ToProto() const override;
885 
886  private:
887   std::vector<string> ExtraAttributesToStringImpl(
888       const HloPrintOptions& options) const override;
889   bool IdenticalSlowPath(
890       const HloInstruction& other,
891       const std::function<bool(const HloComputation*, const HloComputation*)>&
892           eq_computations) const override;
893   string OperandsToStringWithCanonicalNameMap(
894       const HloPrintOptions& options,
895       CanonicalNameMap* canonical_name_map) const override;
896   // Implementation for non-common logic of CloneWithNewOperands.
897   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
898       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
899       HloCloneContext* context) const override;
900 
901   int64 parameter_number_ = 0;
902 
903   // Specifies whether each buffer has the same parameter value on all replicas
904   // in data parallelism.
905   absl::optional<std::vector<bool>> parameter_replicated_at_leaf_buffers_;
906 };
907 
908 class HloGetTupleElementInstruction : public HloInstruction {
909  public:
910   explicit HloGetTupleElementInstruction(const Shape& shape,
911                                          HloInstruction* operand, int64 index);
912   // Returns the tuple index associated with this instruction.
tuple_index()913   int64 tuple_index() const { return tuple_index_; }
914   // Returns a serialized representation of this instruction.
915   HloInstructionProto ToProto() const override;
916 
917  private:
918   std::vector<string> ExtraAttributesToStringImpl(
919       const HloPrintOptions& options) const override;
920   bool IdenticalSlowPath(
921       const HloInstruction& other,
922       const std::function<bool(const HloComputation*, const HloComputation*)>&
923           eq_computations) const override;
924   // Implementation for non-common logic of CloneWithNewOperands.
925   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
926       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
927       HloCloneContext* context) const override;
928 
929   int64 tuple_index_ = -1;
930 };
931 
932 class HloReducePrecisionInstruction : public HloInstruction {
933  public:
934   explicit HloReducePrecisionInstruction(const Shape& shape,
935                                          HloInstruction* operand,
936                                          const int exponent_bits,
937                                          const int mantissa_bits);
938   // Returns the number of exponent bits for a reduce-precision node.
exponent_bits()939   int32 exponent_bits() const { return exponent_bits_; }
940   // Returns the number of mantissa bits for a reduce-precision node.
mantissa_bits()941   int32 mantissa_bits() const { return mantissa_bits_; }
942   // Returns a serialized representation of this instruction.
943   HloInstructionProto ToProto() const override;
944 
945  private:
946   std::vector<string> ExtraAttributesToStringImpl(
947       const HloPrintOptions& options) const override;
948   bool IdenticalSlowPath(
949       const HloInstruction& other,
950       const std::function<bool(const HloComputation*, const HloComputation*)>&
951           eq_computations) const override;
952   // Implementation for non-common logic of CloneWithNewOperands.
953   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
954       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
955       HloCloneContext* context) const override;
956 
957   // The bit sizes for a reduce-precision operation.
958   int32 exponent_bits_ = 0;
959   int32 mantissa_bits_ = 0;
960 };
961 
962 class HloInfeedInstruction : public HloInstruction {
963  public:
964   explicit HloInfeedInstruction(const Shape& infeed_shape,
965                                 HloInstruction* token_operand,
966                                 const string& config);
967   // Returns the infeed configuration string. The infeed configuration includes
968   // any metadata needed for the backend compiler (e.g., infeed buffer address)
969   // and is target-dependent.
infeed_config()970   string infeed_config() const { return infeed_config_; }
set_infeed_config(const string & config)971   void set_infeed_config(const string& config) { infeed_config_ = config; }
972   // Returns the shape of the data received by the infeed. This is not the same
973   // as the shape of the infeed instruction which produces a tuple containing
974   // the infeed data shape and a TOKEN.
infeed_shape()975   const Shape& infeed_shape() const {
976     TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape()));
977     return ShapeUtil::GetSubshape(shape(), {0});
978   }
979   // Returns a serialized representation of this instruction.
980   HloInstructionProto ToProto() const override;
981 
982  private:
983   std::vector<string> ExtraAttributesToStringImpl(
984       const HloPrintOptions& options) const override;
985   bool IdenticalSlowPath(
986       const HloInstruction& other,
987       const std::function<bool(const HloComputation*, const HloComputation*)>&
988           eq_computations) const override;
989   // Implementation for non-common logic of CloneWithNewOperands.
990   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
991       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
992       HloCloneContext* context) const override;
993 
994   // The string representation of the infeed configuration.
995   string infeed_config_;
996 };
997 
998 class HloOutfeedInstruction : public HloInstruction {
999  public:
1000   explicit HloOutfeedInstruction(const Shape& outfeed_shape,
1001                                  HloInstruction* operand,
1002                                  HloInstruction* token_operand,
1003                                  absl::string_view outfeed_config);
1004   // Returns the shape for the Outfeed instruction.
outfeed_shape()1005   const Shape& outfeed_shape() const { return outfeed_shape_; }
1006   // Returns the config for the Outfeed instruction.
outfeed_config()1007   const string& outfeed_config() const { return outfeed_config_; }
1008   // Returns a serialized representation of this instruction.
1009   HloInstructionProto ToProto() const override;
1010 
1011  private:
1012   std::vector<string> ExtraAttributesToStringImpl(
1013       const HloPrintOptions& options) const override;
1014   bool IdenticalSlowPath(
1015       const HloInstruction& other,
1016       const std::function<bool(const HloComputation*, const HloComputation*)>&
1017           eq_computations) const override;
1018   // Implementation for non-common logic of CloneWithNewOperands.
1019   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1020       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1021       HloCloneContext* context) const override;
1022 
1023   // Shape of outfeed request.
1024   Shape outfeed_shape_;
1025   // Outfeed configuration information, only present for kOutfeed.
1026   string outfeed_config_;
1027 };
1028 
1029 class HloConvolutionInstruction : public HloInstruction {
1030  public:
1031   explicit HloConvolutionInstruction(
1032       const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
1033       int64 feature_group_count, int64 batch_group_count, const Window& window,
1034       const ConvolutionDimensionNumbers& dimension_numbers,
1035       const PrecisionConfig& precision_config);
window()1036   const Window& window() const override { return window_; }
set_window(const Window & window)1037   void set_window(const Window& window) override { window_ = window; }
convolution_dimension_numbers()1038   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1039     return convolution_dimension_numbers_;
1040   }
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1041   void set_convolution_dimension_numbers(
1042       const ConvolutionDimensionNumbers& dnums) {
1043     convolution_dimension_numbers_ = dnums;
1044   }
1045   // The number of feature groups. Must be a divisor of the input feature
1046   // dimension and output feature dimension.
feature_group_count()1047   int64 feature_group_count() const { return feature_group_count_; }
1048 
1049   // The number of feature groups. Must be a divisor of the input batch
1050   // dimension.
batch_group_count()1051   int64 batch_group_count() const { return batch_group_count_; }
1052 
1053   // Returns the information used to tell the implementation information about
1054   // what sort of precision is requested. The meaning of the field is backend
1055   // specific. At the moment, it is only supported for kConvolution and kDot.
1056   // Transformations on one kDot or kConvolution to another will preserve this
1057   // information. Transformations to other HLOs will not preserve this
1058   // information but it is presumed that the alternate lowering is strictly
1059   // superior.
precision_config()1060   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1061   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1062 
1063   string ToCategory() const override;
1064   // Returns a serialized representation of this instruction.
1065   HloInstructionProto ToProto() const override;
1066 
1067  private:
1068   std::vector<string> ExtraAttributesToStringImpl(
1069       const HloPrintOptions& options) const override;
1070   bool IdenticalSlowPath(
1071       const HloInstruction& other,
1072       const std::function<bool(const HloComputation*, const HloComputation*)>&
1073           eq_computations) const override;
1074   // Implementation for non-common logic of CloneWithNewOperands.
1075   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1076       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1077       HloCloneContext* context) const override;
1078   // The number of feature groups. Must be a divisor of the input feature
1079   // dimension and output feature dimension.
1080   int64 feature_group_count_;
1081   // The number of feature groups. Must be a divisor of the input batch
1082   // dimension.
1083   int64 batch_group_count_;
1084   // Describes the window used for a convolution.
1085   Window window_;
1086   // Describes the dimension numbers used for a convolution.
1087   ConvolutionDimensionNumbers convolution_dimension_numbers_;
1088   // Information used to communicate to the implementation about the algorithm
1089   // used to produce results. See the documentation on precision_config().
1090   PrecisionConfig precision_config_;
1091 };
1092 
1093 class HloReduceWindowInstruction : public HloInstruction {
1094  public:
1095   explicit HloReduceWindowInstruction(const Shape& shape,
1096                                       HloInstruction* operand,
1097                                       HloInstruction* init_value,
1098                                       const Window& window,
1099                                       HloComputation* reduce_computation);
window()1100   const Window& window() const override { return window_; }
set_window(const Window & window)1101   void set_window(const Window& window) override { window_ = window; }
1102   // Returns a serialized representation of this instruction.
1103   HloInstructionProto ToProto() const override;
1104 
1105  private:
1106   std::vector<string> ExtraAttributesToStringImpl(
1107       const HloPrintOptions& options) const override;
1108   bool IdenticalSlowPath(
1109       const HloInstruction& other,
1110       const std::function<bool(const HloComputation*, const HloComputation*)>&
1111           eq_computations) const override;
1112   // Implementation for non-common logic of CloneWithNewOperands.
1113   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1114       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1115       HloCloneContext* context) const override;
1116   Window window_;
1117 };
1118 
1119 class HloSelectAndScatterInstruction : public HloInstruction {
1120  public:
1121   explicit HloSelectAndScatterInstruction(
1122       const Shape& shape, HloInstruction* operand, HloComputation* select,
1123       const Window& window, HloInstruction* source, HloInstruction* init_value,
1124       HloComputation* scatter);
window()1125   const Window& window() const override { return window_; }
set_window(const Window & window)1126   void set_window(const Window& window) override { window_ = window; }
1127   // Gets/sets the select or scatter HloComputation for SelectAndScatter. The
1128   // setters should only be called by HloModule or HloComputation methods.
select()1129   HloComputation* select() const {
1130     return called_computations()[kSelectComputationIndex];
1131   }
1132 
scatter()1133   HloComputation* scatter() const {
1134     return called_computations()[kScatterComputationIndex];
1135   }
1136 
set_select(HloComputation * computation)1137   void set_select(HloComputation* computation) {
1138     // Don't allow changing the computation for fused instructions so we don't
1139     // have to recompute called_instructions for the entire fusion instruction.
1140     CHECK(!IsFused());
1141     set_called_computation(kSelectComputationIndex, computation);
1142   }
1143 
set_scatter(HloComputation * computation)1144   void set_scatter(HloComputation* computation) {
1145     // Don't allow changing the computation for fused instructions so we don't
1146     // have to recompute called_instructions for the entire fusion instruction.
1147     CHECK(!IsFused());
1148     set_called_computation(kScatterComputationIndex, computation);
1149   }
1150   // Returns a serialized representation of this instruction.
1151   HloInstructionProto ToProto() const override;
1152 
1153  private:
1154   std::vector<string> ExtraAttributesToStringImpl(
1155       const HloPrintOptions& options) const override;
1156   bool IdenticalSlowPath(
1157       const HloInstruction& other,
1158       const std::function<bool(const HloComputation*, const HloComputation*)>&
1159           eq_computations) const override;
1160   // Implementation for non-common logic of CloneWithNewOperands.
1161   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1162       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1163       HloCloneContext* context) const override;
1164   Window window_;
1165 };
1166 
1167 class HloCustomCallInstruction : public HloInstruction {
1168  public:
1169   HloCustomCallInstruction(const Shape& shape,
1170                            absl::Span<HloInstruction* const> operands,
1171                            absl::string_view custom_call_target,
1172                            absl::string_view opaque);
1173 
1174   // Constructor for a custom call with constrained layout. 'shape' and
1175   // 'operands_with_layout' must all have layouts.
1176   HloCustomCallInstruction(const Shape& shape,
1177                            absl::Span<HloInstruction* const> operands,
1178                            absl::string_view custom_call_target,
1179                            absl::string_view opaque,
1180                            absl::Span<const Shape> operand_shapes_with_layout);
1181 
window()1182   const Window& window() const override {
1183     CHECK(window_ != nullptr);
1184     return *window_;
1185   }
1186 
set_window(const Window & window)1187   void set_window(const Window& window) override {
1188     window_ = absl::make_unique<Window>(window);
1189   }
1190 
convolution_dimension_numbers()1191   const ConvolutionDimensionNumbers& convolution_dimension_numbers() const {
1192     CHECK(convolution_dimension_numbers_ != nullptr);
1193     return *convolution_dimension_numbers_;
1194   }
1195 
set_convolution_dimension_numbers(const ConvolutionDimensionNumbers & dnums)1196   void set_convolution_dimension_numbers(
1197       const ConvolutionDimensionNumbers& dnums) {
1198     convolution_dimension_numbers_ =
1199         absl::make_unique<ConvolutionDimensionNumbers>(dnums);
1200   }
opaque()1201   const string& opaque() const { return opaque_; }
custom_call_target()1202   const string& custom_call_target() const { return custom_call_target_; }
set_feature_group_count(int64 feature_group_count)1203   void set_feature_group_count(int64 feature_group_count) {
1204     feature_group_count_ = feature_group_count;
1205   }
set_batch_group_count(int64 batch_group_count)1206   void set_batch_group_count(int64 batch_group_count) {
1207     batch_group_count_ = batch_group_count;
1208   }
feature_group_count()1209   int64 feature_group_count() const { return feature_group_count_; }
batch_group_count()1210   int64 batch_group_count() const { return batch_group_count_; }
1211   // Returns a serialized representation of this instruction.
1212   HloInstructionProto ToProto() const override;
1213 
1214   // Returns whether the result and operand layouts are constrained.
layout_constrained()1215   bool layout_constrained() const { return layout_constrained_; }
1216 
1217   // Returns the shapes (with layout) of the operands. CHECKs if this custom
1218   // call does not have constrained layouts.
operand_shapes_with_layout()1219   const std::vector<Shape>& operand_shapes_with_layout() const {
1220     CHECK(layout_constrained());
1221     return operand_shapes_with_layout_;
1222   }
1223 
1224  private:
1225   std::vector<string> ExtraAttributesToStringImpl(
1226       const HloPrintOptions& options) const override;
1227   bool IdenticalSlowPath(
1228       const HloInstruction& other,
1229       const std::function<bool(const HloComputation*, const HloComputation*)>&
1230           eq_computations) const override;
1231   // Implementation for non-common logic of CloneWithNewOperands.
1232   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1233       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1234       HloCloneContext* context) const override;
1235   // Name of a global symbol to call.
1236   string custom_call_target_;
1237   // Opaque string interpreted by the backend.
1238   string opaque_;
1239   // Describes the window in a windowed operation such as convolution.
1240   std::unique_ptr<Window> window_;
1241   // Describes the dimension numbers used for a convolution.
1242   std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
1243   // The number of feature groups. This is used for grouped convolutions.
1244   int64 feature_group_count_;
1245   int64 batch_group_count_;
1246   // Whether the result and operand layouts are constrained.
1247   bool layout_constrained_;
1248   // For layout-constrained custom calls, this vector holds the shape with
1249   // layout for each operand.
1250   std::vector<Shape> operand_shapes_with_layout_;
1251 };
1252 
1253 class HloPadInstruction : public HloInstruction {
1254  public:
1255   explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
1256                              HloInstruction* padding_value,
1257                              const PaddingConfig& padding_config);
1258   // Returns the padding configuration for a pad node.
padding_config()1259   const PaddingConfig& padding_config() const { return padding_config_; }
1260   // Returns the padding value.
padding_value()1261   const HloInstruction* padding_value() const { return operand(1); }
mutable_padding_value()1262   HloInstruction* mutable_padding_value() { return mutable_operand(1); }
1263   // Returns a serialized representation of this instruction.
1264   HloInstructionProto ToProto() const override;
1265 
1266  private:
1267   std::vector<string> ExtraAttributesToStringImpl(
1268       const HloPrintOptions& options) const override;
1269   bool IdenticalSlowPath(
1270       const HloInstruction& other,
1271       const std::function<bool(const HloComputation*, const HloComputation*)>&
1272           eq_computations) const override;
1273   // Implementation for non-common logic of CloneWithNewOperands.
1274   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1275       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1276       HloCloneContext* context) const override;
1277 
1278   // The padding configuration that describes the edge padding and interior
1279   // padding of this pad instruction.
1280   PaddingConfig padding_config_;
1281 };
1282 
1283 class HloDynamicIndexInstruction : public HloInstruction {
1284  public:
HloDynamicIndexInstruction(HloOpcode opcode,const Shape & shape)1285   explicit HloDynamicIndexInstruction(HloOpcode opcode, const Shape& shape)
1286       : HloInstruction(opcode, shape) {}
1287   virtual int64 first_index_operand_number() const = 0;
1288 
1289   // Returns a subspan of operands which represent the start indices.
index_operands()1290   absl::Span<HloInstruction* const> index_operands() const {
1291     return absl::MakeSpan(operands()).subspan(first_index_operand_number());
1292   }
1293 
1294   // Returns the shapes of the index operands.
index_shapes()1295   std::vector<Shape> index_shapes() const {
1296     std::vector<Shape> shapes;
1297     auto indices = index_operands();
1298     for (const HloInstruction* index : indices) {
1299       shapes.push_back(index->shape());
1300     }
1301     return shapes;
1302   }
1303 };
1304 
1305 class HloDynamicSliceInstruction : public HloDynamicIndexInstruction {
1306  public:
1307   explicit HloDynamicSliceInstruction(const Shape& shape,
1308                                       HloInstruction* operand,
1309                                       HloInstruction* start_indices,
1310                                       absl::Span<const int64> slice_sizes);
1311   explicit HloDynamicSliceInstruction(
1312       const Shape& shape, HloInstruction* operand,
1313       absl::Span<HloInstruction* const> start_indices,
1314       absl::Span<const int64> slice_sizes);
1315   // Old methods kept for smooth subclassing transition END.
1316   // Returns the size of the slice in the given dimension for a dynamic
1317   // slice node.
slice_sizes(int64 dimension)1318   int64 slice_sizes(int64 dimension) const {
1319     return dynamic_slice_sizes_[dimension];
1320   }
dynamic_slice_sizes()1321   const std::vector<int64>& dynamic_slice_sizes() const {
1322     return dynamic_slice_sizes_;
1323   }
1324   // Returns a serialized representation of this instruction.
1325   HloInstructionProto ToProto() const override;
1326 
first_index_operand_number()1327   int64 first_index_operand_number() const override { return 1; }
1328 
1329  private:
1330   std::vector<string> ExtraAttributesToStringImpl(
1331       const HloPrintOptions& options) const override;
1332   bool IdenticalSlowPath(
1333       const HloInstruction& other,
1334       const std::function<bool(const HloComputation*, const HloComputation*)>&
1335           eq_computations) const override;
1336   // Implementation for non-common logic of CloneWithNewOperands.
1337   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1338       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1339       HloCloneContext* context) const override;
1340 
1341   // Describes the [start, start + size) range size for a dynamic slice
1342   // ('start' is specified dynamically in the second operand of the operation).
1343   std::vector<int64> dynamic_slice_sizes_;
1344 };
1345 
1346 class HloDynamicUpdateSliceInstruction : public HloDynamicIndexInstruction {
1347  public:
1348   explicit HloDynamicUpdateSliceInstruction(const Shape& shape,
1349                                             HloInstruction* operand,
1350                                             HloInstruction* update,
1351                                             HloInstruction* start_indices);
1352   explicit HloDynamicUpdateSliceInstruction(
1353       const Shape& shape, HloInstruction* operand, HloInstruction* update,
1354       absl::Span<HloInstruction* const> start_indices);
1355 
first_index_operand_number()1356   int64 first_index_operand_number() const override { return 2; }
1357 };
1358 
1359 class HloGatherInstruction : public HloInstruction {
1360  public:
1361   explicit HloGatherInstruction(
1362       const Shape& shape, HloInstruction* operand,
1363       HloInstruction* start_indices,
1364       const GatherDimensionNumbers& gather_dim_numbers,
1365       absl::Span<const int64> slice_sizes);
gather_dimension_numbers()1366   const GatherDimensionNumbers& gather_dimension_numbers() const {
1367     CHECK(gather_dimension_numbers_ != nullptr);
1368     return *gather_dimension_numbers_;
1369   }
gather_slice_sizes()1370   absl::Span<const int64> gather_slice_sizes() const {
1371     return gather_slice_sizes_;
1372   }
1373   // Returns the dump string of the gather dimension numbers.
1374   string GatherDimensionNumbersToString() const;
1375   // Returns a serialized representation of this instruction.
1376   HloInstructionProto ToProto() const override;
1377 
1378   // Creates an instance of GatherDimensionNumbers.
1379   static GatherDimensionNumbers MakeGatherDimNumbers(
1380       absl::Span<const int64> offset_dims,
1381       absl::Span<const int64> collapsed_slice_dims,
1382       absl::Span<const int64> start_index_map, int64 index_vector_dim);
1383 
1384  private:
1385   std::vector<string> ExtraAttributesToStringImpl(
1386       const HloPrintOptions& options) const override;
1387   bool IdenticalSlowPath(
1388       const HloInstruction& other,
1389       const std::function<bool(const HloComputation*, const HloComputation*)>&
1390           eq_computations) const override;
1391   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1392       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1393       HloCloneContext* context) const override;
1394 
1395   std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
1396   std::vector<int64> gather_slice_sizes_;
1397 };
1398 
1399 class HloScatterInstruction : public HloInstruction {
1400  public:
1401   explicit HloScatterInstruction(
1402       const Shape& shape, HloInstruction* operand,
1403       HloInstruction* scatter_indices, HloInstruction* updates,
1404       HloComputation* update_computation,
1405       const ScatterDimensionNumbers& scatter_dim_numbers);
scatter_dimension_numbers()1406   const ScatterDimensionNumbers& scatter_dimension_numbers() const {
1407     CHECK(scatter_dimension_numbers_ != nullptr);
1408     return *scatter_dimension_numbers_;
1409   }
1410   // Returns the dump string of the scatter dimension numbers.
1411   string ScatterDimensionNumbersToString() const;
1412   // Returns a serialized representation of this instruction.
1413   HloInstructionProto ToProto() const override;
1414 
1415   // Creates an instance of ScatterDimensionNumbers.
1416   static ScatterDimensionNumbers MakeScatterDimNumbers(
1417       absl::Span<const int64> update_window_dims,
1418       absl::Span<const int64> inserted_window_dims,
1419       absl::Span<const int64> scatter_dims_to_operand_dims,
1420       int64 index_vector_dim);
1421 
1422  private:
1423   std::vector<string> ExtraAttributesToStringImpl(
1424       const HloPrintOptions& options) const override;
1425   bool IdenticalSlowPath(
1426       const HloInstruction& other,
1427       const std::function<bool(const HloComputation*, const HloComputation*)>&
1428           eq_computations) const override;
1429   // Implementation for non-common logic of CloneWithNewOperands.
1430   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1431       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1432       HloCloneContext* context) const override;
1433 
1434   std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
1435 };
1436 
1437 class HloIotaInstruction : public HloInstruction {
1438  public:
1439   explicit HloIotaInstruction(const Shape& shape, int64 iota_dimension);
1440   // Returns the dimension sizes or numbers associated with this instruction.
iota_dimension()1441   int64 iota_dimension() const { return iota_dimension_; }
1442   // Returns a serialized representation of this instruction.
1443   HloInstructionProto ToProto() const override;
1444 
1445  private:
1446   std::vector<string> ExtraAttributesToStringImpl(
1447       const HloPrintOptions& options) const override;
1448   bool IdenticalSlowPath(
1449       const HloInstruction& other,
1450       const std::function<bool(const HloComputation*, const HloComputation*)>&
1451           eq_computations) const override;
1452   // Implementation for non-common logic of CloneWithNewOperands.
1453   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1454       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1455       HloCloneContext* context) const override;
1456 
1457   const int64 iota_dimension_;
1458 };
1459 
1460 class HloDotInstruction : public HloInstruction {
1461  public:
1462   // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
1463   // dimensions specified in 'dimension_numbers'.
1464   explicit HloDotInstruction(const Shape& shape, HloInstruction* lhs,
1465                              HloInstruction* rhs,
1466                              const DotDimensionNumbers& dimension_numbers,
1467                              const PrecisionConfig& precision_config);
1468 
1469   // Returns data on the dimension numbers used for a dot operation.
dot_dimension_numbers()1470   const DotDimensionNumbers& dot_dimension_numbers() const {
1471     return dot_dimension_numbers_;
1472   }
1473 
1474   // Returns the information used to tell the implementation information about
1475   // what sort of precision is requested. The meaning of the field is backend
1476   // specific. At the moment, it is only supported for kConvolution and kDot.
1477   // Transformations on one kDot or kConvolution to another will preserve this
1478   // information. Transformations to other HLOs will not preserve this
1479   // information but it is presumed that the alternate lowering is strictly
1480   // superior.
precision_config()1481   const PrecisionConfig& precision_config() const { return precision_config_; }
mutable_precision_config()1482   PrecisionConfig* mutable_precision_config() { return &precision_config_; }
1483 
1484   // Returns a serialized representation of this instruction.
1485   HloInstructionProto ToProto() const override;
1486 
1487  private:
1488   std::vector<string> ExtraAttributesToStringImpl(
1489       const HloPrintOptions& options) const override;
1490   bool IdenticalSlowPath(
1491       const HloInstruction& other,
1492       const std::function<bool(const HloComputation*, const HloComputation*)>&
1493           eq_computations) const override;
1494   // Implementation for non-common logic of CloneWithNewOperands.
1495   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1496       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1497       HloCloneContext* context) const override;
1498   // Returns the dump string of the dot dimension numbers.
1499   string DotDimensionNumbersToString() const;
1500 
1501   // Describes the dimension numbers used for a dot.
1502   DotDimensionNumbers dot_dimension_numbers_;
1503 
1504   // Information used to communicate to the implementation about the algorithm
1505   // used to produce results. See the documentation on precision_config().
1506   PrecisionConfig precision_config_;
1507 };
1508 
1509 class HloDomainInstruction : public HloInstruction {
1510  public:
1511   explicit HloDomainInstruction(
1512       const Shape& shape, HloInstruction* operand,
1513       std::unique_ptr<DomainMetadata> operand_side_metadata,
1514       std::unique_ptr<DomainMetadata> user_side_metadata);
1515 
1516   // Returns a serialized representation of this instruction.
1517   HloInstructionProto ToProto() const override;
1518 
1519   // Retrieves the operand side metadata of a kDomain instruction.
operand_side_metadata()1520   const DomainMetadata& operand_side_metadata() const {
1521     return *operand_side_metadata_;
1522   }
1523   // Retrieves the user side metadata of a kDomain instruction.
user_side_metadata()1524   const DomainMetadata& user_side_metadata() const {
1525     return *user_side_metadata_;
1526   }
1527 
1528  private:
1529   std::vector<string> ExtraAttributesToStringImpl(
1530       const HloPrintOptions& options) const override;
1531   bool IdenticalSlowPath(
1532       const HloInstruction& other,
1533       const std::function<bool(const HloComputation*, const HloComputation*)>&
1534           eq_computations) const override;
1535   // Implementation for non-common logic of CloneWithNewOperands.
1536   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1537       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1538       HloCloneContext* context) const override;
1539 
1540   std::unique_ptr<DomainMetadata> operand_side_metadata_;
1541   std::unique_ptr<DomainMetadata> user_side_metadata_;
1542 };
1543 
1544 class HloGetDimensionSizeInstruction : public HloInstruction {
1545  public:
1546   explicit HloGetDimensionSizeInstruction(const Shape& shape,
1547                                           HloInstruction* operand,
1548                                           int64 dimension);
1549 
1550   // Returns the dimension sizes or numbers associated with this instruction.
dimension()1551   int64 dimension() const { return dimension_; }
1552   // Returns a serialized representation of this instruction.
1553   HloInstructionProto ToProto() const override;
1554 
1555  private:
1556   std::vector<string> ExtraAttributesToStringImpl(
1557       const HloPrintOptions& options) const override;
1558   bool IdenticalSlowPath(
1559       const HloInstruction& other,
1560       const std::function<bool(const HloComputation*, const HloComputation*)>&
1561           eq_computations) const override;
1562   // Implementation for non-common logic of CloneWithNewOperands.
1563   std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
1564       const Shape& shape, absl::Span<HloInstruction* const> new_operands,
1565       HloCloneContext* context) const override;
1566 
1567   int64 dimension_;
1568 };
1569 
1570 }  // namespace xla
1571 
1572 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
1573