1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Shape inference is used by the XLA service as the user builds up
17 // computation requests.
18 
19 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
20 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
21 
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace xla {
34 
35 // For a given operation and input shapes, infers what the resulting shape is
36 // for the operation. With this functionality, the user does not need to specify
37 // the expected result type for computations that are built up via the API --
38 // the shape that results from an operation is inferred. Some methods have
39 // overloads for inferring shape at the HLO level.
40 //
41 // TODO(b/73352135): Shape inference does not issue very good error messages, in
42 // part because HloInstruction::ToString() is not available since shape
43 // inference runs before the HloInstruction object is created. We need a
44 // solution for this.
45 class ShapeInference {
46  public:
47   // Infers the shape produced by applying the given unary operation to the
48   // given input shape.
49   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
50                                            const Shape& shape);
51   static StatusOr<Shape> InferUnaryOpShape(HloOpcode opcode,
52                                            const HloInstruction* operand);
53 
54   // Infers the shape produced by applying the given binary operation to the
55   // given input shapes.
56   static StatusOr<Shape> InferBinaryOpShape(
57       HloOpcode opcode, const Shape& lhs, const Shape& rhs,
58       absl::Span<const int64> broadcast_dimensions);
59   static StatusOr<Shape> InferBinaryOpShape(HloOpcode opcode,
60                                             const HloInstruction* lhs,
61                                             const HloInstruction* rhs);
62 
63   // Infers the shape produced by applying the given ternary operation to the
64   // given input shapes.
65   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode, const Shape& lhs,
66                                              const Shape& rhs,
67                                              const Shape& ehs);
68   static StatusOr<Shape> InferTernaryOpShape(HloOpcode opcode,
69                                              const HloInstruction* lhs,
70                                              const HloInstruction* rhs,
71                                              const HloInstruction* ehs);
72 
73   // Infers the shape produced by applying the given variadic operation to the
74   // given input operand shapes.
75   static StatusOr<Shape> InferVariadicOpShape(
76       HloOpcode opcode, absl::Span<const Shape* const> operand_shapes);
77   static StatusOr<Shape> InferVariadicOpShape(
78       HloOpcode opcode, absl::Span<const HloInstruction* const> operands);
79 
80   // Infers the shape produced by applying the given mapping computation shape
81   // to the given operand shapes.
82   static StatusOr<Shape> InferMapShape(
83       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply,
84       absl::Span<const int64> dimensions);
85 
86   // Infers the shape produced by InferBatchNormTraining with the given
87   // operands.
88   static StatusOr<Shape> InferBatchNormTrainingShape(const Shape& operand_shape,
89                                                      const Shape& scale_shape,
90                                                      const Shape& offset_shape,
91                                                      int64 feature_index);
92 
93   // Infers the shape produced by InferBatchNormInference with the given
94   // operands.
95   static StatusOr<Shape> InferBatchNormInferenceShape(
96       const Shape& operand_shape, const Shape& scale_shape,
97       const Shape& offset_shape, const Shape& mean_shape,
98       const Shape& variance_shape, int64 feature_index);
99 
100   // Infers the shape produced by InferBatchNormGrad with the given operands.
101   static StatusOr<Shape> InferBatchNormGradShape(const Shape& operand_shape,
102                                                  const Shape& scale_shape,
103                                                  const Shape& mean_shape,
104                                                  const Shape& var_shape,
105                                                  const Shape& output_grad_shape,
106                                                  int64 feature_index);
107 
108   // Infers the shape produced by applying the given convolutional
109   // filter (rhs) to lhs in the way specified by the fields on window.
110   static StatusOr<Shape> InferConvolveShape(
111       const Shape& lhs, const Shape& rhs, int64 feature_group_count,
112       int64 batch_group_count, const Window& window,
113       const ConvolutionDimensionNumbers& dimension_numbers);
114 
115   // Infers the shape produced by the given FFT type on the given operand.
116   static StatusOr<Shape> InferFftShape(const Shape& in, FftType fft_type,
117                                        absl::Span<const int64> fft_length);
118 
119   // Infers the shape produced by the given triangular solve operation.
120   static StatusOr<Shape> InferTriangularSolveShape(
121       const Shape& a, const Shape& b, const TriangularSolveOptions& options);
122 
123   // Infers the shape produced by the given triangular solve operation.
124   static StatusOr<Shape> InferCholeskyShape(const Shape& a);
125 
126   // Infers the shape produced by a cross replica sum with the given operand
127   // shapes.
128   static StatusOr<Shape> InferAllReduceShape(
129       absl::Span<const Shape* const> operand_shapes);
130 
131   // Infers final shape of an Alltoall operation that is created by the xla
132   // builder.
133   static StatusOr<Shape> InferAllToAllShape(const Shape& shape,
134                                             int64 split_dimension,
135                                             int64 concat_dimension,
136                                             int64 split_count);
137 
138   // Infers the shape of an HLO all-to-all instruction.
139   static StatusOr<Shape> InferAllToAllTupleShape(
140       absl::Span<const Shape* const> operand_shapes);
141 
142   // Infers the shape of a collective permute operation.
143   static StatusOr<Shape> InferCollectivePermuteShape(const Shape& shape);
144 
145   // Infers the shape produced by applying the given reduction computation
146   // shape to the given input operand shape.
147   //
148   // If pass_index is true, the reduce function is invoked with the element
149   // index as the leading parameter, and the program shape should match
150   // accordingly (or an error will result).
151   static StatusOr<Shape> InferReduceShape(
152       absl::Span<const Shape* const> arg_shapes,
153       absl::Span<const int64> dimensions_to_reduce,
154       const ProgramShape& to_apply);
155 
156   // Infers the shape produced by applying the given computation to the operand
157   // shape with the given window and stride dimensions.
158   static StatusOr<Shape> InferReduceWindowShape(
159       const Shape& operand_shape, const Shape& init_value, const Window& window,
160       const ProgramShape& to_apply_shape);
161 
162   // Infers the shape produced by scattering the given source shape to the
163   // selected indices of each window on the operand shape.
164   static StatusOr<Shape> InferSelectAndScatterShape(
165       const Shape& operand_shape, const ProgramShape& select_shape,
166       const Window& window, const Shape& source_shape,
167       const Shape& init_value_shape, const ProgramShape& scatter_shape);
168 
169   // Infers the shape produced by a reverse operation that reverses the order
170   // of the elements in the given dimensions.
171   static StatusOr<Shape> InferReverseShape(const Shape& operand_shape,
172                                            absl::Span<const int64> dimensions);
173 
174   // Infers the shape produced by a slice operation spanning from the starts to
175   // the limits in the original shape's dimensions.
176   //
177   // e.g. slice f32[32x32] 0:16 0:16 -> f32[16x16]
178   static StatusOr<Shape> InferSliceShape(const Shape& arg,
179                                          absl::Span<const int64> starts,
180                                          absl::Span<const int64> limits,
181                                          absl::Span<const int64> strides);
182 
183   // Infers the shape produced by a dynamic slice operation of size specified
184   // in 'slice_sizes', with dynamic start indices shape 'start_indices_shape'.
185   static StatusOr<Shape> InferDynamicSliceShape(
186       const Shape& operand_shape, absl::Span<const Shape> start_index_shapes,
187       absl::Span<const int64> slice_sizes, bool allow_scalar_indices = true);
188 
189   // Infers the shape produced by a dynamic update slice operation based
190   // on the shape of operand and update.
191   static StatusOr<Shape> InferDynamicUpdateSliceShape(
192       const Shape& operand_shape, const Shape& update_shape,
193       absl::Span<const Shape> start_index_shapes,
194       bool allow_scalar_indices = true);
195 
196   // Infers the shape produced by doing a compile-time-constant indexing into
197   // the given input shape. This is essential for operations on tuples, because
198   // it is impossible to infer the type that comes out of the tuple indexing if
199   // it is not a compile time constant.
200   static StatusOr<Shape> InferGetTupleElementShape(const Shape& arg,
201                                                    int64 index);
202 
203   // Infers the shape produced from a while node. condition and body are the
204   // shapes of computations for the condition and the body of a while node, and
205   // init is the shape of data initially passed in to the body as an argument.
206   // The shapes must match; condition: T -> PRED, body: T -> T, init: T
207   static StatusOr<Shape> InferWhileShape(const ProgramShape& condition,
208                                          const ProgramShape& body,
209                                          const Shape& init);
210 
211   // Infers the shape produced by a predicated or indexed conditional operation.
212   static StatusOr<Shape> InferConditionalShape(
213       const Shape& branch_index,
214       absl::Span<const ProgramShape> branch_computations,
215       absl::Span<const Shape> branch_operands);
216 
217   // Infers the shape produced by a broadcast operation.
218   static StatusOr<Shape> InferBroadcastShape(
219       const Shape& operand, absl::Span<const int64> broadcast_sizes);
220 
221   // Checks whether the given parameters can form a broadcast. Returns the same
222   // output_shape if it's legal.
223   static StatusOr<Shape> InferBroadcastShape(
224       const Shape& operand_shape, const Shape& output_shape,
225       absl::Span<const int64> broadcast_dimensions);
226 
227   // Infers the shape produced by a reshape operation from the element type of
228   // its operand and the new dimension sizes specified.
229   static StatusOr<Shape> InferReshapeShape(const Shape& operand,
230                                            absl::Span<const int64> dimensions,
231                                            absl::Span<const int64> new_sizes);
232 
233   // Infers the shape produced by a transpose operation from the element type of
234   // its operand and its dimensions field.
235   static StatusOr<Shape> InferTransposeShape(
236       const Shape& operand, absl::Span<const int64> dimensions);
237 
238   // Helper that infers the shape produced by performing a concatenate operation
239   // with the given operand shapes.
240   static StatusOr<Shape> InferConcatOpShape(
241       absl::Span<const Shape* const> arg_shapes, int64 dimension);
242 
243   // Helper that validates the given operand shape can be converted to the
244   // target output_shape via a convert instruction -- the requirement is that
245   // the shape is identical except for the element type.
246   static StatusOr<Shape> InferConvertShape(const Shape& operand_shape,
247                                            PrimitiveType new_element_type);
248 
249   // Helper that validates the given operand shape can be bitcast converted to
250   // the target output_shape via a bitcast convert instruction -- the
251   // requirement is that the shape is identical except for the element type and
252   // the element types have identical bit-widths.
253   static StatusOr<Shape> InferBitcastConvertShape(
254       const Shape& operand_shape, PrimitiveType new_element_type);
255 
256   // Helper that validates the input data type for a reduce-precision operation,
257   // and returns the result shape.
258   static StatusOr<Shape> InferReducePrecisionShape(const Shape& operand_shape,
259                                                    const int exponent_bits,
260                                                    const int mantissa_bits);
261 
262   // Helper that infers the shape produced by a pad operation based on the
263   // padding configuration.
264   static StatusOr<Shape> InferPadShape(const Shape& operand_shape,
265                                        const Shape& padding_value_shape,
266                                        const PaddingConfig& padding_config);
267 
268   // Helper that validates the given arg_shapes are compatible with the shape of
269   // the to_apply parameters, and returns the to_apply result shape.
270   static StatusOr<Shape> InferCallShape(
271       absl::Span<const Shape* const> arg_shapes, const ProgramShape& to_apply);
272 
273   // Helper that infers the shape produced by performing a dot operation with
274   // the given LHS and RHS shapes.
275   static StatusOr<Shape> InferDotOpShape(
276       const Shape& lhs, const Shape& rhs,
277       const DotDimensionNumbers& dimension_numbers);
278 
279   // Helper that infers the shape of the tensor produced by a gather operation
280   // with the given input shape, gather indices shape and gather dimension
281   // numbers.
282   static StatusOr<Shape> InferGatherShape(
283       const Shape& input_shape, const Shape& start_indices_shape,
284       const GatherDimensionNumbers& gather_dim_numbers,
285       absl::Span<const int64> slice_sizes);
286 
287   // Helper that validates the given input shape, scatter indices shape, updates
288   // shape, and scatter dimension numbers that constitute a scatter operation,
289   // and returns the result shape of the scatter operation.
290   static StatusOr<Shape> InferScatterShape(
291       const Shape& operand_shape, const Shape& scatter_indices_shape,
292       const Shape& updates_shape, const ProgramShape& to_apply_shape,
293       const ScatterDimensionNumbers& scatter_dim_numbers);
294 
295   static StatusOr<Shape> InferGetDimensionSizeShape(const Shape& shape,
296                                                     int64 dimension);
297 
298  private:
299   // Helper that infers the shape produced by performing an element-wise binary
300   // operation with the given LHS and RHS shapes.
301   // Note: By "element-wise" we mean operations that look at a single element in
302   // the LHS and a single element in the RHS to produce a single output element,
303   // even in the presence of broadcasting of one of the operands over the other.
304   static StatusOr<Shape> InferElementwiseBinaryOpShape(
305       HloOpcode operation, const Shape& lhs, const Shape& rhs,
306       absl::Span<const int64> broadcast_dimensions);
307 
308   // Helper for inferring the shape of Clamp ops.
309   static StatusOr<Shape> InferClampShape(const Shape& min, const Shape& operand,
310                                          const Shape& max);
311 
312   // Helper for inferring the shape of Select ops.
313   static StatusOr<Shape> InferSelectShape(const Shape& pred,
314                                           const Shape& on_true,
315                                           const Shape& on_false);
316   // Helper for inferring the shape of TupleSelect ops.
317   static StatusOr<Shape> InferTupleSelectShape(const Shape& pred,
318                                                const Shape& on_true,
319                                                const Shape& on_false);
320 
321   // Helper for inferring shapes of binary operations which use degenerate
322   // dimension broadcasting (a dimension of size 1 in one operand is broadcast
323   // up to match the size of the dimension in the other operand).
324   static StatusOr<Shape> InferDegenerateDimensionBroadcastShape(
325       HloOpcode operation, const Shape& lhs, const Shape& rhs);
326 
327   // Helper for inferring shapes of binary operations using "InDim"
328   // broadcasting. This is the broadcasting used in the *InDim binary operations
329   // (for example ComputationBuilder::AddInDim). smaller_shape must be a
330   // lower-rank shape than larger_shape. Returns the shape that the
331   // smaller_shape is broadcast to.
332   static StatusOr<Shape> InferInDimBroadcastShape(
333       const Shape& smaller_shape, const Shape& larger_shape,
334       absl::Span<const int64> broadcast_dimensions);
335 
336   TF_DISALLOW_COPY_AND_ASSIGN(ShapeInference);
337 };
338 
339 }  // namespace xla
340 
341 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPE_INFERENCE_H_
342