1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
18 
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/statusor.h"
23 
24 namespace xla {
25 
26 // Some lightweight utilities intended to make HLO instruction creation more
27 // ergonomic.  We don't have a complete set of helpers yet -- I expect we'll
28 // expand this interface as needed on an ad-hoc basis.
29 
30 // Creates a unary HLO instruction and adds it to the computation containing
31 // `operand`.
32 StatusOr<HloInstruction*> MakeUnaryHlo(HloOpcode opcode,
33                                        HloInstruction* operand);
34 
35 // Creates a binary HLO instruction and adds it to the computation containing
36 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
37 StatusOr<HloInstruction*> MakeBinaryHlo(HloOpcode opcode, HloInstruction* lhs,
38                                         HloInstruction* rhs);
39 
40 // Creates a compare HLO instruction and adds it to the computation containing
41 // `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
42 StatusOr<HloInstruction*> MakeCompareHlo(Comparison::Direction direction,
43                                          HloInstruction* lhs,
44                                          HloInstruction* rhs);
45 
46 // Creates a pad HLO instruction and adds it to the computation containing
47 // `operand` and `padding_value` (`operand` and `padding_value` must be in the
48 // same computation).
49 StatusOr<HloInstruction*> MakePadHlo(HloInstruction* operand,
50                                      HloInstruction* padding_value,
51                                      const PaddingConfig& padding_config);
52 
53 // Creates a slice HLO instruction and adds it to the computation containing
54 // `operand`.
55 StatusOr<HloInstruction*> MakeSliceHlo(HloInstruction* operand,
56                                        absl::Span<const int64> start_indices,
57                                        absl::Span<const int64> limit_indices,
58                                        absl::Span<const int64> strides);
59 
60 // Creates a convolution HLO instruction and adds it to the computation
61 // containing `lhs` and `rhs` (`lhs` and `rhs` must be in the same computation).
62 // If the result shape has integral element type, an optional
63 // preferred_element_type can be specified to override the element type.
64 StatusOr<HloInstruction*> MakeConvolveHlo(
65     HloInstruction* lhs, HloInstruction* rhs, int64 feature_group_count,
66     int64 batch_group_count, const Window& window,
67     const ConvolutionDimensionNumbers& dimension_numbers,
68     const PrecisionConfig& precision_config,
69     absl::optional<PrimitiveType> preferred_element_type);
70 
71 // Creates a transpose HLO instruction and adds it to the computation containing
72 // `operand`.
73 StatusOr<HloInstruction*> MakeTransposeHlo(HloInstruction* operand,
74                                            absl::Span<const int64> dimensions);
75 
76 // Creates a reshape HLO instruction and adds it to the computation containing
77 // `operand`.
78 StatusOr<HloInstruction*> MakeReshapeHlo(const Shape& result_shape,
79                                          HloInstruction* operand);
80 
81 StatusOr<HloInstruction*> MakeReshapeHlo(
82     absl::Span<const int64> result_shape_dim_bounds, HloInstruction* operand);
83 
84 // Creates a dynamic-slice HLO instruction and adds it to the computation
85 // containing `operand` and `start_indices` (`operand` and `start_indices` must
86 // be in the same computation).
87 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
88     HloInstruction* operand, absl::Span<HloInstruction* const> start_indices,
89     absl::Span<const int64> slice_sizes);
90 StatusOr<HloInstruction*> MakeDynamicSliceHlo(
91     HloInstruction* operand, HloInstruction* start_indices,
92     absl::Span<const int64> slice_sizes);
93 
94 // Creates a dynamic-update-slice HLO instruction and adds it to the computation
95 // containing `operand`, `update` and `start_indices` (`operand`, `update` and
96 // `start_indices` must be in the same computation).
97 StatusOr<HloInstruction*> MakeDynamicUpdateSliceHlo(
98     HloInstruction* operand, HloInstruction* update,
99     HloInstruction* start_indices);
100 
101 // Creates a broadcast HLO instruction and adds it to the computation containing
102 // `operand`.
103 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
104                                  absl::Span<const int64> broadcast_dimensions,
105                                  absl::Span<const int64> result_shape_bounds);
106 HloInstruction* MakeBroadcastHlo(HloInstruction* operand,
107                                  absl::Span<const int64> broadcast_dimensions,
108                                  const Shape& shape);
109 
110 // Creates a GetTupleElement HLO instruction and adds it to the computation
111 // containing `operand`.
112 StatusOr<HloInstruction*> MakeGetTupleElementHlo(HloInstruction* operand,
113                                                  int64 index);
114 
115 // Creates a Concatenate HLO instruction and adds it to the computation
116 // containing `operands` (`operands` must be non-empty and every element must be
117 // contained in the same computation).
118 StatusOr<HloInstruction*> MakeConcatHlo(
119     absl::Span<HloInstruction* const> operands, int64 dimension);
120 
121 // Creates a Convert HLO instruction that converts the given instruction to have
122 // the given primitive type.
123 HloInstruction* MakeConvertToHlo(HloInstruction* hlo, PrimitiveType type);
124 
125 // Creates a BitcastConvert HLO instruction.
126 HloInstruction* MakeBitcastConvertToHlo(HloInstruction* hlo,
127                                         PrimitiveType type);
128 
129 // Creates an Iota HLO instruction.
130 HloInstruction* MakeIotaHlo(HloComputation* computation, const Shape& shape,
131                             int64 iota_dimension);
132 
133 // Creates a Dot HLO instruction and adds it to the computation containing `lhs`
134 // and `rhs` (both must be in the same computation). If the result shape has
135 // integral element type, an optional preferred_element_type can be specified to
136 // override the element type.
137 StatusOr<HloInstruction*> MakeDotHlo(
138     HloInstruction* lhs, HloInstruction* rhs,
139     const DotDimensionNumbers& dim_numbers,
140     const PrecisionConfig& precision_config,
141     absl::optional<PrimitiveType> preferred_element_type);
142 
143 // Creates a Map HLO instruction and adds it to the computation containing the
144 // operands. All operands must be in the same computation.
145 StatusOr<HloInstruction*> MakeMapHlo(absl::Span<HloInstruction* const> operands,
146                                      HloComputation* map_computation);
147 
148 // Creates a Reduce HLO instruction and adds it to the computation containing
149 // the operand. This will create the sub-computation needed for the reduction in
150 // the given module. binary_opcode should represent a binary operation.
151 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
152                                         HloInstruction* init_value,
153                                         absl::Span<const int64> dimensions,
154                                         HloOpcode binary_opcode);
155 
156 StatusOr<HloInstruction*> MakeReduceHlo(HloInstruction* operand,
157                                         HloInstruction* init_value,
158                                         HloOpcode binary_opcode,
159                                         HloModule* module);
160 
161 // Creates a Reverse HLO instruction and adds it to the computation containing
162 // `operand`.
163 StatusOr<HloInstruction*> MakeReverseHlo(HloInstruction* operand,
164                                          absl::Span<const int64> dimensions);
165 
166 // Creates a Select HLO instruction and adds it to the computation containing
167 // the predicate. The on_true and on_false instructions must also be contained
168 // in the same computation. If on_true and on_false are tuples, create a tuple
169 // select instead. `pred` is broadcasted up from a scalar if necessary.
170 StatusOr<HloInstruction*> MakeSelectHlo(HloInstruction* pred,
171                                         HloInstruction* on_true,
172                                         HloInstruction* on_false,
173                                         HloInstruction* derived_from = nullptr);
174 
175 // Creates a Sort HLO instruction and adds it to the computation containing the
176 // operands. All operands must be in the same computation. Also creates a
177 // default compare sub-computation which sorts the first operand into ascending
178 // order. 'is_stable' specifies whether the sorting should be stable.
179 StatusOr<HloInstruction*> MakeSortHlo(
180     const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
181     int64 dimension_to_sort, bool is_stable, HloComputation::Builder* builder,
182     HloModule* module);
183 
184 // Creates an R1 Constant HLO instruction of the given PrimitiveType with the
185 // given values and adds it to the given computation.
186 template <typename NativeT>
MakeR1ConstantHlo(HloComputation * computation,PrimitiveType type,absl::Span<const NativeT> values)187 StatusOr<HloInstruction*> MakeR1ConstantHlo(HloComputation* computation,
188                                             PrimitiveType type,
189                                             absl::Span<const NativeT> values) {
190   Literal literal = LiteralUtil::CreateR1<NativeT>(values);
191   if (literal.shape().element_type() != type) {
192     TF_ASSIGN_OR_RETURN(literal, literal.Convert(type));
193   }
194   return computation->AddInstruction(
195       HloInstruction::CreateConstant(std::move(literal)));
196 }
197 
198 // Creates an R0 Constant HLO instruction of the PrimitiveType corresponding to
199 // `NativeT` with the given value and adds it to the given computation.
200 template <class NativeT>
MakeR0ConstantHlo(HloComputation * computation,NativeT value)201 HloInstruction* MakeR0ConstantHlo(HloComputation* computation, NativeT value) {
202   return computation->AddInstruction(
203       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)));
204 }
205 
206 // Makes a scalar that is elementwise compatible with the shape of the base
207 // instruction.
208 template <class NativeT>
MakeScalarLike(HloInstruction * base,NativeT value)209 HloInstruction* MakeScalarLike(HloInstruction* base, NativeT value) {
210   auto scalar = base->parent()->AddInstruction(
211       HloInstruction::CreateConstant(LiteralUtil::CreateR0<NativeT>(value)
212                                          .Convert(base->shape().element_type())
213                                          .ValueOrDie()));
214   if (base->shape().rank() == 0) {
215     *scalar->mutable_shape() = base->shape();
216     return scalar;
217   }
218   return base->parent()->AddInstruction(
219       HloInstruction::CreateBroadcast(base->shape(), scalar, {}));
220 }
221 
222 // -----------------------------------------------------------------------------
223 // Some other miscellaneous helpers to generate common HLO patterns.  All of
224 // these add all the instructions they generate into the computation containing
225 // their operand(s).
226 
227 // Collapses (via reshape) the first N (logical) dimensions of `operand` into a
228 // single leading dimension.  `operand` must have rank > `n` and `n` must not be
229 // 0.
230 //
231 // For instance if `operand` has shape f32[7,8,9] and n is 2 then the output is
232 // the `operand` reshaped to [56,9].
233 StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n);
234 
235 // Prepends `n` degenerate dimensions (dimensions with bound = 1) to `operand`
236 // using a reshape.
237 //
238 // For instance if operand has shape f32[3,4,5] then this returns the operand
239 // reshaped to f32[1,3,4,5].  If the operand is a f32 scalar (i.e. has shape
240 // f32[]) then this returns the operand reshaped to f32[1].
241 StatusOr<HloInstruction*> PrependDegenerateDims(HloInstruction* operand,
242                                                 int64 n);
243 
244 // Expands (via reshape) the first (logical) dimension of `operand` into a
245 // sequence of `expanded_dims` dimensions.  `operand` must at least be of rank 1
246 // and the number of elements in its first dimension must be equal to the
247 // product of `expanded_dims`.
248 //
249 // For instance if `operand` has shape f32[200,9,7] and expanded_dims is
250 // {2,5,20} the result is `operand` reshaped to [2,5,20,9,7].
251 StatusOr<HloInstruction*> ExpandFirstDimIntoNDims(
252     HloInstruction* operand, absl::Span<const int64> expanded_dims);
253 
254 // Elides (via reshape) a set of degenerate dimensions (dimensions containing
255 // exactly one element), `dims_to_elide` from `operand`.  Every dimension in
256 // `dims_to_elide` must be a degenerate dimension.  `dims_to_elide` must be
257 // sorted and not contain duplicates.
258 //
259 // For example if `operand` is of shape f32[19,1,20,1,7,1,9] and dims_to_elide
260 // is {1,5} then the result is `operand` reshaped to [19,20,1,7,9].
261 StatusOr<HloInstruction*> ElideDegenerateDims(
262     HloInstruction* operand, absl::Span<const int64> dims_to_elide);
263 
264 // Inserts (via reshape) a set of degenerate dimensions (dimensions containing
265 // exactly one element), `dims_to_insert` into `operand`. The dimensions in
266 // `dims_to_insert` refer to the dimensions in the result, and hence should be
267 // less than the rank of the result. Also, `dims_to_insert` must be sorted.
268 //
269 // For example, if `operand` is of shape f32[12,21,8,34] and dims_to_insert is
270 // {0, 2}, then the result is `operand` reshaped to [1,12,1,21,8,34].
271 StatusOr<HloInstruction*> InsertDegenerateDims(
272     HloInstruction* operand, absl::Span<const int64> dims_to_insert);
273 
274 // Pads `operand` (which must have rank 1) with `zeros_to_prepend` zeros in the
275 // front and `zeros_to_append` zeros in the back.
276 StatusOr<HloInstruction*> PadVectorWithZeros(HloInstruction* operand,
277                                              int64 zeros_to_prepend,
278                                              int64 zeros_to_append);
279 
280 // Broadcasts a zero value of type `element_type` into a tensor with element
281 // type `element_type` and dimension bounds `broadcast_dimensions`.  The
282 // broadcast instruction is emitted into `computation`.
283 HloInstruction* BroadcastZeros(HloComputation* computation,
284                                PrimitiveType element_type,
285                                absl::Span<const int64> broadcast_dimensions);
286 
287 // Same as above, but fill the tensor with ones.
288 HloInstruction* BroadcastOnes(HloComputation* computation,
289                               PrimitiveType element_type,
290                               absl::Span<const int64> broadcast_dimensions);
291 
292 // Creates a HLO computation that takes arguments of type `domain` and produces
293 // a value of type `range`.
294 StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
295     absl::Span<const Shape* const> domain, const Shape& range,
296     absl::string_view name);
297 
298 }  // namespace xla
299 
300 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CREATION_UTILS_H_
301