1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17 
18 #include <string>
19 
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 namespace {
30 
31 using ::testing::ContainsRegex;
32 using ::testing::HasSubstr;
33 
34 class ShapeInferenceTest : public ::testing::Test {
35  protected:
36   // Some handy scalar shapes.
37   const Shape s32_ = ShapeUtil::MakeShape(S32, {});
38   const Shape f16_ = ShapeUtil::MakeShape(F16, {});
39   const Shape f32_ = ShapeUtil::MakeShape(F32, {});
40   const Shape f64_ = ShapeUtil::MakeShape(F64, {});
41   const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
42 
43   // Some handy vector and matrix shapes of F32 type.
44   // Suffix: vector_length_, matrix_rows_cols_
45   const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
46   const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
47   const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
48   const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
49   const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
50 
51   // Some handy S32 arrays.
52   const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
53 };
54 
55 // Subclass for testing InferReduceShape.
56 class ReduceShapeInferenceTest : public ShapeInferenceTest {
57  protected:
58   // Helper that runs reduce shape inference with the input 'arg' and given
59   // dimensions to reduce, and checks the inferred shape is as expected. The
60   // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64> dimensions_to_reduce)61   void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
62                                  const Shape& arg,
63                                  absl::Span<const int64> dimensions_to_reduce) {
64     ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
65     auto inferred_status = ShapeInference::InferReduceShape(
66         {&arg, &f32_}, dimensions_to_reduce, to_apply);
67     EXPECT_IS_OK(inferred_status.status());
68     EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
69                                  inferred_status.ValueOrDie()));
70   }
71 };
72 
73 // Subclass for testing InferSelectAndScatterShape.
74 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
75  protected:
SelectAndScatterShapeInferenceTest()76   SelectAndScatterShapeInferenceTest() {
77     operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
78     source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
79     WindowDimension dim;
80     dim.set_size(2);
81     dim.set_stride(2);
82     dim.set_padding_low(0);
83     dim.set_padding_high(0);
84     dim.set_window_dilation(1);
85     dim.set_base_dilation(1);
86     *window_.add_dimensions() = dim;
87     *window_.add_dimensions() = dim;
88     init_value_shape_ = ShapeUtil::MakeShape(F32, {});
89     select_program_shape_ = ShapeUtil::MakeProgramShape(
90         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
91     scatter_program_shape_ = ShapeUtil::MakeProgramShape(
92         {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
93   }
94 
95   Shape operand_shape_;
96   Shape source_shape_;
97   Window window_;
98   Shape init_value_shape_;
99   ProgramShape select_program_shape_;
100   ProgramShape scatter_program_shape_;
101 };
102 
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)103 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
104   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
105   auto inferred_status =
106       ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
107   ASSERT_IS_OK(inferred_status.status());
108   ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
109 }
110 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)111 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
112   Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
113   auto inferred_status = ShapeInference::InferTernaryOpShape(
114       HloOpcode::kSelect, pred_, tuple, tuple);
115   ASSERT_IS_OK(inferred_status.status());
116   ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
117 }
118 
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)119 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
120   auto inferred_status = ShapeInference::InferTernaryOpShape(
121       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
122   ASSERT_IS_OK(inferred_status.status());
123   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
124 }
125 
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)126 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
127   auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
128   auto inferred_status = ShapeInference::InferTernaryOpShape(
129       HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
130   ASSERT_IS_OK(inferred_status.status());
131   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
132 }
133 
TEST_F(ShapeInferenceTest,SelectBadShapes)134 TEST_F(ShapeInferenceTest, SelectBadShapes) {
135   auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
136       HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
137   ASSERT_FALSE(inferred_status_error1.ok());
138   ASSERT_THAT(inferred_status_error1.status().error_message(),
139               HasSubstr("Operands to select must be the same shape"));
140 
141   auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
142       HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
143   ASSERT_FALSE(inferred_status_error2.ok());
144   ASSERT_THAT(inferred_status_error2.status().error_message(),
145               HasSubstr("pred operand must have PRED"));
146 
147   auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
148       HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
149       matrix_64_48_);
150   ASSERT_FALSE(inferred_status_error3.ok());
151   ASSERT_THAT(inferred_status_error3.status().error_message(),
152               HasSubstr("with non-scalar predicate with dimensionality"));
153 
154   // Tuples have a TUPLE element type and cannot be the pred of a select.
155   auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
156       HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
157       ShapeUtil::MakeTupleShape({f32_, f32_}),
158       ShapeUtil::MakeTupleShape({f32_, f32_}));
159   ASSERT_FALSE(inferred_status_error4.ok());
160   ASSERT_THAT(inferred_status_error4.status().error_message(),
161               HasSubstr("pred operand must have PRED element type"));
162 }
163 
TEST_F(ShapeInferenceTest,ClampAllMatrix)164 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
165   auto inferred_status = ShapeInference::InferTernaryOpShape(
166       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
167   ASSERT_IS_OK(inferred_status.status());
168   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
169 }
170 
TEST_F(ShapeInferenceTest,ClampAllScalar)171 TEST_F(ShapeInferenceTest, ClampAllScalar) {
172   auto inferred_status =
173       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
174   ASSERT_IS_OK(inferred_status.status());
175   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
176 }
177 
TEST_F(ShapeInferenceTest,ClampMinScalar)178 TEST_F(ShapeInferenceTest, ClampMinScalar) {
179   auto inferred_status = ShapeInference::InferTernaryOpShape(
180       HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
181   ASSERT_IS_OK(inferred_status.status());
182   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
183 }
184 
TEST_F(ShapeInferenceTest,ClampMaxScalar)185 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
186   auto inferred_status = ShapeInference::InferTernaryOpShape(
187       HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
188   ASSERT_IS_OK(inferred_status.status());
189   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
190 }
191 
TEST_F(ShapeInferenceTest,ClampOperandScalar)192 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
193   auto inferred_status = ShapeInference::InferTernaryOpShape(
194       HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
195   ASSERT_IS_OK(inferred_status.status());
196   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
197 }
198 
TEST_F(ShapeInferenceTest,ClampMinMatrix)199 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
200   auto inferred_status = ShapeInference::InferTernaryOpShape(
201       HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
202   ASSERT_IS_OK(inferred_status.status());
203   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
204 }
205 
TEST_F(ShapeInferenceTest,ClampMaxMatrix)206 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
207   auto inferred_status = ShapeInference::InferTernaryOpShape(
208       HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
209   ASSERT_IS_OK(inferred_status.status());
210   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
211 }
212 
TEST_F(ShapeInferenceTest,ClampOperandMatrix)213 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
214   auto inferred_status = ShapeInference::InferTernaryOpShape(
215       HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
216   ASSERT_IS_OK(inferred_status.status());
217   ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
218 }
219 
TEST_F(ShapeInferenceTest,ClampBadShapes)220 TEST_F(ShapeInferenceTest, ClampBadShapes) {
221   // Type mismatch
222   ASSERT_FALSE(
223       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
224           .ok());
225   ASSERT_FALSE(
226       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
227           .ok());
228   ASSERT_FALSE(
229       ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
230           .ok());
231   // Dimension mismatch
232   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
233                    HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
234                    .ok());
235   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
236                    HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
237                    .ok());
238   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
239                    HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
240                    .ok());
241   // Dimension mismatch, where one operand is a scalar
242   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
243                                                    vector_64_, vector_32_, f32_)
244                    .ok());
245   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
246                                                    vector_64_, f32_, vector_32_)
247                    .ok());
248   ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
249                                                    vector_64_, vector_32_)
250                    .ok());
251 }
252 
TEST_F(ShapeInferenceTest,Complex)253 TEST_F(ShapeInferenceTest, Complex) {
254   auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
255                            absl::Span<const int64> bcast) {
256     return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
257                                               bcast);
258   };
259   // Inputs must be FP.
260   ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
261   ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
262   // Component types must match.
263   ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
264   // Only F32->C64 and F64->C128 supported.
265   ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
266   // Validate correct uses.
267   Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
268   TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
269   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
270   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
271   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
272   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
273   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
274   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
275   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
276 
277   Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
278   TF_ASSERT_OK_AND_ASSIGN(result,
279                           complex_shape(vector_64_, matrix_32_64_, {1}));
280   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
281   TF_ASSERT_OK_AND_ASSIGN(result,
282                           complex_shape(matrix_32_64_, vector_64_, {1}));
283   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
284   TF_ASSERT_OK_AND_ASSIGN(result,
285                           complex_shape(matrix_32_64_, matrix_32_64_, {}));
286   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
287   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
288   ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
289 
290   TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
291   ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
292 }
293 
TEST_F(ShapeInferenceTest,VariadicOpTuplify)294 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
295   StatusOr<Shape> result =
296       ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
297   ASSERT_IS_OK(result.status());
298   ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
299                                ShapeUtil::MakeTupleShape({s32_, f32_})));
300 }
301 
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)302 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
303   Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
304   Window window;
305   WindowDimension dim;
306   dim.set_size(2);
307   dim.set_stride(2);
308   dim.set_padding_low(0);
309   dim.set_padding_high(0);
310   dim.set_window_dilation(1);
311   dim.set_base_dilation(1);
312   *window.add_dimensions() = dim;
313   *window.add_dimensions() = dim;
314   Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
315   Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
316   Shape float_scalar = ShapeUtil::MakeShape(F32, {});
317   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
318       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
319   auto inferred_status = ShapeInference::InferReduceWindowShape(
320       matrix_shape, init_value_shape, window, to_apply);
321 
322   ASSERT_IS_OK(inferred_status.status());
323   Shape inferred = inferred_status.ValueOrDie();
324   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
325 }
326 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)327 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
328   auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
329       operand_shape_, select_program_shape_, window_, source_shape_,
330       init_value_shape_, scatter_program_shape_);
331   ASSERT_IS_OK(inferred_status_ok.status());
332   Shape inferred = inferred_status_ok.ValueOrDie();
333   ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
334 }
335 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)336 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
337   Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
338   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
339       operand_shape_, select_program_shape_, window_, source_shape_fail,
340       init_value_shape_, scatter_program_shape_);
341   ASSERT_FALSE(inferred_status_fail.ok());
342   ASSERT_THAT(inferred_status_fail.status().error_message(),
343               HasSubstr("Source shape does not match"));
344 }
345 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)346 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
347   ProgramShape select_program_shape_fail =
348       ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
349   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
350       operand_shape_, select_program_shape_fail, window_, source_shape_,
351       init_value_shape_, scatter_program_shape_);
352   ASSERT_FALSE(inferred_status_fail.ok());
353   ASSERT_THAT(inferred_status_fail.status().error_message(),
354               HasSubstr("Select function must take 2 parameters"));
355 }
356 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)357 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
358   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
359       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
360   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
361       operand_shape_, select_program_shape_fail, window_, source_shape_,
362       init_value_shape_, scatter_program_shape_);
363   ASSERT_FALSE(inferred_status_fail.ok());
364   ASSERT_THAT(inferred_status_fail.status().error_message(),
365               HasSubstr("Select function must have rank-0 PRED"));
366 }
367 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)368 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
369   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
370       {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
371   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
372       operand_shape_, select_program_shape_fail, window_, source_shape_,
373       init_value_shape_, scatter_program_shape_);
374   ASSERT_FALSE(inferred_status_fail.ok());
375   ASSERT_THAT(inferred_status_fail.status().error_message(),
376               HasSubstr("Select function's first parameter"));
377 }
378 
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)379 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
380   ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
381       {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
382   auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
383       operand_shape_, select_program_shape_fail, window_, source_shape_,
384       init_value_shape_, scatter_program_shape_);
385   ASSERT_FALSE(inferred_status_fail.ok());
386   ASSERT_THAT(inferred_status_fail.status().error_message(),
387               HasSubstr("Select function's second parameter"));
388 }
389 
TEST_F(ShapeInferenceTest,Convolve)390 TEST_F(ShapeInferenceTest, Convolve) {
391   ConvolutionDimensionNumbers dnums;
392 
393   // Dimension order: batch, feature, x0, x1
394   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
395   dnums.set_input_batch_dimension(0);
396   dnums.set_output_batch_dimension(0);
397   dnums.set_input_feature_dimension(1);
398   dnums.set_output_feature_dimension(1);
399   dnums.add_input_spatial_dimensions(2);
400   dnums.add_output_spatial_dimensions(2);
401   dnums.add_input_spatial_dimensions(3);
402   dnums.add_output_spatial_dimensions(3);
403 
404   // Dimension order: x1, batch, feature, x0
405   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
406   dnums.set_kernel_input_feature_dimension(2);
407   dnums.set_kernel_output_feature_dimension(1);
408   dnums.add_kernel_spatial_dimensions(3);
409   dnums.add_kernel_spatial_dimensions(0);
410 
411   Window window;
412   auto dim0 = window.add_dimensions();
413   auto dim1 = window.add_dimensions();
414   dim0->set_size(3);
415   dim0->set_stride(2);
416   dim0->set_padding_low(1);
417   dim0->set_padding_high(1);
418   dim0->set_window_dilation(1);
419   dim0->set_base_dilation(1);
420   dim1->set_size(2);
421   dim1->set_stride(1);
422   dim1->set_padding_low(0);
423   dim1->set_padding_high(0);
424   dim1->set_window_dilation(1);
425   dim1->set_base_dilation(1);
426   auto inferred_status = ShapeInference::InferConvolveShape(
427       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
428       window, dnums);
429   ASSERT_IS_OK(inferred_status.status());
430   Shape inferred_shape = inferred_status.ValueOrDie();
431   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
432                                inferred_shape));
433 }
434 
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)435 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
436   ConvolutionDimensionNumbers dnums;
437 
438   // Dimension order: batch, feature, x0, x1
439   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
440   dnums.set_input_batch_dimension(0);
441   dnums.set_output_batch_dimension(0);
442   dnums.set_input_feature_dimension(1);
443   dnums.set_output_feature_dimension(1);
444   dnums.add_input_spatial_dimensions(2);
445   dnums.add_output_spatial_dimensions(2);
446   dnums.add_input_spatial_dimensions(3);
447   dnums.add_output_spatial_dimensions(3);
448 
449   // Dimension order: x1, batch, feature, x0
450   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
451   dnums.set_kernel_input_feature_dimension(2);
452   dnums.set_kernel_output_feature_dimension(1);
453   dnums.add_kernel_spatial_dimensions(3);
454   dnums.add_kernel_spatial_dimensions(0);
455 
456   Window window;
457   auto dim0 = window.add_dimensions();
458   dim0->set_size(3);
459   dim0->set_stride(3);
460   dim0->set_padding_low(0);
461   dim0->set_padding_high(0);
462   dim0->set_window_dilation(6);
463   dim0->set_base_dilation(1);
464 
465   auto dim1 = window.add_dimensions();
466   dim1->set_size(2);
467   dim1->set_stride(1);
468   dim1->set_padding_low(2);
469   dim1->set_padding_high(1);
470   dim1->set_window_dilation(2);
471   dim1->set_base_dilation(1);
472   auto inferred_status = ShapeInference::InferConvolveShape(
473       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
474       window, dnums);
475   ASSERT_IS_OK(inferred_status.status());
476   Shape inferred_shape = inferred_status.ValueOrDie();
477   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
478                                inferred_shape));
479 }
480 
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)481 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
482   ConvolutionDimensionNumbers dnums;
483 
484   // Dimension order: batch, feature, x0, x1
485   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
486   dnums.set_input_batch_dimension(0);
487   dnums.set_output_batch_dimension(0);
488   dnums.set_input_feature_dimension(1);
489   dnums.set_output_feature_dimension(1);
490   dnums.add_input_spatial_dimensions(2);
491   dnums.add_output_spatial_dimensions(2);
492   dnums.add_input_spatial_dimensions(3);
493   dnums.add_output_spatial_dimensions(3);
494 
495   // Dimension order: x1, batch, feature, x0
496   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
497   dnums.set_kernel_input_feature_dimension(2);
498   dnums.set_kernel_output_feature_dimension(1);
499   dnums.add_kernel_spatial_dimensions(3);
500   dnums.add_kernel_spatial_dimensions(0);
501 
502   Window window;
503   auto dim0 = window.add_dimensions();
504   dim0->set_size(4);
505   dim0->set_stride(3);
506   dim0->set_padding_low(0);
507   dim0->set_padding_high(0);
508   dim0->set_window_dilation(1);
509   dim0->set_base_dilation(6);
510 
511   auto dim1 = window.add_dimensions();
512   dim1->set_size(2);
513   dim1->set_stride(1);
514   dim1->set_padding_low(2);
515   dim1->set_padding_high(1);
516   dim1->set_window_dilation(1);
517   dim1->set_base_dilation(2);
518   auto inferred_status = ShapeInference::InferConvolveShape(
519       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
520       window, dnums);
521   ASSERT_IS_OK(inferred_status.status());
522   Shape inferred_shape = inferred_status.ValueOrDie();
523   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
524                                inferred_shape));
525 }
526 
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)527 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
528   // Dimension order for this test: batch, feature, x0, x1
529   Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
530   Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
531 
532   ConvolutionDimensionNumbers dnums;
533   dnums.set_input_batch_dimension(3);
534   dnums.set_output_batch_dimension(3);
535   dnums.set_input_feature_dimension(2);
536   dnums.set_output_feature_dimension(2);
537   dnums.add_input_spatial_dimensions(0);
538   dnums.add_output_spatial_dimensions(0);
539   dnums.add_input_spatial_dimensions(1);
540   dnums.add_output_spatial_dimensions(1);
541   dnums.set_kernel_input_feature_dimension(0);  // duplicated with kernel_x0
542   dnums.set_kernel_output_feature_dimension(3);
543   dnums.add_kernel_spatial_dimensions(0);
544   dnums.add_kernel_spatial_dimensions(1);
545 
546   Window window;
547   auto dim0 = window.add_dimensions();
548   auto dim1 = window.add_dimensions();
549   dim0->set_size(2);
550   dim0->set_stride(1);
551   dim0->set_padding_low(0);
552   dim0->set_padding_high(0);
553   dim1->set_size(3);
554   dim1->set_stride(2);
555   dim1->set_padding_low(1);
556   dim1->set_padding_high(1);
557   auto inferred_status = ShapeInference::InferConvolveShape(
558       lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
559       window, dnums);
560   ASSERT_FALSE(inferred_status.ok());
561   ASSERT_THAT(inferred_status.status().error_message(),
562               HasSubstr("each dimension exactly once"));
563 }
564 
TEST_F(ShapeInferenceTest,MapThatChangesElementType)565 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
566   Shape arg = ShapeUtil::MakeShape(F32, {20});
567   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
568   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
569   EXPECT_IS_OK(inferred_status.status());
570   Shape expected = ShapeUtil::MakeShape(S32, {20});
571   EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
572 }
573 
TEST_F(ShapeInferenceTest,Map)574 TEST_F(ShapeInferenceTest, Map) {
575   auto inferred_status_r1f32 = ShapeInference::InferMapShape(
576       {&vector_32_, &vector_32_},
577       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
578   EXPECT_IS_OK(inferred_status_r1f32.status());
579   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
580 
581   // It's OK to provide a single argument, as long as the applied arity matches
582   // (this degenerates to a Map).
583   auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
584       {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
585   EXPECT_IS_OK(inferred_status_r1f32_one.status());
586   EXPECT_TRUE(
587       ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
588 
589   auto inferred_status_r2s32 = ShapeInference::InferMapShape(
590       {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
591       ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
592   EXPECT_IS_OK(inferred_status_r2s32.status());
593   EXPECT_TRUE(
594       ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
595 
596   auto no_args_error = ShapeInference::InferMapShape(
597       {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
598   ASSERT_FALSE(no_args_error.ok());
599   ASSERT_THAT(no_args_error.status().error_message(),
600               HasSubstr("expects at least one argument"));
601 
602   auto args_diff_shapes_error = ShapeInference::InferMapShape(
603       {&vector_32_, &vector_64_},
604       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
605   ASSERT_FALSE(args_diff_shapes_error.ok());
606   ASSERT_THAT(args_diff_shapes_error.status().error_message(),
607               HasSubstr("requires all operands to have the same shape"));
608 
609   auto arity_error = ShapeInference::InferMapShape(
610       {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
611       {0});
612   ASSERT_FALSE(arity_error.ok());
613   ASSERT_THAT(arity_error.status().error_message(),
614               HasSubstr("function arity must match"));
615 
616   auto output_shape_error = ShapeInference::InferMapShape(
617       {&vector_32_, &vector_32_},
618       ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
619   ASSERT_FALSE(output_shape_error.ok());
620   ASSERT_THAT(output_shape_error.status().error_message(),
621               HasSubstr("result has to be a scalar"));
622 
623   auto param_shape_error = ShapeInference::InferMapShape(
624       {&vector_32_, &vector_32_},
625       ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
626   ASSERT_FALSE(param_shape_error.ok());
627   ASSERT_THAT(param_shape_error.status().error_message(),
628               HasSubstr("parameter has to be a scalar"));
629 
630   auto param_element_type_error = ShapeInference::InferMapShape(
631       {&vector_32_, &vector_32_},
632       ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
633   ASSERT_FALSE(param_element_type_error.ok());
634   ASSERT_THAT(param_element_type_error.status().error_message(),
635               HasSubstr("parameter type has to match argument"));
636 
637   Shape arg = ShapeUtil::MakeShape(F32, {20});
638   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
639   auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
640   EXPECT_IS_OK(inferred_status.status());
641   EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
642 
643   auto inferred_status_error1 = ShapeInference::InferMapShape(
644       {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
645   ASSERT_FALSE(inferred_status_error1.ok());
646   ASSERT_THAT(inferred_status_error1.status().error_message(),
647               HasSubstr("arity must match number of arguments"));
648 
649   auto inferred_status_error2 = ShapeInference::InferMapShape(
650       {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
651   ASSERT_FALSE(inferred_status_error2.ok());
652   ASSERT_THAT(inferred_status_error2.status().error_message(),
653               HasSubstr("has to be a scalar"));
654 
655   auto inferred_status_error3 = ShapeInference::InferMapShape(
656       {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
657   ASSERT_FALSE(inferred_status_error3.ok());
658   ASSERT_THAT(inferred_status_error3.status().error_message(),
659               HasSubstr("has to be a scalar"));
660 
661   auto inferred_status_error5 = ShapeInference::InferMapShape(
662       {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
663   ASSERT_FALSE(inferred_status_error5.ok());
664   ASSERT_THAT(inferred_status_error5.status().error_message(),
665               HasSubstr("parameter type has to match argument"));
666 }
667 
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)668 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
669   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
670                             /*dimensions_to_reduce=*/{0});
671 }
672 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)673 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
674   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
675                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
676                             /*dimensions_to_reduce=*/{0});
677 }
678 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)679 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
680   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
681                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
682                             /*dimensions_to_reduce=*/{1});
683 }
684 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)685 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
686   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
687                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
688                             /*dimensions_to_reduce=*/{0, 1});
689 }
690 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)691 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
692   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
693                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
694                             /*dimensions_to_reduce=*/{1, 2});
695 }
696 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)697 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
698   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
699                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
700                             /*dimensions_to_reduce=*/{0, 2});
701 
702   // Check that the order of dimensions_to_reduce doesn't matter.
703   ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
704                             ShapeUtil::MakeShape(F32, {2, 3, 4}),
705                             /*dimensions_to_reduce=*/{2, 0});
706 }
707 
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)708 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
709   ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
710                             /*dimensions_to_reduce=*/{0, 1, 2});
711 }
712 
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)713 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
714   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
715   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
716   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
717       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
718   auto inferred_status = ShapeInference::InferReduceShape(
719       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
720   EXPECT_IS_OK(inferred_status.status());
721   EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
722                                inferred_status.ValueOrDie()));
723 }
724 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)725 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
726   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
727   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
728   ProgramShape to_apply =
729       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
730                                   ShapeUtil::MakeTupleShape({f32_, s32_}));
731   auto inferred_status = ShapeInference::InferReduceShape(
732       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
733   EXPECT_FALSE(inferred_status.ok());
734   EXPECT_THAT(inferred_status.status().error_message(),
735               HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
736 }
737 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)738 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
739   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
740   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
741   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
742       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
743   auto inferred_status = ShapeInference::InferReduceShape(
744       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
745   EXPECT_FALSE(inferred_status.ok());
746   EXPECT_THAT(
747       inferred_status.status().error_message(),
748       HasSubstr(
749           "parameter shape differs from the result shape: s32[] vs f32[]"));
750 }
751 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)752 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
753   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
754       {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
755   auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
756   EXPECT_FALSE(inferred_status.ok());
757   EXPECT_THAT(inferred_status.status().error_message(),
758               HasSubstr("must have at least 2 arguments, has 0"));
759 }
760 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)761 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
762   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
763   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
764   ProgramShape to_apply =
765       ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
766   auto inferred_status = ShapeInference::InferReduceShape(
767       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
768   EXPECT_FALSE(inferred_status.ok());
769   EXPECT_THAT(
770       inferred_status.status().error_message(),
771       HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
772 }
773 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)774 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
775   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
776   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
777   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
778       {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
779   auto inferred_status = ShapeInference::InferReduceShape(
780       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
781   EXPECT_FALSE(inferred_status.ok());
782   EXPECT_THAT(
783       inferred_status.status().error_message(),
784       HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
785 }
786 
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)787 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
788   Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
789   Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
790   ProgramShape to_apply = ShapeUtil::MakeProgramShape(
791       {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
792   auto inferred_status = ShapeInference::InferReduceShape(
793       {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
794   EXPECT_FALSE(inferred_status.ok());
795   EXPECT_THAT(inferred_status.status().error_message(),
796               HasSubstr("accumulator shape at index 0 differs from the "
797                         "init_value shape: s32[] vs f32[]"));
798 }
799 
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)800 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
801   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
802   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
803   auto inferred_status = ShapeInference::InferReduceShape(
804       {&arg_shape, &f32_},
805       /*dimensions_to_reduce=*/{3, 4}, to_apply);
806   EXPECT_FALSE(inferred_status.ok());
807   EXPECT_THAT(inferred_status.status().error_message(),
808               HasSubstr("out-of-bounds dimension"));
809 }
810 
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)811 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
812   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
813   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
814   auto inferred_status =
815       ShapeInference::InferReduceShape({&arg_shape, &f32_},
816                                        /*dimensions_to_reduce=*/{0}, to_apply);
817   EXPECT_FALSE(inferred_status.ok());
818   EXPECT_THAT(inferred_status.status().error_message(),
819               HasSubstr("take 2 parameters"));
820 }
821 
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)822 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
823   ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
824   Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
825   auto inferred_status =
826       ShapeInference::InferReduceShape({&arg_shape, &f32_},
827                                        /*dimensions_to_reduce=*/{0}, to_apply);
828   EXPECT_FALSE(inferred_status.ok());
829   EXPECT_THAT(inferred_status.status().error_message(),
830               HasSubstr("0-th parameter shape differs"));
831 }
832 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)833 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
834   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
835   auto inferred_status =
836       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
837   ASSERT_IS_OK(inferred_status.status());
838   Shape inferred = inferred_status.ValueOrDie();
839   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
840 }
841 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)842 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
843   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
844   auto inferred_status =
845       ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
846   ASSERT_IS_OK(inferred_status.status());
847   Shape inferred = inferred_status.ValueOrDie();
848   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
849 }
850 
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)851 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
852   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
853   auto inferred_status =
854       ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
855   ASSERT_IS_OK(inferred_status.status());
856   Shape inferred = inferred_status.ValueOrDie();
857   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
858 }
859 
TEST_F(ShapeInferenceTest,InferInvalidStride)860 TEST_F(ShapeInferenceTest, InferInvalidStride) {
861   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
862   auto inferred_status =
863       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
864   ASSERT_FALSE(inferred_status.ok());
865   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
866             inferred_status.status().code());
867 }
868 
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)869 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
870   Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
871   auto inferred_status =
872       ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
873   ASSERT_FALSE(inferred_status.ok());
874   ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
875             inferred_status.status().code());
876 }
877 
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)878 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
879   Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
880   auto inferred_status =
881       ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
882   ASSERT_TRUE(inferred_status.ok());
883   Shape inferred = inferred_status.ValueOrDie();
884   ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
885 }
886 
TEST_F(ShapeInferenceTest,InferConstIndexShape)887 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
888   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
889   auto inferred0_status =
890       ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
891   auto inferred1_status =
892       ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
893   ASSERT_IS_OK(inferred0_status.status());
894   ASSERT_IS_OK(inferred1_status.status());
895   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
896   ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
897 }
898 
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)899 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
900   Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
901   auto inferredNegative_status =
902       ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
903   auto inferred2_status =
904       ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
905   ASSERT_FALSE(inferredNegative_status.ok());
906   ASSERT_FALSE(inferred2_status.ok());
907   EXPECT_THAT(inferredNegative_status.status().error_message(),
908               HasSubstr("attempt to index out of tuple bounds"));
909   EXPECT_THAT(inferred2_status.status().error_message(),
910               HasSubstr("attempt to index out of tuple bounds"));
911 }
912 
TEST_F(ShapeInferenceTest,InferPowShape)913 TEST_F(ShapeInferenceTest, InferPowShape) {
914   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
915   auto inferred_status = ShapeInference::InferBinaryOpShape(
916       HloOpcode::kPower, ten_floats, f32_, {});
917   ASSERT_IS_OK(inferred_status.status());
918   ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
919 }
920 
TEST_F(ShapeInferenceTest,InferCompareShape)921 TEST_F(ShapeInferenceTest, InferCompareShape) {
922   auto ten_floats = ShapeUtil::MakeShape(F32, {10});
923   auto inferred_status = ShapeInference::InferBinaryOpShape(
924       HloOpcode::kCompare, ten_floats, f32_, {});
925   ASSERT_IS_OK(inferred_status.status());
926   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
927                                inferred_status.ValueOrDie()));
928 }
929 
TEST_F(ShapeInferenceTest,BroadcastScalar)930 TEST_F(ShapeInferenceTest, BroadcastScalar) {
931   for (auto element_type : {F32, U32, S8}) {
932     const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
933     {  // no-op scalar broadcast
934       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
935       ASSERT_IS_OK(status.status());
936       ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
937     }
938     const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
939     {  // scalar -> 1d broadcast
940       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
941       ASSERT_IS_OK(status.status());
942       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
943     }
944     {  // no-op 1d broadcast
945       auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
946       ASSERT_IS_OK(status.status());
947       ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
948     }
949     const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
950     {  // scalar -> 2d broadcast
951       auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
952       ASSERT_IS_OK(status.status());
953       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
954     }
955     {  // 1d -> 2d broadcast
956       auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
957       ASSERT_IS_OK(status.status());
958       ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
959     }
960   }
961 }
962 
963 // scalar <dot> vector: error
TEST_F(ShapeInferenceTest,ScalarDotVector)964 TEST_F(ShapeInferenceTest, ScalarDotVector) {
965   DotDimensionNumbers dot_dnums;
966   dot_dnums.add_lhs_contracting_dimensions(1);
967   dot_dnums.add_rhs_contracting_dimensions(0);
968   auto inferred_status =
969       ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
970   ASSERT_FALSE(inferred_status.ok());
971   ASSERT_THAT(inferred_status.status().error_message(),
972               HasSubstr("Dot only supports rank"));
973 }
974 
975 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)976 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
977   DotDimensionNumbers dot_dnums;
978   dot_dnums.add_lhs_contracting_dimensions(1);
979   dot_dnums.add_rhs_contracting_dimensions(0);
980   auto inferred_status = ShapeInference::InferDotOpShape(
981       ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
982   EXPECT_TRUE(inferred_status.ok());
983   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
984                                ShapeUtil::MakeShape(F32, {32, 32, 64})));
985 }
986 
987 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)988 TEST_F(ShapeInferenceTest, VectorDotVector) {
989   DotDimensionNumbers dot_dnums;
990   dot_dnums.add_lhs_contracting_dimensions(0);
991   dot_dnums.add_rhs_contracting_dimensions(0);
992   auto inferred_status =
993       ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
994   ASSERT_IS_OK(inferred_status.status());
995   ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
996   auto inferred_status_mismatch =
997       ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
998   ASSERT_FALSE(inferred_status_mismatch.ok());
999 }
1000 
1001 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1002 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1003   DotDimensionNumbers dot_dnums;
1004   dot_dnums.add_lhs_contracting_dimensions(1);
1005   dot_dnums.add_rhs_contracting_dimensions(0);
1006   auto inferred_status =
1007       ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
1008   ASSERT_IS_OK(inferred_status.status());
1009   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1010   auto inferred_status_mismatch =
1011       ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
1012   ASSERT_FALSE(inferred_status_mismatch.ok());
1013 }
1014 
1015 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1016 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1017   DotDimensionNumbers dot_dnums;
1018   dot_dnums.add_lhs_contracting_dimensions(0);
1019   dot_dnums.add_rhs_contracting_dimensions(0);
1020   auto inferred_status =
1021       ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
1022   ASSERT_IS_OK(inferred_status.status());
1023   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1024   auto inferred_status_mismatch =
1025       ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
1026   ASSERT_FALSE(inferred_status_mismatch.ok());
1027 }
1028 
1029 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1030 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1031   DotDimensionNumbers dot_dnums;
1032   dot_dnums.add_lhs_contracting_dimensions(1);
1033   dot_dnums.add_rhs_contracting_dimensions(0);
1034   auto inferred_status_match =
1035       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
1036   ASSERT_IS_OK(inferred_status_match.status());
1037   ASSERT_TRUE(
1038       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1039       << "inferred: "
1040       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1041       << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1042   auto inferred_status_mismatch =
1043       ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
1044   ASSERT_FALSE(inferred_status_mismatch.ok());
1045 }
1046 
1047 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1048 TEST_F(ShapeInferenceTest, DotGeneral) {
1049   Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1050   Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1051   Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1052 
1053   DotDimensionNumbers dot_dnums;
1054   dot_dnums.add_lhs_contracting_dimensions(3);
1055   dot_dnums.add_lhs_batch_dimensions(0);
1056   dot_dnums.add_lhs_batch_dimensions(1);
1057 
1058   dot_dnums.add_rhs_contracting_dimensions(2);
1059   dot_dnums.add_rhs_batch_dimensions(0);
1060   dot_dnums.add_rhs_batch_dimensions(1);
1061 
1062   auto inferred_status_match =
1063       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1064   ASSERT_IS_OK(inferred_status_match.status());
1065   ASSERT_TRUE(
1066       ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1067       << "inferred: "
1068       << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1069       << " expected: " << ShapeUtil::HumanString(output_shape);
1070 }
1071 
1072 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1073 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1074   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1075   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1076 
1077   DotDimensionNumbers dot_dnums;
1078   dot_dnums.add_lhs_contracting_dimensions(2);
1079   dot_dnums.add_lhs_contracting_dimensions(3);
1080   dot_dnums.add_lhs_batch_dimensions(0);
1081 
1082   dot_dnums.add_rhs_contracting_dimensions(1);
1083   dot_dnums.add_rhs_batch_dimensions(0);
1084 
1085   auto inferred_status =
1086       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1087   ASSERT_FALSE(inferred_status.ok());
1088   ASSERT_THAT(inferred_status.status().error_message(),
1089               HasSubstr("Must specify the same number of contracting "
1090                         "dimensions for lhs and rhs."));
1091 }
1092 
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1093 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1094   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1095   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1096   Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1097 
1098   DotDimensionNumbers dot_dnums;
1099   dot_dnums.add_lhs_contracting_dimensions(2);
1100   dot_dnums.add_lhs_contracting_dimensions(3);
1101   dot_dnums.add_lhs_batch_dimensions(0);
1102 
1103   dot_dnums.add_rhs_contracting_dimensions(1);
1104   dot_dnums.add_rhs_contracting_dimensions(2);
1105   dot_dnums.add_rhs_batch_dimensions(0);
1106 
1107   auto inferred_status =
1108       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1109   EXPECT_TRUE(inferred_status.ok());
1110   EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1111 }
1112 
1113 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMisatchedBatchDimSizesFails)1114 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
1115   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1116   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1117 
1118   DotDimensionNumbers dot_dnums;
1119   dot_dnums.add_lhs_contracting_dimensions(2);
1120   dot_dnums.add_lhs_batch_dimensions(0);
1121 
1122   dot_dnums.add_rhs_contracting_dimensions(1);
1123   dot_dnums.add_rhs_batch_dimensions(0);
1124 
1125   auto inferred_status =
1126       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1127   ASSERT_FALSE(inferred_status.ok());
1128   ASSERT_THAT(inferred_status.status().error_message(),
1129               HasSubstr("Batch dimension sizes must match"));
1130 }
1131 
1132 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMisatchedBatchDimNumbersPasses)1133 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) {
1134   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1135   Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1136 
1137   DotDimensionNumbers dot_dnums;
1138   dot_dnums.add_lhs_contracting_dimensions(2);
1139   dot_dnums.add_lhs_batch_dimensions(0);
1140 
1141   dot_dnums.add_rhs_contracting_dimensions(0);
1142   dot_dnums.add_rhs_batch_dimensions(1);
1143 
1144   auto inferred_status =
1145       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1146   ASSERT_TRUE(inferred_status.ok());
1147   ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1148                                ShapeUtil::MakeShape(F32, {2, 11, 14})));
1149 }
1150 
1151 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1152 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1153   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1154   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1155 
1156   DotDimensionNumbers dot_dnums;
1157   dot_dnums.add_lhs_contracting_dimensions(3);
1158   dot_dnums.add_lhs_batch_dimensions(0);
1159 
1160   dot_dnums.add_rhs_contracting_dimensions(0);
1161   dot_dnums.add_rhs_batch_dimensions(1);
1162 
1163   auto inferred_status =
1164       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1165   ASSERT_FALSE(inferred_status.ok());
1166   ASSERT_THAT(inferred_status.status().error_message(),
1167               HasSubstr("A dimension number is out of range"));
1168 }
1169 
1170 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1171 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1172   Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1173   Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1174 
1175   DotDimensionNumbers dot_dnums;
1176   dot_dnums.add_lhs_contracting_dimensions(0);
1177   dot_dnums.add_lhs_batch_dimensions(0);
1178 
1179   dot_dnums.add_rhs_contracting_dimensions(0);
1180   dot_dnums.add_rhs_batch_dimensions(1);
1181 
1182   auto inferred_status =
1183       ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1184   ASSERT_FALSE(inferred_status.ok());
1185   ASSERT_THAT(inferred_status.status().error_message(),
1186               HasSubstr("A dimension number is not unique"));
1187 }
1188 
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1189 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1190   // Test variations of broadcasting a vector for a binary add with a
1191   // matrix.
1192   const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1193   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1194   const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1195 
1196   auto inferred_status_match =
1197       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1198   ASSERT_IS_OK(inferred_status_match.status());
1199   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1200 
1201   auto inferred_status_mismatch =
1202       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1203   ASSERT_FALSE(inferred_status_mismatch.ok());
1204 
1205   inferred_status_match =
1206       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1207   ASSERT_IS_OK(inferred_status_match.status());
1208   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1209 
1210   inferred_status_mismatch =
1211       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1212   ASSERT_FALSE(inferred_status_mismatch.ok());
1213 }
1214 
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1215 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1216   // Test variations of broadcasting a matrix for a binary add with a cube.
1217   const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1218   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1219   const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1220   const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1221 
1222   auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1223       HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1224   ASSERT_IS_OK(inferred_status_match.status());
1225   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1226 
1227   inferred_status_match = ShapeInference::InferBinaryOpShape(
1228       HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1229   ASSERT_IS_OK(inferred_status_match.status());
1230   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1231 
1232   inferred_status_match = ShapeInference::InferBinaryOpShape(
1233       HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1234   ASSERT_IS_OK(inferred_status_match.status());
1235   ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1236 }
1237 
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1238 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1239   // Test various errors with the broadcast argument.
1240   const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1241   const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1242   const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1243   const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1244   const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1245 
1246   // "magical" broadcast rejected
1247   auto inferred_status_error1 =
1248       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1249   ASSERT_FALSE(inferred_status_error1.ok());
1250   ASSERT_THAT(inferred_status_error1.status().error_message(),
1251               HasSubstr("Automatic"));
1252 
1253   // broadcast_dimension out of bounds for tensor's rank
1254   auto inferred_status_error2 =
1255       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1256   ASSERT_FALSE(inferred_status_error2.ok());
1257   ASSERT_THAT(inferred_status_error2.status().error_message(),
1258               ContainsRegex("Broadcast dimension number .* too large"));
1259 
1260   // broadcast_dimension doesn't match corresponding dimension
1261   auto inferred_status_error3 =
1262       ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1263   ASSERT_FALSE(inferred_status_error3.ok());
1264   ASSERT_THAT(inferred_status_error3.status().error_message(),
1265               HasSubstr("Broadcast dimension 0 mismatch"));
1266 
1267   // broadcast_dimensions list too long
1268   auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1269       HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1270   ASSERT_FALSE(inferred_status_error4.ok());
1271   ASSERT_THAT(inferred_status_error4.status().error_message(),
1272               HasSubstr("broadcast_dimensions has to match"));
1273 
1274   // there's a dimension above the rank of the tensor
1275   auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1276       HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1277   ASSERT_FALSE(inferred_status_error5.ok());
1278   ASSERT_THAT(inferred_status_error5.status().error_message(),
1279               ContainsRegex("dimension number .* too large"));
1280 
1281   // broadcasting dimensions don't match in this order
1282   auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1283       HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1284   ASSERT_FALSE(inferred_status_error6.ok());
1285   ASSERT_THAT(inferred_status_error6.status().error_message(),
1286               HasSubstr("dimension 0 mismatch"));
1287 
1288   // The following two tests make sure that broadcasting dimensions are listed
1289   // in a proper (strictly increasing) order, even if the lower-rank array
1290   // matches the higher-rank array in many different ways.
1291   auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1292       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1293   ASSERT_FALSE(inferred_status_error7.ok());
1294   ASSERT_THAT(inferred_status_error7.status().error_message(),
1295               HasSubstr("dimensions order is wrong"));
1296 
1297   auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1298       HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1299   ASSERT_FALSE(inferred_status_error8.ok());
1300   ASSERT_THAT(inferred_status_error8.status().error_message(),
1301               HasSubstr("dimensions order is wrong"));
1302 }
1303 
1304 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1305 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1306   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1307   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1308   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1309   auto inferred_status =
1310       ShapeInference::InferWhileShape(cond, body, result_shape);
1311   ASSERT_IS_OK(inferred_status.status());
1312   Shape inferred = inferred_status.ValueOrDie();
1313   ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1314 }
1315 
1316 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1317 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1318   Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1319   ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1320   ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1321 
1322   auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1323   auto inferred_status_error1 =
1324       ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1325   ASSERT_FALSE(inferred_status_error1.ok());
1326   ASSERT_THAT(inferred_status_error1.status().error_message(),
1327               HasSubstr("Condition must take 1 arguments"));
1328 
1329   auto bad_shape_2 =
1330       ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1331   auto inferred_status_error2 =
1332       ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
1333   ASSERT_FALSE(inferred_status_error2.ok());
1334   ASSERT_THAT(inferred_status_error2.status().error_message(),
1335               HasSubstr("Body must take 1 arguments"));
1336 
1337   auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
1338   auto inferred_status_error3 =
1339       ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
1340   ASSERT_FALSE(inferred_status_error3.ok());
1341   ASSERT_THAT(inferred_status_error3.status().error_message(),
1342               HasSubstr("Condition must return a boolean"));
1343 
1344   auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
1345   auto inferred_status_error4 =
1346       ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
1347   ASSERT_FALSE(inferred_status_error4.ok());
1348   ASSERT_THAT(inferred_status_error4.status().error_message(),
1349               HasSubstr("parameter of condition and body"));
1350 }
1351 
1352 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)1353 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
1354   auto inferred_status_1 = ShapeInference::InferConcatOpShape(
1355       {&vector_32_, &vector_64_}, /*dimension=*/0);
1356   ASSERT_IS_OK(inferred_status_1.status());
1357   Shape inferred_1 = inferred_status_1.ValueOrDie();
1358   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
1359 
1360   auto inferred_status_2 = ShapeInference::InferConcatOpShape(
1361       {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
1362   ASSERT_IS_OK(inferred_status_2.status());
1363   Shape inferred_2 = inferred_status_2.ValueOrDie();
1364   ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
1365 
1366   auto inferred_status_3 = ShapeInference::InferConcatOpShape(
1367       {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
1368   ASSERT_IS_OK(inferred_status_3.status());
1369   Shape inferred_3 = inferred_status_3.ValueOrDie();
1370   ASSERT_TRUE(
1371       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
1372 }
1373 
1374 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)1375 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
1376   auto inferred_status_error1 =
1377       ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
1378   ASSERT_FALSE(inferred_status_error1.ok());
1379   ASSERT_THAT(inferred_status_error1.status().error_message(),
1380               HasSubstr("Concatenate expects at least one argument"));
1381 
1382   auto inferred_status_error2 =
1383       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
1384   ASSERT_FALSE(inferred_status_error2.ok());
1385   ASSERT_THAT(inferred_status_error2.status().error_message(),
1386               HasSubstr("dimension out of bounds: -1"));
1387 
1388   auto inferred_status_error3 =
1389       ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
1390   ASSERT_FALSE(inferred_status_error3.ok());
1391   ASSERT_THAT(inferred_status_error3.status().error_message(),
1392               HasSubstr("dimension out of bounds: 1"));
1393 
1394   Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
1395   auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
1396       {&vector_32_, &tuple}, /*dimension=*/0);
1397   ASSERT_FALSE(inferred_status_error4.ok());
1398   ASSERT_THAT(
1399       inferred_status_error4.status().error_message(),
1400       HasSubstr("Expected array argument for operand of concatenation"));
1401 
1402   const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
1403   auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
1404       {&vector_32_, &vector_s32}, /*dimension=*/0);
1405   ASSERT_FALSE(inferred_status_error5.ok());
1406   ASSERT_THAT(inferred_status_error5.status().error_message(),
1407               HasSubstr("concatenate arrays with different element types"));
1408 
1409   auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
1410       {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
1411   ASSERT_FALSE(inferred_status_error6.ok());
1412   ASSERT_THAT(inferred_status_error6.status().error_message(),
1413               HasSubstr("concatenate arrays that differ in "
1414                         "dimensions other than the one being "
1415                         "concatenated"));
1416 }
1417 
TEST_F(ShapeInferenceTest,Pad)1418 TEST_F(ShapeInferenceTest, Pad) {
1419   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
1420   Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
1421   // Padding for dimension 0: {low: 0, high: 2, interior: 3}
1422   // Padding for dimension 1: {low: 1, high: 5, interior: 0}
1423   PaddingConfig padding_config;
1424   auto dimension0 = padding_config.add_dimensions();
1425   dimension0->set_edge_padding_low(0);
1426   dimension0->set_edge_padding_high(2);
1427   dimension0->set_interior_padding(3);
1428   auto dimension1 = padding_config.add_dimensions();
1429   dimension1->set_edge_padding_low(1);
1430   dimension1->set_edge_padding_high(5);
1431   dimension1->set_interior_padding(0);
1432 
1433   auto inferred_status = ShapeInference::InferPadShape(
1434       input_shape, padding_value_shape, padding_config);
1435   ASSERT_IS_OK(inferred_status.status());
1436   Shape inferred_shape = inferred_status.ValueOrDie();
1437   ASSERT_TRUE(
1438       ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
1439 
1440   dimension1->set_edge_padding_low(-20);
1441   dimension1->set_edge_padding_high(-10);
1442   auto negative_dimension_size = ShapeInference::InferPadShape(
1443       input_shape, padding_value_shape, padding_config);
1444   ASSERT_FALSE(negative_dimension_size.ok());
1445   ASSERT_THAT(negative_dimension_size.status().error_message(),
1446               HasSubstr("negative size for dimension 1"));
1447 }
1448 
TEST_F(ShapeInferenceTest,Reverse)1449 TEST_F(ShapeInferenceTest, Reverse) {
1450   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
1451 
1452   auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
1453   ASSERT_IS_OK(inferred_status.status());
1454   Shape inferred_shape = inferred_status.ValueOrDie();
1455   ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
1456 }
1457 
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)1458 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
1459   Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
1460 
1461   auto inferred_status_error0 =
1462       ShapeInference::InferReverseShape(input_shape, {0, 2});
1463   ASSERT_FALSE(inferred_status_error0.ok());
1464   ASSERT_THAT(inferred_status_error0.status().error_message(),
1465               HasSubstr("out-of-bounds"));
1466 
1467   auto inferred_status_error1 =
1468       ShapeInference::InferReverseShape(input_shape, {0, -1});
1469   ASSERT_FALSE(inferred_status_error1.ok());
1470   ASSERT_THAT(inferred_status_error1.status().error_message(),
1471               HasSubstr("out-of-bounds"));
1472 
1473   auto inferred_status_error2 =
1474       ShapeInference::InferReverseShape(input_shape, {0, 0});
1475   ASSERT_FALSE(inferred_status_error2.ok());
1476   ASSERT_THAT(inferred_status_error2.status().error_message(),
1477               HasSubstr("duplicated"));
1478 
1479   Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
1480   auto inferred_status_error3 =
1481       ShapeInference::InferReverseShape(tuple_shape, {0});
1482   ASSERT_FALSE(inferred_status_error3.ok());
1483   ASSERT_THAT(inferred_status_error3.status().error_message(),
1484               HasSubstr("Expected array argument"));
1485 }
1486 
TEST_F(ShapeInferenceTest,Call)1487 TEST_F(ShapeInferenceTest, Call) {
1488   auto inferred_status0 =
1489       ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
1490   EXPECT_IS_OK(inferred_status0.status());
1491   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
1492 
1493   auto inferred_status1 = ShapeInference::InferCallShape(
1494       {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
1495       ShapeUtil::MakeProgramShape(
1496           {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
1497   EXPECT_IS_OK(inferred_status1.status());
1498   EXPECT_TRUE(
1499       ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
1500 
1501   auto inferred_status_error0 = ShapeInference::InferCallShape(
1502       {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
1503   EXPECT_FALSE(inferred_status_error0.ok());
1504   EXPECT_THAT(inferred_status_error0.status().error_message(),
1505               HasSubstr("arity must match"));
1506 
1507   auto inferred_status_error1 = ShapeInference::InferCallShape(
1508       {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
1509   EXPECT_FALSE(inferred_status_error1.ok());
1510   EXPECT_THAT(inferred_status_error1.status().error_message(),
1511               HasSubstr("arity must match"));
1512 
1513   auto inferred_status_error2 = ShapeInference::InferCallShape(
1514       {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
1515   EXPECT_FALSE(inferred_status_error2.ok());
1516   EXPECT_THAT(inferred_status_error2.status().error_message(),
1517               HasSubstr("parameter must match argument"));
1518 }
1519 
TEST_F(ShapeInferenceTest,Transpose)1520 TEST_F(ShapeInferenceTest, Transpose) {
1521   Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
1522   auto inferred_shape_and_status =
1523       ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
1524   EXPECT_IS_OK(inferred_shape_and_status);
1525   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
1526   EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
1527                                     ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
1528 }
1529 
TEST_F(ShapeInferenceTest,Rank1Transpose)1530 TEST_F(ShapeInferenceTest, Rank1Transpose) {
1531   Shape a_shape = ShapeUtil::MakeShape(F32, {5});
1532   auto inferred_shape_and_status =
1533       ShapeInference::InferTransposeShape(a_shape, {0});
1534   EXPECT_IS_OK(inferred_shape_and_status);
1535   Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
1536   EXPECT_TRUE(
1537       ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
1538 }
1539 
TEST_F(ShapeInferenceTest,ConditionalPred)1540 TEST_F(ShapeInferenceTest, ConditionalPred) {
1541   auto inferred_status0 = ShapeInference::InferConditionalShape(
1542       pred_,
1543       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1544        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1545       {vector_32_, vector_64_});
1546   EXPECT_IS_OK(inferred_status0.status());
1547   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
1548 
1549   auto inferred_status1 = ShapeInference::InferConditionalShape(
1550       pred_,
1551       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
1552        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
1553       {matrix_32_48_, vector_32_});
1554   EXPECT_IS_OK(inferred_status1.status());
1555   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
1556 
1557   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
1558   auto inferred_status2 = ShapeInference::InferConditionalShape(
1559       pred_,
1560       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
1561        ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
1562       {matrix_32_48_, tuple_f32_v32});
1563   EXPECT_IS_OK(inferred_status2.status());
1564   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
1565 
1566   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
1567       f32_,
1568       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1569        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1570       {vector_32_, vector_64_});
1571   EXPECT_FALSE(inferred_status_error0.ok());
1572   EXPECT_THAT(inferred_status_error0.status().error_message(),
1573               HasSubstr("must be bool or int32"));
1574 
1575   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
1576       pred_,
1577       {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
1578        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
1579       {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
1580   EXPECT_FALSE(inferred_status_error1.ok());
1581   EXPECT_THAT(inferred_status_error1.status().error_message(),
1582               HasSubstr("branch computation 0 must take 1 argument"));
1583 
1584   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
1585       pred_,
1586       {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
1587        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1588       {vector_32_, vector_64_});
1589   EXPECT_FALSE(inferred_status_error2.ok());
1590   EXPECT_THAT(inferred_status_error2.status().error_message(),
1591               HasSubstr("branch operand 0 must match the shape of the only "
1592                         "parameter of branch computation 0"));
1593 
1594   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
1595       pred_,
1596       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
1597        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
1598       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
1599   EXPECT_FALSE(inferred_status_error3.ok());
1600   EXPECT_THAT(inferred_status_error3.status().error_message(),
1601               HasSubstr("branch computation 1 must take 1 argument"));
1602 
1603   auto inferred_status_error4 = ShapeInference::InferConditionalShape(
1604       pred_,
1605       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1606        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
1607       {vector_32_, vector_64_});
1608   EXPECT_FALSE(inferred_status_error4.ok());
1609   EXPECT_THAT(inferred_status_error4.status().error_message(),
1610               HasSubstr("branch operand 1 must match the shape of the only "
1611                         "parameter of branch computation 1"));
1612 
1613   auto inferred_status_error5 = ShapeInference::InferConditionalShape(
1614       pred_,
1615       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1616        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
1617       {vector_32_, vector_64_});
1618   EXPECT_FALSE(inferred_status_error5.ok());
1619   EXPECT_THAT(inferred_status_error5.status().error_message(),
1620               HasSubstr("the result of branch 0 computation and branch 1 "
1621                         "computation must have the same shape"));
1622 }
1623 
TEST_F(ShapeInferenceTest,ConditionalIndexed)1624 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
1625   auto r0s32 = ShapeUtil::MakeShape(S32, {});
1626   auto inferred_status0 = ShapeInference::InferConditionalShape(
1627       r0s32,
1628       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1629        ShapeUtil::MakeProgramShape({vector_64_}, f32_),
1630        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1631       {vector_32_, vector_64_, vector_64_});
1632   EXPECT_IS_OK(inferred_status0.status());
1633   EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
1634 
1635   auto inferred_status1 = ShapeInference::InferConditionalShape(
1636       r0s32,
1637       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
1638        ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
1639        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
1640       {matrix_32_48_, vector_32_, matrix_32_48_});
1641   EXPECT_IS_OK(inferred_status1.status());
1642   EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
1643 
1644   auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
1645   auto inferred_status2 = ShapeInference::InferConditionalShape(
1646       r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
1647       {tuple_f32_v32});
1648   EXPECT_IS_OK(inferred_status2.status());
1649   EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
1650 
1651   auto inferred_status_error0 = ShapeInference::InferConditionalShape(
1652       pred_,
1653       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1654        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1655        ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1656       {vector_32_, vector_32_, vector_64_});
1657   EXPECT_FALSE(inferred_status_error0.ok());
1658   EXPECT_THAT(inferred_status_error0.status().error_message(),
1659               HasSubstr("2 == branch_computations.size()"));
1660 
1661   auto inferred_status_error1 = ShapeInference::InferConditionalShape(
1662       r0s32,
1663       {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
1664        ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
1665        ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
1666       {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
1667        matrix_32_48_});
1668   EXPECT_FALSE(inferred_status_error1.ok());
1669   EXPECT_THAT(inferred_status_error1.status().error_message(),
1670               HasSubstr("branch computation 1 must take 1 argument"));
1671 
1672   auto inferred_status_error2 = ShapeInference::InferConditionalShape(
1673       r0s32,
1674       {ShapeUtil::MakeProgramShape({r0s32}, f32_),
1675        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1676        ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
1677       {r0s32, vector_32_, vector_64_});
1678   EXPECT_FALSE(inferred_status_error2.ok());
1679   EXPECT_THAT(inferred_status_error2.status().error_message(),
1680               HasSubstr("branch operand 2 must match the shape of the only "
1681                         "parameter of branch computation 2"));
1682 
1683   auto inferred_status_error3 = ShapeInference::InferConditionalShape(
1684       r0s32,
1685       {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1686        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1687        ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1688        ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
1689       {vector_32_, vector_32_, vector_32_, vector_64_});
1690   EXPECT_FALSE(inferred_status_error3.ok());
1691   EXPECT_THAT(inferred_status_error3.status().error_message(),
1692               HasSubstr("the result of branch 0 computation and branch 3 "
1693                         "computation must have the same shape"));
1694 
1695   auto inferred_status_error4 =
1696       ShapeInference::InferConditionalShape(r0s32, {}, {});
1697   EXPECT_FALSE(inferred_status_error4.ok());
1698   EXPECT_THAT(inferred_status_error4.status().error_message(),
1699               HasSubstr("!branch_computations.empty()"));
1700 }
1701 
TEST_F(ShapeInferenceTest,BadSlice)1702 TEST_F(ShapeInferenceTest, BadSlice) {
1703   auto arg = ShapeUtil::MakeShape(F32, {4});
1704   StatusOr<Shape> statusor =
1705       ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
1706   ASSERT_FALSE(statusor.ok());
1707 
1708   LOG(INFO) << statusor.status();
1709 
1710   EXPECT_THAT(statusor.status().error_message(),
1711               HasSubstr("less than or equal to dimension size"))
1712       << statusor.status();
1713   EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
1714       << statusor.status();
1715 }
1716 
TEST_F(ShapeInferenceTest,BadSort)1717 TEST_F(ShapeInferenceTest, BadSort) {
1718   auto keys = ShapeUtil::MakeShape(F32, {4});
1719   auto values = ShapeUtil::MakeShape(F32, {5});
1720   StatusOr<Shape> statusor =
1721       ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
1722   EXPECT_FALSE(statusor.ok());
1723   EXPECT_THAT(statusor.status().error_message(),
1724               HasSubstr("dimensions must match"))
1725       << statusor.status();
1726 }
1727 
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)1728 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
1729   auto keys = ShapeUtil::MakeShape(F32, {4});
1730   auto values_good = ShapeUtil::MakeShape(F32, {4});
1731   auto values_bad = ShapeUtil::MakeShape(F32, {5});
1732   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
1733       HloOpcode::kSort, {&keys, &values_good, &values_bad});
1734   EXPECT_FALSE(statusor.ok());
1735   EXPECT_THAT(statusor.status().error_message(),
1736               HasSubstr("dimensions must match"))
1737       << statusor.status();
1738 }
1739 
TEST_F(ShapeInferenceTest,SortManyValues)1740 TEST_F(ShapeInferenceTest, SortManyValues) {
1741   auto keys = ShapeUtil::MakeShape(F32, {4});
1742   auto values_s32 = ShapeUtil::MakeShape(S32, {4});
1743   auto values_u32 = ShapeUtil::MakeShape(U32, {4});
1744   StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
1745       HloOpcode::kSort, {&keys, &values_s32, &values_u32});
1746   EXPECT_IS_OK(statusor);
1747   Shape inferred_shape = statusor.ValueOrDie();
1748   EXPECT_TRUE(ShapeUtil::Compatible(
1749       inferred_shape,
1750       ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
1751 }
1752 
1753 class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
1754  protected:
1755   const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
1756   const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
1757   const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
1758   const Shape s64_4d_tensor_10_9_8_7_1_ =
1759       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
1760   const Shape s64_4d_tensor_10_9_8_7_5_ =
1761       ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
1762   const Shape s64_4d_tensor_5_10_9_7_6_ =
1763       ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
1764   const Shape s64_4d_tensor_10_9_5_7_6_ =
1765       ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1766   const Shape f32_5d_tensor_50_49_48_47_46_ =
1767       ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1768   const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
1769       {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
1770   const ProgramShape to_apply_ =
1771       ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1772 };
1773 
1774 // Shape inference tests for Gather.
1775 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGather)1776 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
1777   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1778                           ShapeInference::InferGatherShape(
1779                               matrix_64_48_, s64_vector_32_,
1780                               HloGatherInstruction::MakeGatherDimNumbers(
1781                                   /*offset_dims=*/{0},
1782                                   /*collapsed_slice_dims=*/{1},
1783                                   /*start_index_map=*/{1},
1784                                   /*index_vector_dim=*/1),
1785                               /*slice_sizes=*/{64, 1}));
1786   EXPECT_TRUE(
1787       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
1788       << ShapeUtil::HumanString(gather_shape);
1789 }
1790 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherV2)1791 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
1792   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1793                           ShapeInference::InferGatherShape(
1794                               matrix_64_48_, s64_vector_32_,
1795                               HloGatherInstruction::MakeGatherDimNumbers(
1796                                   /*offset_dims=*/{1},
1797                                   /*collapsed_slice_dims=*/{0},
1798                                   /*start_index_map=*/{0},
1799                                   /*index_vector_dim=*/1),
1800                               /*slice_sizes=*/{1, 48}));
1801   EXPECT_TRUE(
1802       ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
1803       << ShapeUtil::HumanString(gather_shape);
1804 }
1805 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherNd)1806 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
1807   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1808                           ShapeInference::InferGatherShape(
1809                               matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
1810                               HloGatherInstruction::MakeGatherDimNumbers(
1811                                   /*offset_dims=*/{4},
1812                                   /*collapsed_slice_dims=*/{0},
1813                                   /*start_index_map=*/{0},
1814                                   /*index_vector_dim=*/4),
1815                               /*slice_sizes=*/{1, 48}));
1816   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
1817                                ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
1818       << ShapeUtil::HumanString(gather_shape);
1819 }
1820 
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowBatchDynamicSlice)1821 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
1822   TF_ASSERT_OK_AND_ASSIGN(
1823       Shape gather_shape,
1824       ShapeInference::InferGatherShape(
1825           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1826           HloGatherInstruction::MakeGatherDimNumbers(
1827               /*offset_dims=*/{4, 5, 6, 7, 8},
1828               /*collapsed_slice_dims=*/{},
1829               /*start_index_map=*/{0, 1, 2, 3, 4},
1830               /*index_vector_dim=*/4),
1831           /*slice_sizes=*/{30, 29, 28, 27, 26}));
1832   EXPECT_TRUE(ShapeUtil::Equal(
1833       gather_shape,
1834       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
1835       << ShapeUtil::HumanString(gather_shape);
1836 }
1837 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)1838 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
1839   TF_ASSERT_OK_AND_ASSIGN(
1840       Shape gather_shape,
1841       ShapeInference::InferGatherShape(
1842           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
1843           HloGatherInstruction::MakeGatherDimNumbers(
1844               /*offset_dims=*/{4, 5, 6, 7, 8},
1845               /*collapsed_slice_dims=*/{},
1846               /*start_index_map=*/{0, 1, 2, 3, 4},
1847               /*index_vector_dim=*/2),
1848           /*slice_sizes=*/{30, 29, 28, 27, 26}));
1849 
1850   EXPECT_TRUE(ShapeUtil::Equal(
1851       gather_shape,
1852       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
1853       << ShapeUtil::HumanString(gather_shape);
1854 }
1855 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)1856 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
1857   TF_ASSERT_OK_AND_ASSIGN(
1858       Shape gather_shape,
1859       ShapeInference::InferGatherShape(
1860           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
1861           HloGatherInstruction::MakeGatherDimNumbers(
1862               /*offset_dims=*/{4, 5, 6, 7, 8},
1863               /*collapsed_slice_dims=*/{},
1864               /*start_index_map=*/{0, 1, 2, 3, 4},
1865               /*index_vector_dim=*/0),
1866           /*slice_sizes=*/{30, 29, 28, 27, 26}));
1867 
1868   EXPECT_TRUE(ShapeUtil::Equal(
1869       gather_shape,
1870       ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
1871       << ShapeUtil::HumanString(gather_shape);
1872 }
1873 
TEST_F(ScatterGatherShapeInferenceTest,NoOutputGatherDims)1874 TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
1875   // This is equivalent to a dynamic slice.
1876   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1877                           ShapeInference::InferGatherShape(
1878                               f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
1879                               HloGatherInstruction::MakeGatherDimNumbers(
1880                                   /*offset_dims=*/{0, 1, 2, 3, 4},
1881                                   /*collapsed_slice_dims=*/{},
1882                                   /*start_index_map=*/{0, 1, 2, 3, 4},
1883                                   /*index_vector_dim=*/0),
1884                               /*slice_sizes=*/{30, 29, 28, 27, 26}));
1885 
1886   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
1887                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
1888       << ShapeUtil::HumanString(gather_shape);
1889 }
1890 
TEST_F(ScatterGatherShapeInferenceTest,ScalarGatherIndices)1891 TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
1892   // The gather indices "tensor" is a scalar S here that's used to slice out
1893   // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
1894   TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1895                           ShapeInference::InferGatherShape(
1896                               f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
1897                               HloGatherInstruction::MakeGatherDimNumbers(
1898                                   /*offset_dims=*/{0, 1, 2, 3},
1899                                   /*collapsed_slice_dims=*/{0},
1900                                   /*start_index_map=*/{0},
1901                                   /*index_vector_dim=*/0),
1902                               /*slice_sizes=*/{1, 30, 29, 28, 27}));
1903 
1904   EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
1905                                ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
1906       << ShapeUtil::HumanString(gather_shape);
1907 }
1908 
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedTensorInput)1909 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
1910   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1911       tuple_shape_, s64_vector_32_,
1912       HloGatherInstruction::MakeGatherDimNumbers(
1913           /*offset_dims=*/{0},
1914           /*collapsed_slice_dims=*/{1},
1915           /*start_index_map=*/{1},
1916           /*index_vector_dim=*/1),
1917       /*slice_sizes=*/{64, 1});
1918   ASSERT_FALSE(statusor.ok());
1919   EXPECT_THAT(statusor.status().error_message(),
1920               HasSubstr("Expected array argument for input"))
1921       << statusor.status();
1922 }
1923 
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedGatherIndicesInput)1924 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
1925   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1926       s64_vector_32_, tuple_shape_,
1927       HloGatherInstruction::MakeGatherDimNumbers(
1928           /*offset_dims=*/{0},
1929           /*collapsed_slice_dims=*/{1},
1930           /*start_index_map=*/{1},
1931           /*index_vector_dim=*/0),
1932       /*slice_sizes=*/{64, 1});
1933   ASSERT_FALSE(statusor.ok());
1934   EXPECT_THAT(statusor.status().error_message(),
1935               HasSubstr("Expected array argument for gather indices"))
1936       << statusor.status();
1937 }
1938 
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointGatherIndicesInput)1939 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
1940   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1941       s64_vector_32_, vector_32_,
1942       HloGatherInstruction::MakeGatherDimNumbers(
1943           /*offset_dims=*/{0},
1944           /*collapsed_slice_dims=*/{1},
1945           /*start_index_map=*/{1},
1946           /*index_vector_dim=*/0),
1947       /*slice_sizes=*/{64, 1});
1948   ASSERT_FALSE(statusor.ok());
1949   EXPECT_THAT(statusor.status().error_message(),
1950               HasSubstr("Gather indices parameter must be an integral tensor"))
1951       << statusor.status();
1952 }
1953 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)1954 TEST_F(ScatterGatherShapeInferenceTest,
1955        InvalidGatherDimNumbers_NonAscendingWindowIndices) {
1956   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1957       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1958       HloGatherInstruction::MakeGatherDimNumbers(
1959           /*offset_dims=*/{4, 5, 6, 8, 7},
1960           /*collapsed_slice_dims=*/{},
1961           /*start_index_map=*/{0, 1, 2, 3, 4},
1962           /*index_vector_dim=*/4),
1963       /*slice_sizes=*/{30, 29, 28, 27, 26});
1964   ASSERT_FALSE(statusor.ok());
1965   EXPECT_THAT(
1966       statusor.status().error_message(),
1967       HasSubstr("Output window dimensions in gather op must be ascending"))
1968       << statusor.status();
1969 }
1970 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)1971 TEST_F(ScatterGatherShapeInferenceTest,
1972        InvalidGatherDimNumbers_RepeatedWindowIndices) {
1973   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1974       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1975       HloGatherInstruction::MakeGatherDimNumbers(
1976           /*offset_dims=*/{4, 5, 6, 7, 7},
1977           /*collapsed_slice_dims=*/{},
1978           /*start_index_map=*/{0, 1, 2, 3, 4},
1979           /*index_vector_dim=*/4),
1980       /*slice_sizes=*/{30, 29, 28, 27, 26});
1981   ASSERT_FALSE(statusor.ok());
1982   EXPECT_THAT(
1983       statusor.status().error_message(),
1984       HasSubstr("Output window dimensions in gather op must not repeat"))
1985       << statusor.status();
1986 }
1987 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)1988 TEST_F(ScatterGatherShapeInferenceTest,
1989        InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
1990   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1991       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1992       HloGatherInstruction::MakeGatherDimNumbers(
1993           /*offset_dims=*/{4, 5, 99, 100, 101},
1994           /*collapsed_slice_dims=*/{},
1995           /*start_index_map=*/{0, 1, 2, 3, 4},
1996           /*index_vector_dim=*/4),
1997       /*slice_sizes=*/{30, 29, 28, 27, 26});
1998   ASSERT_FALSE(statusor.ok());
1999   EXPECT_THAT(statusor.status().error_message(),
2000               HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2001       << statusor.status();
2002 }
2003 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2004 TEST_F(ScatterGatherShapeInferenceTest,
2005        InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2006   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2007       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2008       HloGatherInstruction::MakeGatherDimNumbers(
2009           /*offset_dims=*/{4, 5, 6, 7, 9},
2010           /*collapsed_slice_dims=*/{},
2011           /*start_index_map=*/{0, 1, 2, 3, 4},
2012           /*index_vector_dim=*/4),
2013       /*slice_sizes=*/{30, 29, 28, 27, 26});
2014   ASSERT_FALSE(statusor.ok());
2015   EXPECT_THAT(statusor.status().error_message(),
2016               HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2017       << statusor.status();
2018 }
2019 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2020 TEST_F(ScatterGatherShapeInferenceTest,
2021        InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2022   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2023       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2024       HloGatherInstruction::MakeGatherDimNumbers(
2025           /*offset_dims=*/{4, 5, 6, 7, 8},
2026           /*collapsed_slice_dims=*/{4},
2027           /*start_index_map=*/{0, 1, 2, 3, 4},
2028           /*index_vector_dim=*/4),
2029       /*slice_sizes=*/{30, 29, 28, 27, 26});
2030   ASSERT_FALSE(statusor.ok());
2031   EXPECT_THAT(
2032       statusor.status().error_message(),
2033       HasSubstr("All components of the offset index in a gather op must either "
2034                 "be a offset dimension or explicitly collapsed"))
2035       << statusor.status();
2036 }
2037 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2038 TEST_F(ScatterGatherShapeInferenceTest,
2039        InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2040   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2041       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2042       HloGatherInstruction::MakeGatherDimNumbers(
2043           /*offset_dims=*/{4, 5, 6, 7, 8},
2044           /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2045           /*start_index_map=*/{0, 1, 2, 3, 4},
2046           /*index_vector_dim=*/4),
2047       /*slice_sizes=*/{30, 29, 28, 27, 26});
2048   ASSERT_FALSE(statusor.ok());
2049   EXPECT_THAT(statusor.status().error_message(),
2050               HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2051                         "range is [0, 5), got: 19"))
2052       << statusor.status();
2053 }
2054 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2055 TEST_F(ScatterGatherShapeInferenceTest,
2056        InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2057   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2058       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2059       HloGatherInstruction::MakeGatherDimNumbers(
2060           /*offset_dims=*/{4, 5, 6, 7, 8},
2061           /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2062           /*start_index_map=*/{0, 1, 2, 3, 4},
2063           /*index_vector_dim=*/4),
2064       /*slice_sizes=*/{30, 29, 28, 27, 26});
2065   ASSERT_FALSE(statusor.ok());
2066   EXPECT_THAT(statusor.status().error_message(),
2067               HasSubstr("Repeated dimensions not allowed in "
2068                         "collapsed_slice_dims in gather op"))
2069       << statusor.status();
2070 }
2071 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2072 TEST_F(ScatterGatherShapeInferenceTest,
2073        InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2074   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2075       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2076       HloGatherInstruction::MakeGatherDimNumbers(
2077           /*offset_dims=*/{4, 5, 6, 7, 8},
2078           /*collapsed_slice_dims=*/{},
2079           /*start_index_map=*/{0, 1, 2, 3},
2080           /*index_vector_dim=*/4),
2081       /*slice_sizes=*/{30, 29, 28, 27, 26});
2082   ASSERT_FALSE(statusor.ok());
2083   EXPECT_THAT(statusor.status().error_message(),
2084               HasSubstr("Gather op has 4 elements in start_index_map and "
2085                         "the bound of dimension index_vector_dim=4 of "
2086                         "start_indices is 5. These two numbers must be equal."))
2087       << statusor.status();
2088 }
2089 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2090 TEST_F(ScatterGatherShapeInferenceTest,
2091        InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2092   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2093       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2094       HloGatherInstruction::MakeGatherDimNumbers(
2095           /*offset_dims=*/{4, 5, 6, 7, 8},
2096           /*collapsed_slice_dims=*/{},
2097           /*start_index_map=*/{0, 1, 2, 3, 7},
2098           /*index_vector_dim=*/4),
2099       /*slice_sizes=*/{30, 29, 28, 27, 26});
2100   ASSERT_FALSE(statusor.ok());
2101   EXPECT_THAT(statusor.status().error_message(),
2102               HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2103       << statusor.status();
2104 }
2105 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2106 TEST_F(ScatterGatherShapeInferenceTest,
2107        InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2108   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2109       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2110       HloGatherInstruction::MakeGatherDimNumbers(
2111           /*offset_dims=*/{4, 5, 6, 7, 8},
2112           /*collapsed_slice_dims=*/{},
2113           /*start_index_map=*/{0, 1, 2, 3, 3},
2114           /*index_vector_dim=*/4),
2115       /*slice_sizes=*/{30, 29, 28, 27, 26});
2116   ASSERT_FALSE(statusor.ok());
2117   EXPECT_THAT(
2118       statusor.status().error_message(),
2119       HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2120       << statusor.status();
2121 }
2122 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2123 TEST_F(ScatterGatherShapeInferenceTest,
2124        InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2125   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2126       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2127       HloGatherInstruction::MakeGatherDimNumbers(
2128           /*offset_dims=*/{4, 5, 6, 7, 8},
2129           /*collapsed_slice_dims=*/{2, 1},
2130           /*start_index_map=*/{0, 1, 2, 3, 4},
2131           /*index_vector_dim=*/4),
2132       /*slice_sizes=*/{1, 1, 28, 27, 26});
2133   ASSERT_FALSE(statusor.ok());
2134   EXPECT_THAT(statusor.status().error_message(),
2135               HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2136       << statusor.status();
2137 }
2138 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2139 TEST_F(ScatterGatherShapeInferenceTest,
2140        InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2141   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2142       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2143       HloGatherInstruction::MakeGatherDimNumbers(
2144           /*offset_dims=*/{4, 5, 6, 7},
2145           /*collapsed_slice_dims=*/{2},
2146           /*start_index_map=*/{0, 1, 2, 3, 4},
2147           /*index_vector_dim=*/4),
2148       /*slice_sizes=*/{30, 29, 1, 300, 26});
2149   ASSERT_FALSE(statusor.ok());
2150   EXPECT_THAT(statusor.status().error_message(),
2151               HasSubstr("Slice size at index 3 in gather op is out of range, "
2152                         "must be within [0, 48), got 300."))
2153       << statusor.status();
2154 }
2155 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2156 TEST_F(ScatterGatherShapeInferenceTest,
2157        InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2158   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2159       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2160       HloGatherInstruction::MakeGatherDimNumbers(
2161           /*offset_dims=*/{4, 5, 6, 7, 8},
2162           /*collapsed_slice_dims=*/{},
2163           /*start_index_map=*/{0, 1, 2, 3, 4},
2164           /*index_vector_dim=*/4),
2165       /*slice_sizes=*/{30, 29, 28, 26});
2166   ASSERT_FALSE(statusor.ok());
2167   EXPECT_THAT(
2168       statusor.status().error_message(),
2169       HasSubstr("Gather op must have one slice size for every input dimension"))
2170       << statusor.status();
2171 }
2172 
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2173 TEST_F(ScatterGatherShapeInferenceTest,
2174        InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2175   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2176       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2177       HloGatherInstruction::MakeGatherDimNumbers(
2178           /*offset_dims=*/{4, 5, 6, 7},
2179           /*collapsed_slice_dims=*/{1},
2180           /*start_index_map=*/{0, 1, 2, 3, 4},
2181           /*index_vector_dim=*/4),
2182       /*slice_sizes=*/{30, 29, 28, 26, 20});
2183   ASSERT_FALSE(statusor.ok());
2184   EXPECT_THAT(statusor.status().error_message(),
2185               HasSubstr("Gather op can only collapse slice dims with bound 1, "
2186                         "but bound is 29 for index 1 at position 0."))
2187       << statusor.status();
2188 }
2189 
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2190 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2191   StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2192       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2193       HloGatherInstruction::MakeGatherDimNumbers(
2194           /*offset_dims=*/{4, 5, 6, 7, 8},
2195           /*collapsed_slice_dims=*/{},
2196           /*start_index_map=*/{0, 1, 2, 3, 4},
2197           /*index_vector_dim=*/32),
2198       /*slice_sizes=*/{30, 29, 28, 27, 26});
2199 
2200   ASSERT_FALSE(statusor.ok());
2201   EXPECT_THAT(statusor.status().error_message(),
2202               HasSubstr("Gather index leaf dimension must be within [0, "
2203                         "rank(start_indices) + 1)"))
2204       << statusor.status();
2205 }
2206 
2207 // Shape inference tests for Scatter.
2208 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdates)2209 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
2210   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2211                           ShapeInference::InferScatterShape(
2212                               matrix_64_48_, s64_vector_32_,
2213                               ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
2214                               HloScatterInstruction::MakeScatterDimNumbers(
2215                                   /*update_window_dims=*/{0},
2216                                   /*inserted_window_dims=*/{1},
2217                                   /*scatter_dims_to_operand_dims=*/{1},
2218                                   /*index_vector_dim=*/1)));
2219   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2220       << ShapeUtil::HumanString(scatter_shape);
2221 }
2222 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdatesV2)2223 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
2224   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2225                           ShapeInference::InferScatterShape(
2226                               matrix_64_48_, s64_vector_32_,
2227                               ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
2228                               HloScatterInstruction::MakeScatterDimNumbers(
2229                                   /*update_window_dims=*/{1},
2230                                   /*inserted_window_dims=*/{0},
2231                                   /*scatter_dims_to_operand_dims=*/{0},
2232                                   /*index_vector_dim=*/1)));
2233   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2234       << ShapeUtil::HumanString(scatter_shape);
2235 }
2236 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdates)2237 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
2238   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2239                           ShapeInference::InferScatterShape(
2240                               matrix_64_48_, s64_vector_32_,
2241                               ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
2242                               HloScatterInstruction::MakeScatterDimNumbers(
2243                                   /*update_window_dims=*/{0},
2244                                   /*inserted_window_dims=*/{1},
2245                                   /*scatter_dims_to_operand_dims=*/{1},
2246                                   /*index_vector_dim=*/1)));
2247   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2248       << ShapeUtil::HumanString(scatter_shape);
2249 }
2250 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdatesV2)2251 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
2252   TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2253                           ShapeInference::InferScatterShape(
2254                               matrix_64_48_, s64_vector_32_,
2255                               ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
2256                               HloScatterInstruction::MakeScatterDimNumbers(
2257                                   /*update_window_dims=*/{1},
2258                                   /*inserted_window_dims=*/{0},
2259                                   /*scatter_dims_to_operand_dims=*/{0},
2260                                   /*index_vector_dim=*/1)));
2261   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2262       << ShapeUtil::HumanString(scatter_shape);
2263 }
2264 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)2265 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
2266   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2267       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
2268       to_apply_,
2269       HloScatterInstruction::MakeScatterDimNumbers(
2270           /*update_window_dims=*/{0},
2271           /*inserted_window_dims=*/{1},
2272           /*scatter_dims_to_operand_dims=*/{1},
2273           /*index_vector_dim=*/1));
2274   ASSERT_FALSE(statusor.ok());
2275   EXPECT_THAT(
2276       statusor.status().error_message(),
2277       HasSubstr("Bounds of the window dimensions of updates must not exceed "
2278                 "the bounds of the corresponding dimensions of operand."))
2279       << statusor.status();
2280 }
2281 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)2282 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
2283   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2284       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
2285       to_apply_,
2286       HloScatterInstruction::MakeScatterDimNumbers(
2287           /*update_window_dims=*/{1},
2288           /*inserted_window_dims=*/{0},
2289           /*scatter_dims_to_operand_dims=*/{1},
2290           /*index_vector_dim=*/1));
2291   ASSERT_FALSE(statusor.ok());
2292   EXPECT_THAT(
2293       statusor.status().error_message(),
2294       HasSubstr("Bounds of the window dimensions of updates must not exceed "
2295                 "the bounds of the corresponding dimensions of operand."))
2296       << statusor.status();
2297 }
2298 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)2299 TEST_F(ScatterGatherShapeInferenceTest,
2300        TfScatterWithUpdatesNotMatchingIndices) {
2301   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2302       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
2303       to_apply_,
2304       HloScatterInstruction::MakeScatterDimNumbers(
2305           /*update_window_dims=*/{0},
2306           /*inserted_window_dims=*/{1},
2307           /*scatter_dims_to_operand_dims=*/{1},
2308           /*index_vector_dim=*/1));
2309   ASSERT_FALSE(statusor.ok());
2310   EXPECT_THAT(
2311       statusor.status().error_message(),
2312       HasSubstr(
2313           "Bounds of the scatter dimensions of updates must be same as the "
2314           "bounds of the corresponding dimensions of scatter indices."))
2315       << statusor.status();
2316 }
2317 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)2318 TEST_F(ScatterGatherShapeInferenceTest,
2319        TfScatterWithUpdatesNotMatchingIndicesV2) {
2320   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2321       matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
2322       to_apply_,
2323       HloScatterInstruction::MakeScatterDimNumbers(
2324           /*update_window_dims=*/{1},
2325           /*inserted_window_dims=*/{0},
2326           /*scatter_dims_to_operand_dims=*/{1},
2327           /*index_vector_dim=*/1));
2328   ASSERT_FALSE(statusor.ok());
2329   EXPECT_THAT(
2330       statusor.status().error_message(),
2331       HasSubstr(
2332           "Bounds of the scatter dimensions of updates must be same as the "
2333           "bounds of the corresponding dimensions of scatter indices."))
2334       << statusor.status();
2335 }
2336 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdates)2337 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
2338   TF_ASSERT_OK_AND_ASSIGN(
2339       Shape scatter_shape,
2340       ShapeInference::InferScatterShape(
2341           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2342           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
2343           HloScatterInstruction::MakeScatterDimNumbers(
2344               /*update_window_dims=*/{4},
2345               /*inserted_window_dims=*/{0},
2346               /*scatter_dims_to_operand_dims=*/{0},
2347               /*index_vector_dim=*/4)));
2348   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2349       << ShapeUtil::HumanString(scatter_shape);
2350 }
2351 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdatesV2)2352 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
2353   TF_ASSERT_OK_AND_ASSIGN(
2354       Shape scatter_shape,
2355       ShapeInference::InferScatterShape(
2356           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2357           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
2358           HloScatterInstruction::MakeScatterDimNumbers(
2359               /*update_window_dims=*/{4},
2360               /*inserted_window_dims=*/{1},
2361               /*scatter_dims_to_operand_dims=*/{0},
2362               /*index_vector_dim=*/4)));
2363   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2364       << ShapeUtil::HumanString(scatter_shape);
2365 }
2366 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdates)2367 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
2368   TF_ASSERT_OK_AND_ASSIGN(
2369       Shape scatter_shape,
2370       ShapeInference::InferScatterShape(
2371           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2372           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
2373           HloScatterInstruction::MakeScatterDimNumbers(
2374               /*update_window_dims=*/{4},
2375               /*inserted_window_dims=*/{0},
2376               /*scatter_dims_to_operand_dims=*/{0},
2377               /*index_vector_dim=*/4)));
2378   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2379       << ShapeUtil::HumanString(scatter_shape);
2380 }
2381 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)2382 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
2383   TF_ASSERT_OK_AND_ASSIGN(
2384       Shape scatter_shape,
2385       ShapeInference::InferScatterShape(
2386           matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2387           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
2388           HloScatterInstruction::MakeScatterDimNumbers(
2389               /*update_window_dims=*/{4},
2390               /*inserted_window_dims=*/{1},
2391               /*scatter_dims_to_operand_dims=*/{0},
2392               /*index_vector_dim=*/4)));
2393   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2394       << ShapeUtil::HumanString(scatter_shape);
2395 }
2396 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)2397 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
2398   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2399       matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2400       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
2401       HloScatterInstruction::MakeScatterDimNumbers(
2402           /*update_window_dims=*/{4},
2403           /*inserted_window_dims=*/{1},
2404           /*scatter_dims_to_operand_dims=*/{0},
2405           /*index_vector_dim=*/4));
2406   ASSERT_FALSE(statusor.ok());
2407   EXPECT_THAT(
2408       statusor.status().error_message(),
2409       HasSubstr("Bounds of the window dimensions of updates must not exceed "
2410                 "the bounds of the corresponding dimensions of operand."))
2411       << statusor.status();
2412 }
2413 
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)2414 TEST_F(ScatterGatherShapeInferenceTest,
2415        TfScatterNdWithUpdatesNotMatchingIndices) {
2416   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2417       matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2418       ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
2419       HloScatterInstruction::MakeScatterDimNumbers(
2420           /*update_window_dims=*/{4},
2421           /*inserted_window_dims=*/{1},
2422           /*scatter_dims_to_operand_dims=*/{0},
2423           /*index_vector_dim=*/4));
2424   ASSERT_FALSE(statusor.ok());
2425   EXPECT_THAT(
2426       statusor.status().error_message(),
2427       HasSubstr(
2428           "Bounds of the scatter dimensions of updates must be same as the "
2429           "bounds of the corresponding dimensions of scatter indices."))
2430       << statusor.status();
2431 }
2432 
TEST_F(ScatterGatherShapeInferenceTest,TfBatchDynamicUpdateSlice)2433 TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
2434   TF_ASSERT_OK_AND_ASSIGN(
2435       Shape scatter_shape,
2436       ShapeInference::InferScatterShape(
2437           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2438           ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
2439           to_apply_,
2440           HloScatterInstruction::MakeScatterDimNumbers(
2441               /*update_window_dims=*/{4, 5, 6, 7, 8},
2442               /*inserted_window_dims=*/{},
2443               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2444               /*index_vector_dim=*/4)));
2445   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2446       << ShapeUtil::HumanString(scatter_shape);
2447 }
2448 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDim)2449 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
2450   TF_ASSERT_OK_AND_ASSIGN(
2451       Shape scatter_shape,
2452       ShapeInference::InferScatterShape(
2453           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2454           ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
2455           to_apply_,
2456           HloScatterInstruction::MakeScatterDimNumbers(
2457               /*update_window_dims=*/{4, 5, 6, 7, 8},
2458               /*inserted_window_dims=*/{},
2459               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2460               /*index_vector_dim=*/2)));
2461 
2462   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2463       << ShapeUtil::HumanString(scatter_shape);
2464 }
2465 
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)2466 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
2467   TF_ASSERT_OK_AND_ASSIGN(
2468       Shape scatter_shape,
2469       ShapeInference::InferScatterShape(
2470           f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2471           ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
2472           to_apply_,
2473           HloScatterInstruction::MakeScatterDimNumbers(
2474               /*update_window_dims=*/{4, 5, 6, 7, 8},
2475               /*inserted_window_dims=*/{},
2476               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2477               /*index_vector_dim=*/0)));
2478 
2479   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2480       << ShapeUtil::HumanString(scatter_shape);
2481 }
2482 
TEST_F(ScatterGatherShapeInferenceTest,NoUpdateScatterDims)2483 TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
2484   // This is equivalent to a dynamic update slice.
2485   TF_ASSERT_OK_AND_ASSIGN(
2486       Shape scatter_shape,
2487       ShapeInference::InferScatterShape(
2488           f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2489           ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
2490           HloScatterInstruction::MakeScatterDimNumbers(
2491               /*update_window_dims=*/{0, 1, 2, 3, 4},
2492               /*inserted_window_dims=*/{},
2493               /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2494               /*index_vector_dim=*/0)));
2495 
2496   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2497       << ShapeUtil::HumanString(scatter_shape);
2498 }
2499 
TEST_F(ScatterGatherShapeInferenceTest,ScalarScatterIndices)2500 TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
2501   // The scalar indices "tensor" is a scalar S here that's used to update a
2502   // [30,29,28,27] shaped tensor within the operand at position S.
2503   TF_ASSERT_OK_AND_ASSIGN(
2504       Shape scatter_shape,
2505       ShapeInference::InferScatterShape(
2506           f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2507           ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
2508           HloScatterInstruction::MakeScatterDimNumbers(
2509               /*update_window_dims=*/{0, 1, 2, 3},
2510               /*inserted_window_dims=*/{0},
2511               /*scatter_dims_to_operand_dims=*/{0},
2512               /*index_vector_dim=*/0)));
2513 
2514   EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2515       << ShapeUtil::HumanString(scatter_shape);
2516 }
2517 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedTensorInput)2518 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
2519   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2520       tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
2521       HloScatterInstruction::MakeScatterDimNumbers(
2522           /*update_window_dims=*/{0},
2523           /*inserted_window_dims=*/{1},
2524           /*scatter_dims_to_operand_dims=*/{1},
2525           /*index_vector_dim=*/1));
2526   ASSERT_FALSE(statusor.ok());
2527   EXPECT_THAT(statusor.status().error_message(),
2528               HasSubstr("Expected array argument for operand"))
2529       << statusor.status();
2530 }
2531 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)2532 TEST_F(ScatterGatherShapeInferenceTest,
2533        ScatterWithTupleShapedScatterIndicesInput) {
2534   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2535       s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
2536       HloScatterInstruction::MakeScatterDimNumbers(
2537           /*update_window_dims=*/{0},
2538           /*inserted_window_dims=*/{1},
2539           /*scatter_dims_to_operand_dims=*/{1},
2540           /*index_vector_dim=*/0));
2541   ASSERT_FALSE(statusor.ok());
2542   EXPECT_THAT(statusor.status().error_message(),
2543               HasSubstr("Expected array argument for scatter indices"))
2544       << statusor.status();
2545 }
2546 
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)2547 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
2548   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2549       s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
2550       HloScatterInstruction::MakeScatterDimNumbers(
2551           /*update_window_dims=*/{0},
2552           /*inserted_window_dims=*/{1},
2553           /*scatter_dims_to_operand_dims=*/{1},
2554           /*index_vector_dim=*/0));
2555   ASSERT_FALSE(statusor.ok());
2556   EXPECT_THAT(statusor.status().error_message(),
2557               HasSubstr("Expected array argument for updates"))
2558       << statusor.status();
2559 }
2560 
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointScatterIndicesInput)2561 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
2562   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2563       s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
2564       HloScatterInstruction::MakeScatterDimNumbers(
2565           /*update_window_dims=*/{0},
2566           /*inserted_window_dims=*/{1},
2567           /*scatter_dims_to_operand_dims=*/{1},
2568           /*index_vector_dim=*/0));
2569   ASSERT_FALSE(statusor.ok());
2570   EXPECT_THAT(statusor.status().error_message(),
2571               HasSubstr("Scatter indices parameter must be an integral tensor"))
2572       << statusor.status();
2573 }
2574 
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)2575 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
2576   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2577       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2578       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2579       HloScatterInstruction::MakeScatterDimNumbers(
2580           /*update_window_dims=*/{4, 5, 6},
2581           /*inserted_window_dims=*/{1, 2},
2582           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2583           /*index_vector_dim=*/10));
2584   ASSERT_FALSE(statusor.ok());
2585   EXPECT_THAT(statusor.status().error_message(),
2586               HasSubstr("Scatter index leaf dimension must be within [0, "
2587                         "rank(scatter_indices) + 1)"))
2588       << statusor.status();
2589 }
2590 
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdates)2591 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
2592   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2593       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2594       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
2595       HloScatterInstruction::MakeScatterDimNumbers(
2596           /*update_window_dims=*/{4, 5, 6},
2597           /*inserted_window_dims=*/{1, 2},
2598           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2599           /*index_vector_dim=*/4));
2600   ASSERT_FALSE(statusor.ok());
2601   EXPECT_THAT(statusor.status().error_message(),
2602               HasSubstr("Updates tensor must be of rank 7; got 8."))
2603       << statusor.status();
2604 }
2605 
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdateComputation)2606 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
2607   const ProgramShape invalid_update_computation =
2608       ShapeUtil::MakeProgramShape({f32_}, f32_);
2609   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2610       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2611       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
2612       invalid_update_computation,
2613       HloScatterInstruction::MakeScatterDimNumbers(
2614           /*update_window_dims=*/{4, 5, 6},
2615           /*inserted_window_dims=*/{1, 2},
2616           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2617           /*index_vector_dim=*/4));
2618   ASSERT_FALSE(statusor.ok());
2619   EXPECT_THAT(
2620       statusor.status().error_message(),
2621       HasSubstr("Reduction function must take 2 parameters, but takes 1"))
2622       << statusor.status();
2623 }
2624 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)2625 TEST_F(ScatterGatherShapeInferenceTest,
2626        InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
2627   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2628       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2629       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
2630       HloScatterInstruction::MakeScatterDimNumbers(
2631           /*update_window_dims=*/{4, 5, 6, 8, 7},
2632           /*inserted_window_dims=*/{},
2633           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2634           /*index_vector_dim=*/4));
2635   ASSERT_FALSE(statusor.ok());
2636   EXPECT_THAT(statusor.status().error_message(),
2637               HasSubstr("update_window_dims in scatter op must be sorted"))
2638       << statusor.status();
2639 }
2640 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)2641 TEST_F(ScatterGatherShapeInferenceTest,
2642        InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
2643   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2644       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2645       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
2646       HloScatterInstruction::MakeScatterDimNumbers(
2647           /*update_window_dims=*/{4, 5, 6, 7, 7},
2648           /*inserted_window_dims=*/{},
2649           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2650           /*index_vector_dim=*/4));
2651   ASSERT_FALSE(statusor.ok());
2652   EXPECT_THAT(statusor.status().error_message(),
2653               HasSubstr("update_window_dims in scatter op must not repeat"))
2654       << statusor.status();
2655 }
2656 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)2657 TEST_F(ScatterGatherShapeInferenceTest,
2658        InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
2659   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2660       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2661       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
2662       HloScatterInstruction::MakeScatterDimNumbers(
2663           /*update_window_dims=*/{4, 5, 6, 7, 9},
2664           /*inserted_window_dims=*/{},
2665           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2666           /*index_vector_dim=*/4));
2667   ASSERT_FALSE(statusor.ok());
2668   EXPECT_THAT(statusor.status().error_message(),
2669               HasSubstr("Invalid update_window_dims set in scatter op; valid "
2670                         "range is [0, 9)"))
2671       << statusor.status();
2672 }
2673 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)2674 TEST_F(ScatterGatherShapeInferenceTest,
2675        InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
2676   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2677       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2678       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2679       HloScatterInstruction::MakeScatterDimNumbers(
2680           /*update_window_dims=*/{4, 5, 6},
2681           /*inserted_window_dims=*/{2, 1},
2682           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2683           /*index_vector_dim=*/4));
2684   ASSERT_FALSE(statusor.ok());
2685   EXPECT_THAT(statusor.status().error_message(),
2686               HasSubstr("inserted_window_dims in scatter op must be sorted"))
2687       << statusor.status();
2688 }
2689 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)2690 TEST_F(ScatterGatherShapeInferenceTest,
2691        InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
2692   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2693       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2694       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2695       HloScatterInstruction::MakeScatterDimNumbers(
2696           /*update_window_dims=*/{4, 5, 6},
2697           /*inserted_window_dims=*/{1, 1},
2698           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2699           /*index_vector_dim=*/4));
2700   ASSERT_FALSE(statusor.ok());
2701   EXPECT_THAT(statusor.status().error_message(),
2702               HasSubstr("inserted_window_dims in scatter op must not repeat"))
2703       << statusor.status();
2704 }
2705 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)2706 TEST_F(ScatterGatherShapeInferenceTest,
2707        InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
2708   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2709       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2710       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2711       HloScatterInstruction::MakeScatterDimNumbers(
2712           /*update_window_dims=*/{4, 5, 6},
2713           /*inserted_window_dims=*/{1, 5},
2714           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2715           /*index_vector_dim=*/4));
2716   ASSERT_FALSE(statusor.ok());
2717   EXPECT_THAT(statusor.status().error_message(),
2718               HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
2719                         "range is [0, 5)"))
2720       << statusor.status();
2721 }
2722 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)2723 TEST_F(ScatterGatherShapeInferenceTest,
2724        InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
2725   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2726       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2727       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2728       HloScatterInstruction::MakeScatterDimNumbers(
2729           /*update_window_dims=*/{4, 5, 6},
2730           /*inserted_window_dims=*/{1, 2},
2731           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
2732           /*index_vector_dim=*/4));
2733   ASSERT_FALSE(statusor.ok());
2734   EXPECT_THAT(
2735       statusor.status().error_message(),
2736       HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
2737                 "the bound of dimension index_vector_dim=4 of scatter_indices "
2738                 "is 5. These two numbers must be equal"))
2739       << statusor.status();
2740 }
2741 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)2742 TEST_F(ScatterGatherShapeInferenceTest,
2743        InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
2744   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2745       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2746       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2747       HloScatterInstruction::MakeScatterDimNumbers(
2748           /*update_window_dims=*/{4, 5, 6},
2749           /*inserted_window_dims=*/{1, 2},
2750           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
2751           /*index_vector_dim=*/4));
2752   ASSERT_FALSE(statusor.ok());
2753   EXPECT_THAT(statusor.status().error_message(),
2754               HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
2755                         "is [0, 5), got: 4->10"))
2756       << statusor.status();
2757 }
2758 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)2759 TEST_F(ScatterGatherShapeInferenceTest,
2760        InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
2761   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2762       f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2763       ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2764       HloScatterInstruction::MakeScatterDimNumbers(
2765           /*update_window_dims=*/{4, 5, 6},
2766           /*inserted_window_dims=*/{1, 2},
2767           /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
2768           /*index_vector_dim=*/4));
2769   ASSERT_FALSE(statusor.ok());
2770   EXPECT_THAT(
2771       statusor.status().error_message(),
2772       HasSubstr(
2773           "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
2774       << statusor.status();
2775 }
2776 
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)2777 TEST_F(ScatterGatherShapeInferenceTest,
2778        InvalidScatterDimNumbers_InsufficientWindowDims) {
2779   StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2780       f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2781       ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
2782       HloScatterInstruction::MakeScatterDimNumbers(
2783           /*update_window_dims=*/{0, 1, 2, 3},
2784           /*inserted_window_dims=*/{},
2785           /*scatter_dims_to_operand_dims=*/{0},
2786           /*index_vector_dim=*/0));
2787   ASSERT_FALSE(statusor.ok());
2788   EXPECT_THAT(
2789       statusor.status().error_message(),
2790       HasSubstr(
2791           "Scatter op has window of size 4; doesn't match operand of rank 5."))
2792       << statusor.status();
2793 }
2794 
2795 }  // namespace
2796 }  // namespace xla
2797