1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/shape_inference.h"
17
18 #include <string>
19
20 #include "absl/types/span.h"
21 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/compiler/xla/test_helpers.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace xla {
29 namespace {
30
31 using ::testing::ContainsRegex;
32 using ::testing::HasSubstr;
33
34 class ShapeInferenceTest : public ::testing::Test {
35 protected:
36 // Some handy scalar shapes.
37 const Shape s32_ = ShapeUtil::MakeShape(S32, {});
38 const Shape f16_ = ShapeUtil::MakeShape(F16, {});
39 const Shape f32_ = ShapeUtil::MakeShape(F32, {});
40 const Shape f64_ = ShapeUtil::MakeShape(F64, {});
41 const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
42
43 // Some handy vector and matrix shapes of F32 type.
44 // Suffix: vector_length_, matrix_rows_cols_
45 const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
46 const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
47 const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
48 const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
49 const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
50
51 // Some handy S32 arrays.
52 const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
53 };
54
55 // Subclass for testing InferReduceShape.
56 class ReduceShapeInferenceTest : public ShapeInferenceTest {
57 protected:
58 // Helper that runs reduce shape inference with the input 'arg' and given
59 // dimensions to reduce, and checks the inferred shape is as expected. The
60 // element type here is hard-coded to F32.
ExpectInferredReduceShape(const Shape & expected_inferred_shape,const Shape & arg,absl::Span<const int64> dimensions_to_reduce)61 void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
62 const Shape& arg,
63 absl::Span<const int64> dimensions_to_reduce) {
64 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
65 auto inferred_status = ShapeInference::InferReduceShape(
66 {&arg, &f32_}, dimensions_to_reduce, to_apply);
67 EXPECT_IS_OK(inferred_status.status());
68 EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
69 inferred_status.ValueOrDie()));
70 }
71 };
72
73 // Subclass for testing InferSelectAndScatterShape.
74 class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
75 protected:
SelectAndScatterShapeInferenceTest()76 SelectAndScatterShapeInferenceTest() {
77 operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
78 source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
79 WindowDimension dim;
80 dim.set_size(2);
81 dim.set_stride(2);
82 dim.set_padding_low(0);
83 dim.set_padding_high(0);
84 dim.set_window_dilation(1);
85 dim.set_base_dilation(1);
86 *window_.add_dimensions() = dim;
87 *window_.add_dimensions() = dim;
88 init_value_shape_ = ShapeUtil::MakeShape(F32, {});
89 select_program_shape_ = ShapeUtil::MakeProgramShape(
90 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
91 scatter_program_shape_ = ShapeUtil::MakeProgramShape(
92 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
93 }
94
95 Shape operand_shape_;
96 Shape source_shape_;
97 Window window_;
98 Shape init_value_shape_;
99 ProgramShape select_program_shape_;
100 ProgramShape scatter_program_shape_;
101 };
102
TEST_F(ShapeInferenceTest,UnaryNegateMatrix)103 TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
104 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
105 auto inferred_status =
106 ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
107 ASSERT_IS_OK(inferred_status.status());
108 ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
109 }
110
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenTuples)111 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
112 Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
113 auto inferred_status = ShapeInference::InferTernaryOpShape(
114 HloOpcode::kSelect, pred_, tuple, tuple);
115 ASSERT_IS_OK(inferred_status.status());
116 ASSERT_TRUE(ShapeUtil::Equal(tuple, inferred_status.ValueOrDie()));
117 }
118
TEST_F(ShapeInferenceTest,SelectScalarPredBetweenArrays)119 TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
120 auto inferred_status = ShapeInference::InferTernaryOpShape(
121 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
122 ASSERT_IS_OK(inferred_status.status());
123 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
124 }
125
TEST_F(ShapeInferenceTest,SelectArrayPredBetweenArrays)126 TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
127 auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
128 auto inferred_status = ShapeInference::InferTernaryOpShape(
129 HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
130 ASSERT_IS_OK(inferred_status.status());
131 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
132 }
133
TEST_F(ShapeInferenceTest,SelectBadShapes)134 TEST_F(ShapeInferenceTest, SelectBadShapes) {
135 auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
136 HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
137 ASSERT_FALSE(inferred_status_error1.ok());
138 ASSERT_THAT(inferred_status_error1.status().error_message(),
139 HasSubstr("Operands to select must be the same shape"));
140
141 auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
142 HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
143 ASSERT_FALSE(inferred_status_error2.ok());
144 ASSERT_THAT(inferred_status_error2.status().error_message(),
145 HasSubstr("pred operand must have PRED"));
146
147 auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
148 HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
149 matrix_64_48_);
150 ASSERT_FALSE(inferred_status_error3.ok());
151 ASSERT_THAT(inferred_status_error3.status().error_message(),
152 HasSubstr("with non-scalar predicate with dimensionality"));
153
154 // Tuples have a TUPLE element type and cannot be the pred of a select.
155 auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
156 HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
157 ShapeUtil::MakeTupleShape({f32_, f32_}),
158 ShapeUtil::MakeTupleShape({f32_, f32_}));
159 ASSERT_FALSE(inferred_status_error4.ok());
160 ASSERT_THAT(inferred_status_error4.status().error_message(),
161 HasSubstr("pred operand must have PRED element type"));
162 }
163
TEST_F(ShapeInferenceTest,ClampAllMatrix)164 TEST_F(ShapeInferenceTest, ClampAllMatrix) {
165 auto inferred_status = ShapeInference::InferTernaryOpShape(
166 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
167 ASSERT_IS_OK(inferred_status.status());
168 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
169 }
170
TEST_F(ShapeInferenceTest,ClampAllScalar)171 TEST_F(ShapeInferenceTest, ClampAllScalar) {
172 auto inferred_status =
173 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
174 ASSERT_IS_OK(inferred_status.status());
175 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
176 }
177
TEST_F(ShapeInferenceTest,ClampMinScalar)178 TEST_F(ShapeInferenceTest, ClampMinScalar) {
179 auto inferred_status = ShapeInference::InferTernaryOpShape(
180 HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
181 ASSERT_IS_OK(inferred_status.status());
182 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
183 }
184
TEST_F(ShapeInferenceTest,ClampMaxScalar)185 TEST_F(ShapeInferenceTest, ClampMaxScalar) {
186 auto inferred_status = ShapeInference::InferTernaryOpShape(
187 HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
188 ASSERT_IS_OK(inferred_status.status());
189 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
190 }
191
TEST_F(ShapeInferenceTest,ClampOperandScalar)192 TEST_F(ShapeInferenceTest, ClampOperandScalar) {
193 auto inferred_status = ShapeInference::InferTernaryOpShape(
194 HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
195 ASSERT_IS_OK(inferred_status.status());
196 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
197 }
198
TEST_F(ShapeInferenceTest,ClampMinMatrix)199 TEST_F(ShapeInferenceTest, ClampMinMatrix) {
200 auto inferred_status = ShapeInference::InferTernaryOpShape(
201 HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
202 ASSERT_IS_OK(inferred_status.status());
203 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
204 }
205
TEST_F(ShapeInferenceTest,ClampMaxMatrix)206 TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
207 auto inferred_status = ShapeInference::InferTernaryOpShape(
208 HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
209 ASSERT_IS_OK(inferred_status.status());
210 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
211 }
212
TEST_F(ShapeInferenceTest,ClampOperandMatrix)213 TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
214 auto inferred_status = ShapeInference::InferTernaryOpShape(
215 HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
216 ASSERT_IS_OK(inferred_status.status());
217 ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
218 }
219
TEST_F(ShapeInferenceTest,ClampBadShapes)220 TEST_F(ShapeInferenceTest, ClampBadShapes) {
221 // Type mismatch
222 ASSERT_FALSE(
223 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
224 .ok());
225 ASSERT_FALSE(
226 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
227 .ok());
228 ASSERT_FALSE(
229 ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
230 .ok());
231 // Dimension mismatch
232 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
233 HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
234 .ok());
235 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
236 HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
237 .ok());
238 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
239 HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
240 .ok());
241 // Dimension mismatch, where one operand is a scalar
242 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
243 vector_64_, vector_32_, f32_)
244 .ok());
245 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
246 vector_64_, f32_, vector_32_)
247 .ok());
248 ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
249 vector_64_, vector_32_)
250 .ok());
251 }
252
TEST_F(ShapeInferenceTest,Complex)253 TEST_F(ShapeInferenceTest, Complex) {
254 auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
255 absl::Span<const int64> bcast) {
256 return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
257 bcast);
258 };
259 // Inputs must be FP.
260 ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
261 ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
262 // Component types must match.
263 ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
264 // Only F32->C64 and F64->C128 supported.
265 ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
266 // Validate correct uses.
267 Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
268 TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
269 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
270 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
271 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
272 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
273 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
274 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
275 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
276
277 Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
278 TF_ASSERT_OK_AND_ASSIGN(result,
279 complex_shape(vector_64_, matrix_32_64_, {1}));
280 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
281 TF_ASSERT_OK_AND_ASSIGN(result,
282 complex_shape(matrix_32_64_, vector_64_, {1}));
283 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
284 TF_ASSERT_OK_AND_ASSIGN(result,
285 complex_shape(matrix_32_64_, matrix_32_64_, {}));
286 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
287 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
288 ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
289
290 TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
291 ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
292 }
293
TEST_F(ShapeInferenceTest,VariadicOpTuplify)294 TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
295 StatusOr<Shape> result =
296 ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
297 ASSERT_IS_OK(result.status());
298 ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
299 ShapeUtil::MakeTupleShape({s32_, f32_})));
300 }
301
TEST_F(ShapeInferenceTest,ReduceWindowInHalf)302 TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
303 Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
304 Window window;
305 WindowDimension dim;
306 dim.set_size(2);
307 dim.set_stride(2);
308 dim.set_padding_low(0);
309 dim.set_padding_high(0);
310 dim.set_window_dilation(1);
311 dim.set_base_dilation(1);
312 *window.add_dimensions() = dim;
313 *window.add_dimensions() = dim;
314 Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
315 Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
316 Shape float_scalar = ShapeUtil::MakeShape(F32, {});
317 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
318 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
319 auto inferred_status = ShapeInference::InferReduceWindowShape(
320 matrix_shape, init_value_shape, window, to_apply);
321
322 ASSERT_IS_OK(inferred_status.status());
323 Shape inferred = inferred_status.ValueOrDie();
324 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
325 }
326
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterProperShapes)327 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
328 auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
329 operand_shape_, select_program_shape_, window_, source_shape_,
330 init_value_shape_, scatter_program_shape_);
331 ASSERT_IS_OK(inferred_status_ok.status());
332 Shape inferred = inferred_status_ok.ValueOrDie();
333 ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
334 }
335
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSourceShape)336 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
337 Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
338 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
339 operand_shape_, select_program_shape_, window_, source_shape_fail,
340 init_value_shape_, scatter_program_shape_);
341 ASSERT_FALSE(inferred_status_fail.ok());
342 ASSERT_THAT(inferred_status_fail.status().error_message(),
343 HasSubstr("Source shape does not match"));
344 }
345
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape1)346 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
347 ProgramShape select_program_shape_fail =
348 ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
349 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
350 operand_shape_, select_program_shape_fail, window_, source_shape_,
351 init_value_shape_, scatter_program_shape_);
352 ASSERT_FALSE(inferred_status_fail.ok());
353 ASSERT_THAT(inferred_status_fail.status().error_message(),
354 HasSubstr("Select function must take 2 parameters"));
355 }
356
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape2)357 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
358 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
359 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
360 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
361 operand_shape_, select_program_shape_fail, window_, source_shape_,
362 init_value_shape_, scatter_program_shape_);
363 ASSERT_FALSE(inferred_status_fail.ok());
364 ASSERT_THAT(inferred_status_fail.status().error_message(),
365 HasSubstr("Select function must have rank-0 PRED"));
366 }
367
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape3)368 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
369 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
370 {ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
371 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
372 operand_shape_, select_program_shape_fail, window_, source_shape_,
373 init_value_shape_, scatter_program_shape_);
374 ASSERT_FALSE(inferred_status_fail.ok());
375 ASSERT_THAT(inferred_status_fail.status().error_message(),
376 HasSubstr("Select function's first parameter"));
377 }
378
TEST_F(SelectAndScatterShapeInferenceTest,SelectAndScatterWrongSelectShape4)379 TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
380 ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
381 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
382 auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
383 operand_shape_, select_program_shape_fail, window_, source_shape_,
384 init_value_shape_, scatter_program_shape_);
385 ASSERT_FALSE(inferred_status_fail.ok());
386 ASSERT_THAT(inferred_status_fail.status().error_message(),
387 HasSubstr("Select function's second parameter"));
388 }
389
TEST_F(ShapeInferenceTest,Convolve)390 TEST_F(ShapeInferenceTest, Convolve) {
391 ConvolutionDimensionNumbers dnums;
392
393 // Dimension order: batch, feature, x0, x1
394 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
395 dnums.set_input_batch_dimension(0);
396 dnums.set_output_batch_dimension(0);
397 dnums.set_input_feature_dimension(1);
398 dnums.set_output_feature_dimension(1);
399 dnums.add_input_spatial_dimensions(2);
400 dnums.add_output_spatial_dimensions(2);
401 dnums.add_input_spatial_dimensions(3);
402 dnums.add_output_spatial_dimensions(3);
403
404 // Dimension order: x1, batch, feature, x0
405 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
406 dnums.set_kernel_input_feature_dimension(2);
407 dnums.set_kernel_output_feature_dimension(1);
408 dnums.add_kernel_spatial_dimensions(3);
409 dnums.add_kernel_spatial_dimensions(0);
410
411 Window window;
412 auto dim0 = window.add_dimensions();
413 auto dim1 = window.add_dimensions();
414 dim0->set_size(3);
415 dim0->set_stride(2);
416 dim0->set_padding_low(1);
417 dim0->set_padding_high(1);
418 dim0->set_window_dilation(1);
419 dim0->set_base_dilation(1);
420 dim1->set_size(2);
421 dim1->set_stride(1);
422 dim1->set_padding_low(0);
423 dim1->set_padding_high(0);
424 dim1->set_window_dilation(1);
425 dim1->set_base_dilation(1);
426 auto inferred_status = ShapeInference::InferConvolveShape(
427 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
428 window, dnums);
429 ASSERT_IS_OK(inferred_status.status());
430 Shape inferred_shape = inferred_status.ValueOrDie();
431 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
432 inferred_shape));
433 }
434
TEST_F(ShapeInferenceTest,ConvolveWithWindowDilation)435 TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
436 ConvolutionDimensionNumbers dnums;
437
438 // Dimension order: batch, feature, x0, x1
439 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
440 dnums.set_input_batch_dimension(0);
441 dnums.set_output_batch_dimension(0);
442 dnums.set_input_feature_dimension(1);
443 dnums.set_output_feature_dimension(1);
444 dnums.add_input_spatial_dimensions(2);
445 dnums.add_output_spatial_dimensions(2);
446 dnums.add_input_spatial_dimensions(3);
447 dnums.add_output_spatial_dimensions(3);
448
449 // Dimension order: x1, batch, feature, x0
450 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
451 dnums.set_kernel_input_feature_dimension(2);
452 dnums.set_kernel_output_feature_dimension(1);
453 dnums.add_kernel_spatial_dimensions(3);
454 dnums.add_kernel_spatial_dimensions(0);
455
456 Window window;
457 auto dim0 = window.add_dimensions();
458 dim0->set_size(3);
459 dim0->set_stride(3);
460 dim0->set_padding_low(0);
461 dim0->set_padding_high(0);
462 dim0->set_window_dilation(6);
463 dim0->set_base_dilation(1);
464
465 auto dim1 = window.add_dimensions();
466 dim1->set_size(2);
467 dim1->set_stride(1);
468 dim1->set_padding_low(2);
469 dim1->set_padding_high(1);
470 dim1->set_window_dilation(2);
471 dim1->set_base_dilation(1);
472 auto inferred_status = ShapeInference::InferConvolveShape(
473 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
474 window, dnums);
475 ASSERT_IS_OK(inferred_status.status());
476 Shape inferred_shape = inferred_status.ValueOrDie();
477 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
478 inferred_shape));
479 }
480
TEST_F(ShapeInferenceTest,ConvolveWithBaseDilation)481 TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
482 ConvolutionDimensionNumbers dnums;
483
484 // Dimension order: batch, feature, x0, x1
485 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
486 dnums.set_input_batch_dimension(0);
487 dnums.set_output_batch_dimension(0);
488 dnums.set_input_feature_dimension(1);
489 dnums.set_output_feature_dimension(1);
490 dnums.add_input_spatial_dimensions(2);
491 dnums.add_output_spatial_dimensions(2);
492 dnums.add_input_spatial_dimensions(3);
493 dnums.add_output_spatial_dimensions(3);
494
495 // Dimension order: x1, batch, feature, x0
496 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
497 dnums.set_kernel_input_feature_dimension(2);
498 dnums.set_kernel_output_feature_dimension(1);
499 dnums.add_kernel_spatial_dimensions(3);
500 dnums.add_kernel_spatial_dimensions(0);
501
502 Window window;
503 auto dim0 = window.add_dimensions();
504 dim0->set_size(4);
505 dim0->set_stride(3);
506 dim0->set_padding_low(0);
507 dim0->set_padding_high(0);
508 dim0->set_window_dilation(1);
509 dim0->set_base_dilation(6);
510
511 auto dim1 = window.add_dimensions();
512 dim1->set_size(2);
513 dim1->set_stride(1);
514 dim1->set_padding_low(2);
515 dim1->set_padding_high(1);
516 dim1->set_window_dilation(1);
517 dim1->set_base_dilation(2);
518 auto inferred_status = ShapeInference::InferConvolveShape(
519 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
520 window, dnums);
521 ASSERT_IS_OK(inferred_status.status());
522 Shape inferred_shape = inferred_status.ValueOrDie();
523 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
524 inferred_shape));
525 }
526
TEST_F(ShapeInferenceTest,ConvolveDimensionNumbersOverlapError)527 TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
528 // Dimension order for this test: batch, feature, x0, x1
529 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
530 Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
531
532 ConvolutionDimensionNumbers dnums;
533 dnums.set_input_batch_dimension(3);
534 dnums.set_output_batch_dimension(3);
535 dnums.set_input_feature_dimension(2);
536 dnums.set_output_feature_dimension(2);
537 dnums.add_input_spatial_dimensions(0);
538 dnums.add_output_spatial_dimensions(0);
539 dnums.add_input_spatial_dimensions(1);
540 dnums.add_output_spatial_dimensions(1);
541 dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
542 dnums.set_kernel_output_feature_dimension(3);
543 dnums.add_kernel_spatial_dimensions(0);
544 dnums.add_kernel_spatial_dimensions(1);
545
546 Window window;
547 auto dim0 = window.add_dimensions();
548 auto dim1 = window.add_dimensions();
549 dim0->set_size(2);
550 dim0->set_stride(1);
551 dim0->set_padding_low(0);
552 dim0->set_padding_high(0);
553 dim1->set_size(3);
554 dim1->set_stride(2);
555 dim1->set_padding_low(1);
556 dim1->set_padding_high(1);
557 auto inferred_status = ShapeInference::InferConvolveShape(
558 lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
559 window, dnums);
560 ASSERT_FALSE(inferred_status.ok());
561 ASSERT_THAT(inferred_status.status().error_message(),
562 HasSubstr("each dimension exactly once"));
563 }
564
TEST_F(ShapeInferenceTest,MapThatChangesElementType)565 TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
566 Shape arg = ShapeUtil::MakeShape(F32, {20});
567 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
568 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
569 EXPECT_IS_OK(inferred_status.status());
570 Shape expected = ShapeUtil::MakeShape(S32, {20});
571 EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
572 }
573
TEST_F(ShapeInferenceTest,Map)574 TEST_F(ShapeInferenceTest, Map) {
575 auto inferred_status_r1f32 = ShapeInference::InferMapShape(
576 {&vector_32_, &vector_32_},
577 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
578 EXPECT_IS_OK(inferred_status_r1f32.status());
579 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
580
581 // It's OK to provide a single argument, as long as the applied arity matches
582 // (this degenerates to a Map).
583 auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
584 {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
585 EXPECT_IS_OK(inferred_status_r1f32_one.status());
586 EXPECT_TRUE(
587 ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
588
589 auto inferred_status_r2s32 = ShapeInference::InferMapShape(
590 {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
591 ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
592 EXPECT_IS_OK(inferred_status_r2s32.status());
593 EXPECT_TRUE(
594 ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
595
596 auto no_args_error = ShapeInference::InferMapShape(
597 {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
598 ASSERT_FALSE(no_args_error.ok());
599 ASSERT_THAT(no_args_error.status().error_message(),
600 HasSubstr("expects at least one argument"));
601
602 auto args_diff_shapes_error = ShapeInference::InferMapShape(
603 {&vector_32_, &vector_64_},
604 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
605 ASSERT_FALSE(args_diff_shapes_error.ok());
606 ASSERT_THAT(args_diff_shapes_error.status().error_message(),
607 HasSubstr("requires all operands to have the same shape"));
608
609 auto arity_error = ShapeInference::InferMapShape(
610 {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
611 {0});
612 ASSERT_FALSE(arity_error.ok());
613 ASSERT_THAT(arity_error.status().error_message(),
614 HasSubstr("function arity must match"));
615
616 auto output_shape_error = ShapeInference::InferMapShape(
617 {&vector_32_, &vector_32_},
618 ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
619 ASSERT_FALSE(output_shape_error.ok());
620 ASSERT_THAT(output_shape_error.status().error_message(),
621 HasSubstr("result has to be a scalar"));
622
623 auto param_shape_error = ShapeInference::InferMapShape(
624 {&vector_32_, &vector_32_},
625 ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
626 ASSERT_FALSE(param_shape_error.ok());
627 ASSERT_THAT(param_shape_error.status().error_message(),
628 HasSubstr("parameter has to be a scalar"));
629
630 auto param_element_type_error = ShapeInference::InferMapShape(
631 {&vector_32_, &vector_32_},
632 ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
633 ASSERT_FALSE(param_element_type_error.ok());
634 ASSERT_THAT(param_element_type_error.status().error_message(),
635 HasSubstr("parameter type has to match argument"));
636
637 Shape arg = ShapeUtil::MakeShape(F32, {20});
638 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
639 auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
640 EXPECT_IS_OK(inferred_status.status());
641 EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
642
643 auto inferred_status_error1 = ShapeInference::InferMapShape(
644 {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
645 ASSERT_FALSE(inferred_status_error1.ok());
646 ASSERT_THAT(inferred_status_error1.status().error_message(),
647 HasSubstr("arity must match number of arguments"));
648
649 auto inferred_status_error2 = ShapeInference::InferMapShape(
650 {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
651 ASSERT_FALSE(inferred_status_error2.ok());
652 ASSERT_THAT(inferred_status_error2.status().error_message(),
653 HasSubstr("has to be a scalar"));
654
655 auto inferred_status_error3 = ShapeInference::InferMapShape(
656 {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
657 ASSERT_FALSE(inferred_status_error3.ok());
658 ASSERT_THAT(inferred_status_error3.status().error_message(),
659 HasSubstr("has to be a scalar"));
660
661 auto inferred_status_error5 = ShapeInference::InferMapShape(
662 {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
663 ASSERT_FALSE(inferred_status_error5.ok());
664 ASSERT_THAT(inferred_status_error5.status().error_message(),
665 HasSubstr("parameter type has to match argument"));
666 }
667
TEST_F(ReduceShapeInferenceTest,ReduceVectorToScalar)668 TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
669 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
670 /*dimensions_to_reduce=*/{0});
671 }
672
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstDimension)673 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
674 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
675 ShapeUtil::MakeShape(F32, {2, 3, 4}),
676 /*dimensions_to_reduce=*/{0});
677 }
678
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongMiddleDimension)679 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
680 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
681 ShapeUtil::MakeShape(F32, {2, 3, 4}),
682 /*dimensions_to_reduce=*/{1});
683 }
684
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstTwoDimensions)685 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
686 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
687 ShapeUtil::MakeShape(F32, {2, 3, 4}),
688 /*dimensions_to_reduce=*/{0, 1});
689 }
690
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongLastTwoDimensions)691 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
692 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
693 ShapeUtil::MakeShape(F32, {2, 3, 4}),
694 /*dimensions_to_reduce=*/{1, 2});
695 }
696
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongFirstAndLastDimensions)697 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
698 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
699 ShapeUtil::MakeShape(F32, {2, 3, 4}),
700 /*dimensions_to_reduce=*/{0, 2});
701
702 // Check that the order of dimensions_to_reduce doesn't matter.
703 ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
704 ShapeUtil::MakeShape(F32, {2, 3, 4}),
705 /*dimensions_to_reduce=*/{2, 0});
706 }
707
TEST_F(ReduceShapeInferenceTest,ReduceCubeAmongAllDimensions)708 TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
709 ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
710 /*dimensions_to_reduce=*/{0, 1, 2});
711 }
712
TEST_F(ReduceShapeInferenceTest,ReduceMultiOutput)713 TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
714 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
715 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
716 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
717 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
718 auto inferred_status = ShapeInference::InferReduceShape(
719 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
720 EXPECT_IS_OK(inferred_status.status());
721 EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
722 inferred_status.ValueOrDie()));
723 }
724
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput1)725 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
726 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
727 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
728 ProgramShape to_apply =
729 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
730 ShapeUtil::MakeTupleShape({f32_, s32_}));
731 auto inferred_status = ShapeInference::InferReduceShape(
732 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
733 EXPECT_FALSE(inferred_status.ok());
734 EXPECT_THAT(inferred_status.status().error_message(),
735 HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
736 }
737
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput2)738 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
739 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
740 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
741 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
742 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
743 auto inferred_status = ShapeInference::InferReduceShape(
744 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
745 EXPECT_FALSE(inferred_status.ok());
746 EXPECT_THAT(
747 inferred_status.status().error_message(),
748 HasSubstr(
749 "parameter shape differs from the result shape: s32[] vs f32[]"));
750 }
751
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerInput3)752 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
753 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
754 {s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
755 auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
756 EXPECT_FALSE(inferred_status.ok());
757 EXPECT_THAT(inferred_status.status().error_message(),
758 HasSubstr("must have at least 2 arguments, has 0"));
759 }
760
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput1)761 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
762 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
763 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
764 ProgramShape to_apply =
765 ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
766 auto inferred_status = ShapeInference::InferReduceShape(
767 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
768 EXPECT_FALSE(inferred_status.ok());
769 EXPECT_THAT(
770 inferred_status.status().error_message(),
771 HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
772 }
773
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerOutput2)774 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
775 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
776 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
777 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
778 {f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
779 auto inferred_status = ShapeInference::InferReduceShape(
780 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
781 EXPECT_FALSE(inferred_status.ok());
782 EXPECT_THAT(
783 inferred_status.status().error_message(),
784 HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
785 }
786
TEST_F(ReduceShapeInferenceTest,ErrorMultiOutputBadReducerBoth)787 TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
788 Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
789 Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
790 ProgramShape to_apply = ShapeUtil::MakeProgramShape(
791 {s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
792 auto inferred_status = ShapeInference::InferReduceShape(
793 {&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
794 EXPECT_FALSE(inferred_status.ok());
795 EXPECT_THAT(inferred_status.status().error_message(),
796 HasSubstr("accumulator shape at index 0 differs from the "
797 "init_value shape: s32[] vs f32[]"));
798 }
799
TEST_F(ReduceShapeInferenceTest,ErrorOutOfBoundsDimension)800 TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
801 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
802 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
803 auto inferred_status = ShapeInference::InferReduceShape(
804 {&arg_shape, &f32_},
805 /*dimensions_to_reduce=*/{3, 4}, to_apply);
806 EXPECT_FALSE(inferred_status.ok());
807 EXPECT_THAT(inferred_status.status().error_message(),
808 HasSubstr("out-of-bounds dimension"));
809 }
810
TEST_F(ReduceShapeInferenceTest,ErrorToApplyArity)811 TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
812 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
813 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
814 auto inferred_status =
815 ShapeInference::InferReduceShape({&arg_shape, &f32_},
816 /*dimensions_to_reduce=*/{0}, to_apply);
817 EXPECT_FALSE(inferred_status.ok());
818 EXPECT_THAT(inferred_status.status().error_message(),
819 HasSubstr("take 2 parameters"));
820 }
821
TEST_F(ReduceShapeInferenceTest,ErrorElementTypeVsApplyType)822 TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
823 ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
824 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
825 auto inferred_status =
826 ShapeInference::InferReduceShape({&arg_shape, &f32_},
827 /*dimensions_to_reduce=*/{0}, to_apply);
828 EXPECT_FALSE(inferred_status.ok());
829 EXPECT_THAT(inferred_status.status().error_message(),
830 HasSubstr("0-th parameter shape differs"));
831 }
832
TEST_F(ShapeInferenceTest,InferSliceShapeRank2)833 TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
834 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
835 auto inferred_status =
836 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
837 ASSERT_IS_OK(inferred_status.status());
838 Shape inferred = inferred_status.ValueOrDie();
839 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
840 }
841
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStrides)842 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
843 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
844 auto inferred_status =
845 ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
846 ASSERT_IS_OK(inferred_status.status());
847 Shape inferred = inferred_status.ValueOrDie();
848 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
849 }
850
TEST_F(ShapeInferenceTest,InferSliceShapeRank2WithStridesNotIntegral)851 TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
852 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
853 auto inferred_status =
854 ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
855 ASSERT_IS_OK(inferred_status.status());
856 Shape inferred = inferred_status.ValueOrDie();
857 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
858 }
859
TEST_F(ShapeInferenceTest,InferInvalidStride)860 TEST_F(ShapeInferenceTest, InferInvalidStride) {
861 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
862 auto inferred_status =
863 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
864 ASSERT_FALSE(inferred_status.ok());
865 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
866 inferred_status.status().code());
867 }
868
TEST_F(ShapeInferenceTest,InferOobSliceShapeRank2)869 TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
870 Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
871 auto inferred_status =
872 ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
873 ASSERT_FALSE(inferred_status.ok());
874 ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
875 inferred_status.status().code());
876 }
877
TEST_F(ShapeInferenceTest,InferSliceShapeRank1)878 TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
879 Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
880 auto inferred_status =
881 ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
882 ASSERT_TRUE(inferred_status.ok());
883 Shape inferred = inferred_status.ValueOrDie();
884 ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
885 }
886
TEST_F(ShapeInferenceTest,InferConstIndexShape)887 TEST_F(ShapeInferenceTest, InferConstIndexShape) {
888 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
889 auto inferred0_status =
890 ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
891 auto inferred1_status =
892 ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
893 ASSERT_IS_OK(inferred0_status.status());
894 ASSERT_IS_OK(inferred1_status.status());
895 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
896 ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
897 }
898
TEST_F(ShapeInferenceTest,InferTupleElementShapeOutOfBound)899 TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
900 Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
901 auto inferredNegative_status =
902 ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
903 auto inferred2_status =
904 ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
905 ASSERT_FALSE(inferredNegative_status.ok());
906 ASSERT_FALSE(inferred2_status.ok());
907 EXPECT_THAT(inferredNegative_status.status().error_message(),
908 HasSubstr("attempt to index out of tuple bounds"));
909 EXPECT_THAT(inferred2_status.status().error_message(),
910 HasSubstr("attempt to index out of tuple bounds"));
911 }
912
TEST_F(ShapeInferenceTest,InferPowShape)913 TEST_F(ShapeInferenceTest, InferPowShape) {
914 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
915 auto inferred_status = ShapeInference::InferBinaryOpShape(
916 HloOpcode::kPower, ten_floats, f32_, {});
917 ASSERT_IS_OK(inferred_status.status());
918 ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
919 }
920
TEST_F(ShapeInferenceTest,InferCompareShape)921 TEST_F(ShapeInferenceTest, InferCompareShape) {
922 auto ten_floats = ShapeUtil::MakeShape(F32, {10});
923 auto inferred_status = ShapeInference::InferBinaryOpShape(
924 HloOpcode::kCompare, ten_floats, f32_, {});
925 ASSERT_IS_OK(inferred_status.status());
926 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
927 inferred_status.ValueOrDie()));
928 }
929
TEST_F(ShapeInferenceTest,BroadcastScalar)930 TEST_F(ShapeInferenceTest, BroadcastScalar) {
931 for (auto element_type : {F32, U32, S8}) {
932 const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
933 { // no-op scalar broadcast
934 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
935 ASSERT_IS_OK(status.status());
936 ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
937 }
938 const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
939 { // scalar -> 1d broadcast
940 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
941 ASSERT_IS_OK(status.status());
942 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
943 }
944 { // no-op 1d broadcast
945 auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
946 ASSERT_IS_OK(status.status());
947 ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
948 }
949 const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
950 { // scalar -> 2d broadcast
951 auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
952 ASSERT_IS_OK(status.status());
953 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
954 }
955 { // 1d -> 2d broadcast
956 auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
957 ASSERT_IS_OK(status.status());
958 ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
959 }
960 }
961 }
962
963 // scalar <dot> vector: error
TEST_F(ShapeInferenceTest,ScalarDotVector)964 TEST_F(ShapeInferenceTest, ScalarDotVector) {
965 DotDimensionNumbers dot_dnums;
966 dot_dnums.add_lhs_contracting_dimensions(1);
967 dot_dnums.add_rhs_contracting_dimensions(0);
968 auto inferred_status =
969 ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
970 ASSERT_FALSE(inferred_status.ok());
971 ASSERT_THAT(inferred_status.status().error_message(),
972 HasSubstr("Dot only supports rank"));
973 }
974
975 // 3D <dot> 2D: error
TEST_F(ShapeInferenceTest,DotWithRankHigherThanTwo)976 TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
977 DotDimensionNumbers dot_dnums;
978 dot_dnums.add_lhs_contracting_dimensions(1);
979 dot_dnums.add_rhs_contracting_dimensions(0);
980 auto inferred_status = ShapeInference::InferDotOpShape(
981 ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
982 EXPECT_TRUE(inferred_status.ok());
983 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
984 ShapeUtil::MakeShape(F32, {32, 32, 64})));
985 }
986
987 // vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest,VectorDotVector)988 TEST_F(ShapeInferenceTest, VectorDotVector) {
989 DotDimensionNumbers dot_dnums;
990 dot_dnums.add_lhs_contracting_dimensions(0);
991 dot_dnums.add_rhs_contracting_dimensions(0);
992 auto inferred_status =
993 ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
994 ASSERT_IS_OK(inferred_status.status());
995 ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
996 auto inferred_status_mismatch =
997 ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
998 ASSERT_FALSE(inferred_status_mismatch.ok());
999 }
1000
1001 // matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest,MatrixDotVector)1002 TEST_F(ShapeInferenceTest, MatrixDotVector) {
1003 DotDimensionNumbers dot_dnums;
1004 dot_dnums.add_lhs_contracting_dimensions(1);
1005 dot_dnums.add_rhs_contracting_dimensions(0);
1006 auto inferred_status =
1007 ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
1008 ASSERT_IS_OK(inferred_status.status());
1009 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
1010 auto inferred_status_mismatch =
1011 ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
1012 ASSERT_FALSE(inferred_status_mismatch.ok());
1013 }
1014
1015 // vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest,VectorDotMatrix)1016 TEST_F(ShapeInferenceTest, VectorDotMatrix) {
1017 DotDimensionNumbers dot_dnums;
1018 dot_dnums.add_lhs_contracting_dimensions(0);
1019 dot_dnums.add_rhs_contracting_dimensions(0);
1020 auto inferred_status =
1021 ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
1022 ASSERT_IS_OK(inferred_status.status());
1023 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
1024 auto inferred_status_mismatch =
1025 ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
1026 ASSERT_FALSE(inferred_status_mismatch.ok());
1027 }
1028
1029 // matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest,MatrixDotMatrix)1030 TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
1031 DotDimensionNumbers dot_dnums;
1032 dot_dnums.add_lhs_contracting_dimensions(1);
1033 dot_dnums.add_rhs_contracting_dimensions(0);
1034 auto inferred_status_match =
1035 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
1036 ASSERT_IS_OK(inferred_status_match.status());
1037 ASSERT_TRUE(
1038 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
1039 << "inferred: "
1040 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1041 << " expected: " << ShapeUtil::HumanString(matrix_64_48_);
1042 auto inferred_status_mismatch =
1043 ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
1044 ASSERT_FALSE(inferred_status_mismatch.ok());
1045 }
1046
1047 // BatchMatMul with two batch dimensions and one contracting dimension.
TEST_F(ShapeInferenceTest,DotGeneral)1048 TEST_F(ShapeInferenceTest, DotGeneral) {
1049 Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
1050 Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
1051 Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
1052
1053 DotDimensionNumbers dot_dnums;
1054 dot_dnums.add_lhs_contracting_dimensions(3);
1055 dot_dnums.add_lhs_batch_dimensions(0);
1056 dot_dnums.add_lhs_batch_dimensions(1);
1057
1058 dot_dnums.add_rhs_contracting_dimensions(2);
1059 dot_dnums.add_rhs_batch_dimensions(0);
1060 dot_dnums.add_rhs_batch_dimensions(1);
1061
1062 auto inferred_status_match =
1063 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1064 ASSERT_IS_OK(inferred_status_match.status());
1065 ASSERT_TRUE(
1066 ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
1067 << "inferred: "
1068 << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
1069 << " expected: " << ShapeUtil::HumanString(output_shape);
1070 }
1071
1072 // BatchMatMul with two contracting dimensions fails.
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsFails)1073 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
1074 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1075 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1076
1077 DotDimensionNumbers dot_dnums;
1078 dot_dnums.add_lhs_contracting_dimensions(2);
1079 dot_dnums.add_lhs_contracting_dimensions(3);
1080 dot_dnums.add_lhs_batch_dimensions(0);
1081
1082 dot_dnums.add_rhs_contracting_dimensions(1);
1083 dot_dnums.add_rhs_batch_dimensions(0);
1084
1085 auto inferred_status =
1086 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1087 ASSERT_FALSE(inferred_status.ok());
1088 ASSERT_THAT(inferred_status.status().error_message(),
1089 HasSubstr("Must specify the same number of contracting "
1090 "dimensions for lhs and rhs."));
1091 }
1092
TEST_F(ShapeInferenceTest,DotWithTwoContractingDimsPasses)1093 TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
1094 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
1095 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
1096 Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
1097
1098 DotDimensionNumbers dot_dnums;
1099 dot_dnums.add_lhs_contracting_dimensions(2);
1100 dot_dnums.add_lhs_contracting_dimensions(3);
1101 dot_dnums.add_lhs_batch_dimensions(0);
1102
1103 dot_dnums.add_rhs_contracting_dimensions(1);
1104 dot_dnums.add_rhs_contracting_dimensions(2);
1105 dot_dnums.add_rhs_batch_dimensions(0);
1106
1107 auto inferred_status =
1108 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1109 EXPECT_TRUE(inferred_status.ok());
1110 EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
1111 }
1112
1113 // BatchMatMul with different batch dimension sizes fails.
TEST_F(ShapeInferenceTest,DotWithMisatchedBatchDimSizesFails)1114 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
1115 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1116 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
1117
1118 DotDimensionNumbers dot_dnums;
1119 dot_dnums.add_lhs_contracting_dimensions(2);
1120 dot_dnums.add_lhs_batch_dimensions(0);
1121
1122 dot_dnums.add_rhs_contracting_dimensions(1);
1123 dot_dnums.add_rhs_batch_dimensions(0);
1124
1125 auto inferred_status =
1126 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1127 ASSERT_FALSE(inferred_status.ok());
1128 ASSERT_THAT(inferred_status.status().error_message(),
1129 HasSubstr("Batch dimension sizes must match"));
1130 }
1131
1132 // BatchMatMul with different batch dimension numbers passes
TEST_F(ShapeInferenceTest,DotWithMisatchedBatchDimNumbersPasses)1133 TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersPasses) {
1134 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1135 Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
1136
1137 DotDimensionNumbers dot_dnums;
1138 dot_dnums.add_lhs_contracting_dimensions(2);
1139 dot_dnums.add_lhs_batch_dimensions(0);
1140
1141 dot_dnums.add_rhs_contracting_dimensions(0);
1142 dot_dnums.add_rhs_batch_dimensions(1);
1143
1144 auto inferred_status =
1145 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1146 ASSERT_TRUE(inferred_status.ok());
1147 ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
1148 ShapeUtil::MakeShape(F32, {2, 11, 14})));
1149 }
1150
1151 // BatchMatMul with out-of-range dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingDimNumberOutOfRange)1152 TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
1153 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1154 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1155
1156 DotDimensionNumbers dot_dnums;
1157 dot_dnums.add_lhs_contracting_dimensions(3);
1158 dot_dnums.add_lhs_batch_dimensions(0);
1159
1160 dot_dnums.add_rhs_contracting_dimensions(0);
1161 dot_dnums.add_rhs_batch_dimensions(1);
1162
1163 auto inferred_status =
1164 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1165 ASSERT_FALSE(inferred_status.ok());
1166 ASSERT_THAT(inferred_status.status().error_message(),
1167 HasSubstr("A dimension number is out of range"));
1168 }
1169
1170 // BatchMatMul with non-unique dimension numbers fails.
TEST_F(ShapeInferenceTest,DotWithContractingNonUniqueDimNumber)1171 TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
1172 Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
1173 Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
1174
1175 DotDimensionNumbers dot_dnums;
1176 dot_dnums.add_lhs_contracting_dimensions(0);
1177 dot_dnums.add_lhs_batch_dimensions(0);
1178
1179 dot_dnums.add_rhs_contracting_dimensions(0);
1180 dot_dnums.add_rhs_batch_dimensions(1);
1181
1182 auto inferred_status =
1183 ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
1184 ASSERT_FALSE(inferred_status.ok());
1185 ASSERT_THAT(inferred_status.status().error_message(),
1186 HasSubstr("A dimension number is not unique"));
1187 }
1188
TEST_F(ShapeInferenceTest,BinOpBroadcastMatrixVector)1189 TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
1190 // Test variations of broadcasting a vector for a binary add with a
1191 // matrix.
1192 const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
1193 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1194 const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
1195
1196 auto inferred_status_match =
1197 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
1198 ASSERT_IS_OK(inferred_status_match.status());
1199 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1200
1201 auto inferred_status_mismatch =
1202 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
1203 ASSERT_FALSE(inferred_status_mismatch.ok());
1204
1205 inferred_status_match =
1206 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
1207 ASSERT_IS_OK(inferred_status_match.status());
1208 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
1209
1210 inferred_status_mismatch =
1211 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
1212 ASSERT_FALSE(inferred_status_mismatch.ok());
1213 }
1214
TEST_F(ShapeInferenceTest,BinOpBroadcastCubeMatrix)1215 TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
1216 // Test variations of broadcasting a matrix for a binary add with a cube.
1217 const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
1218 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1219 const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
1220 const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
1221
1222 auto inferred_status_match = ShapeInference::InferBinaryOpShape(
1223 HloOpcode::kAdd, cube, matrix8_4, {1, 2});
1224 ASSERT_IS_OK(inferred_status_match.status());
1225 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1226
1227 inferred_status_match = ShapeInference::InferBinaryOpShape(
1228 HloOpcode::kAdd, cube, matrix16_4, {0, 2});
1229 ASSERT_IS_OK(inferred_status_match.status());
1230 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1231
1232 inferred_status_match = ShapeInference::InferBinaryOpShape(
1233 HloOpcode::kAdd, cube, matrix16_8, {0, 1});
1234 ASSERT_IS_OK(inferred_status_match.status());
1235 ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
1236 }
1237
TEST_F(ShapeInferenceTest,BinOpBroadcastBadDimension)1238 TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
1239 // Test various errors with the broadcast argument.
1240 const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
1241 const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
1242 const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
1243 const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
1244 const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
1245
1246 // "magical" broadcast rejected
1247 auto inferred_status_error1 =
1248 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
1249 ASSERT_FALSE(inferred_status_error1.ok());
1250 ASSERT_THAT(inferred_status_error1.status().error_message(),
1251 HasSubstr("Automatic"));
1252
1253 // broadcast_dimension out of bounds for tensor's rank
1254 auto inferred_status_error2 =
1255 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
1256 ASSERT_FALSE(inferred_status_error2.ok());
1257 ASSERT_THAT(inferred_status_error2.status().error_message(),
1258 ContainsRegex("Broadcast dimension number .* too large"));
1259
1260 // broadcast_dimension doesn't match corresponding dimension
1261 auto inferred_status_error3 =
1262 ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
1263 ASSERT_FALSE(inferred_status_error3.ok());
1264 ASSERT_THAT(inferred_status_error3.status().error_message(),
1265 HasSubstr("Broadcast dimension 0 mismatch"));
1266
1267 // broadcast_dimensions list too long
1268 auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
1269 HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
1270 ASSERT_FALSE(inferred_status_error4.ok());
1271 ASSERT_THAT(inferred_status_error4.status().error_message(),
1272 HasSubstr("broadcast_dimensions has to match"));
1273
1274 // there's a dimension above the rank of the tensor
1275 auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
1276 HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
1277 ASSERT_FALSE(inferred_status_error5.ok());
1278 ASSERT_THAT(inferred_status_error5.status().error_message(),
1279 ContainsRegex("dimension number .* too large"));
1280
1281 // broadcasting dimensions don't match in this order
1282 auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
1283 HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
1284 ASSERT_FALSE(inferred_status_error6.ok());
1285 ASSERT_THAT(inferred_status_error6.status().error_message(),
1286 HasSubstr("dimension 0 mismatch"));
1287
1288 // The following two tests make sure that broadcasting dimensions are listed
1289 // in a proper (strictly increasing) order, even if the lower-rank array
1290 // matches the higher-rank array in many different ways.
1291 auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
1292 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
1293 ASSERT_FALSE(inferred_status_error7.ok());
1294 ASSERT_THAT(inferred_status_error7.status().error_message(),
1295 HasSubstr("dimensions order is wrong"));
1296
1297 auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
1298 HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
1299 ASSERT_FALSE(inferred_status_error8.ok());
1300 ASSERT_THAT(inferred_status_error8.status().error_message(),
1301 HasSubstr("dimensions order is wrong"));
1302 }
1303
1304 // Tests for the while instruction with proper shapes.
TEST_F(ShapeInferenceTest,WhileWithCorrectShapes)1305 TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
1306 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1307 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1308 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1309 auto inferred_status =
1310 ShapeInference::InferWhileShape(cond, body, result_shape);
1311 ASSERT_IS_OK(inferred_status.status());
1312 Shape inferred = inferred_status.ValueOrDie();
1313 ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
1314 }
1315
1316 // Tests for the while instruction with wrong shapes.
TEST_F(ShapeInferenceTest,WhileWithBadShapes)1317 TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
1318 Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
1319 ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
1320 ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
1321
1322 auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
1323 auto inferred_status_error1 =
1324 ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
1325 ASSERT_FALSE(inferred_status_error1.ok());
1326 ASSERT_THAT(inferred_status_error1.status().error_message(),
1327 HasSubstr("Condition must take 1 arguments"));
1328
1329 auto bad_shape_2 =
1330 ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
1331 auto inferred_status_error2 =
1332 ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
1333 ASSERT_FALSE(inferred_status_error2.ok());
1334 ASSERT_THAT(inferred_status_error2.status().error_message(),
1335 HasSubstr("Body must take 1 arguments"));
1336
1337 auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
1338 auto inferred_status_error3 =
1339 ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
1340 ASSERT_FALSE(inferred_status_error3.ok());
1341 ASSERT_THAT(inferred_status_error3.status().error_message(),
1342 HasSubstr("Condition must return a boolean"));
1343
1344 auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
1345 auto inferred_status_error4 =
1346 ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
1347 ASSERT_FALSE(inferred_status_error4.ok());
1348 ASSERT_THAT(inferred_status_error4.status().error_message(),
1349 HasSubstr("parameter of condition and body"));
1350 }
1351
1352 // Tests for the concatenate instruction with proper shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithCorrectShapes)1353 TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
1354 auto inferred_status_1 = ShapeInference::InferConcatOpShape(
1355 {&vector_32_, &vector_64_}, /*dimension=*/0);
1356 ASSERT_IS_OK(inferred_status_1.status());
1357 Shape inferred_1 = inferred_status_1.ValueOrDie();
1358 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
1359
1360 auto inferred_status_2 = ShapeInference::InferConcatOpShape(
1361 {&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
1362 ASSERT_IS_OK(inferred_status_2.status());
1363 Shape inferred_2 = inferred_status_2.ValueOrDie();
1364 ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
1365
1366 auto inferred_status_3 = ShapeInference::InferConcatOpShape(
1367 {&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
1368 ASSERT_IS_OK(inferred_status_3.status());
1369 Shape inferred_3 = inferred_status_3.ValueOrDie();
1370 ASSERT_TRUE(
1371 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
1372 }
1373
1374 // Tests for the concatenate instruction with wrong shapes.
TEST_F(ShapeInferenceTest,ConcatenateWithBadShapes)1375 TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
1376 auto inferred_status_error1 =
1377 ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
1378 ASSERT_FALSE(inferred_status_error1.ok());
1379 ASSERT_THAT(inferred_status_error1.status().error_message(),
1380 HasSubstr("Concatenate expects at least one argument"));
1381
1382 auto inferred_status_error2 =
1383 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
1384 ASSERT_FALSE(inferred_status_error2.ok());
1385 ASSERT_THAT(inferred_status_error2.status().error_message(),
1386 HasSubstr("dimension out of bounds: -1"));
1387
1388 auto inferred_status_error3 =
1389 ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
1390 ASSERT_FALSE(inferred_status_error3.ok());
1391 ASSERT_THAT(inferred_status_error3.status().error_message(),
1392 HasSubstr("dimension out of bounds: 1"));
1393
1394 Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
1395 auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
1396 {&vector_32_, &tuple}, /*dimension=*/0);
1397 ASSERT_FALSE(inferred_status_error4.ok());
1398 ASSERT_THAT(
1399 inferred_status_error4.status().error_message(),
1400 HasSubstr("Expected array argument for operand of concatenation"));
1401
1402 const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
1403 auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
1404 {&vector_32_, &vector_s32}, /*dimension=*/0);
1405 ASSERT_FALSE(inferred_status_error5.ok());
1406 ASSERT_THAT(inferred_status_error5.status().error_message(),
1407 HasSubstr("concatenate arrays with different element types"));
1408
1409 auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
1410 {&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
1411 ASSERT_FALSE(inferred_status_error6.ok());
1412 ASSERT_THAT(inferred_status_error6.status().error_message(),
1413 HasSubstr("concatenate arrays that differ in "
1414 "dimensions other than the one being "
1415 "concatenated"));
1416 }
1417
TEST_F(ShapeInferenceTest,Pad)1418 TEST_F(ShapeInferenceTest, Pad) {
1419 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
1420 Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
1421 // Padding for dimension 0: {low: 0, high: 2, interior: 3}
1422 // Padding for dimension 1: {low: 1, high: 5, interior: 0}
1423 PaddingConfig padding_config;
1424 auto dimension0 = padding_config.add_dimensions();
1425 dimension0->set_edge_padding_low(0);
1426 dimension0->set_edge_padding_high(2);
1427 dimension0->set_interior_padding(3);
1428 auto dimension1 = padding_config.add_dimensions();
1429 dimension1->set_edge_padding_low(1);
1430 dimension1->set_edge_padding_high(5);
1431 dimension1->set_interior_padding(0);
1432
1433 auto inferred_status = ShapeInference::InferPadShape(
1434 input_shape, padding_value_shape, padding_config);
1435 ASSERT_IS_OK(inferred_status.status());
1436 Shape inferred_shape = inferred_status.ValueOrDie();
1437 ASSERT_TRUE(
1438 ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
1439
1440 dimension1->set_edge_padding_low(-20);
1441 dimension1->set_edge_padding_high(-10);
1442 auto negative_dimension_size = ShapeInference::InferPadShape(
1443 input_shape, padding_value_shape, padding_config);
1444 ASSERT_FALSE(negative_dimension_size.ok());
1445 ASSERT_THAT(negative_dimension_size.status().error_message(),
1446 HasSubstr("negative size for dimension 1"));
1447 }
1448
TEST_F(ShapeInferenceTest,Reverse)1449 TEST_F(ShapeInferenceTest, Reverse) {
1450 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
1451
1452 auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
1453 ASSERT_IS_OK(inferred_status.status());
1454 Shape inferred_shape = inferred_status.ValueOrDie();
1455 ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
1456 }
1457
TEST_F(ShapeInferenceTest,ReverseInvalidDimension)1458 TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
1459 Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
1460
1461 auto inferred_status_error0 =
1462 ShapeInference::InferReverseShape(input_shape, {0, 2});
1463 ASSERT_FALSE(inferred_status_error0.ok());
1464 ASSERT_THAT(inferred_status_error0.status().error_message(),
1465 HasSubstr("out-of-bounds"));
1466
1467 auto inferred_status_error1 =
1468 ShapeInference::InferReverseShape(input_shape, {0, -1});
1469 ASSERT_FALSE(inferred_status_error1.ok());
1470 ASSERT_THAT(inferred_status_error1.status().error_message(),
1471 HasSubstr("out-of-bounds"));
1472
1473 auto inferred_status_error2 =
1474 ShapeInference::InferReverseShape(input_shape, {0, 0});
1475 ASSERT_FALSE(inferred_status_error2.ok());
1476 ASSERT_THAT(inferred_status_error2.status().error_message(),
1477 HasSubstr("duplicated"));
1478
1479 Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
1480 auto inferred_status_error3 =
1481 ShapeInference::InferReverseShape(tuple_shape, {0});
1482 ASSERT_FALSE(inferred_status_error3.ok());
1483 ASSERT_THAT(inferred_status_error3.status().error_message(),
1484 HasSubstr("Expected array argument"));
1485 }
1486
TEST_F(ShapeInferenceTest,Call)1487 TEST_F(ShapeInferenceTest, Call) {
1488 auto inferred_status0 =
1489 ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
1490 EXPECT_IS_OK(inferred_status0.status());
1491 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
1492
1493 auto inferred_status1 = ShapeInference::InferCallShape(
1494 {&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
1495 ShapeUtil::MakeProgramShape(
1496 {f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
1497 EXPECT_IS_OK(inferred_status1.status());
1498 EXPECT_TRUE(
1499 ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
1500
1501 auto inferred_status_error0 = ShapeInference::InferCallShape(
1502 {}, ShapeUtil::MakeProgramShape({f32_}, f32_));
1503 EXPECT_FALSE(inferred_status_error0.ok());
1504 EXPECT_THAT(inferred_status_error0.status().error_message(),
1505 HasSubstr("arity must match"));
1506
1507 auto inferred_status_error1 = ShapeInference::InferCallShape(
1508 {&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
1509 EXPECT_FALSE(inferred_status_error1.ok());
1510 EXPECT_THAT(inferred_status_error1.status().error_message(),
1511 HasSubstr("arity must match"));
1512
1513 auto inferred_status_error2 = ShapeInference::InferCallShape(
1514 {&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
1515 EXPECT_FALSE(inferred_status_error2.ok());
1516 EXPECT_THAT(inferred_status_error2.status().error_message(),
1517 HasSubstr("parameter must match argument"));
1518 }
1519
TEST_F(ShapeInferenceTest,Transpose)1520 TEST_F(ShapeInferenceTest, Transpose) {
1521 Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
1522 auto inferred_shape_and_status =
1523 ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
1524 EXPECT_IS_OK(inferred_shape_and_status);
1525 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
1526 EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
1527 ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
1528 }
1529
TEST_F(ShapeInferenceTest,Rank1Transpose)1530 TEST_F(ShapeInferenceTest, Rank1Transpose) {
1531 Shape a_shape = ShapeUtil::MakeShape(F32, {5});
1532 auto inferred_shape_and_status =
1533 ShapeInference::InferTransposeShape(a_shape, {0});
1534 EXPECT_IS_OK(inferred_shape_and_status);
1535 Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
1536 EXPECT_TRUE(
1537 ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
1538 }
1539
TEST_F(ShapeInferenceTest,ConditionalPred)1540 TEST_F(ShapeInferenceTest, ConditionalPred) {
1541 auto inferred_status0 = ShapeInference::InferConditionalShape(
1542 pred_,
1543 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1544 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1545 {vector_32_, vector_64_});
1546 EXPECT_IS_OK(inferred_status0.status());
1547 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
1548
1549 auto inferred_status1 = ShapeInference::InferConditionalShape(
1550 pred_,
1551 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
1552 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
1553 {matrix_32_48_, vector_32_});
1554 EXPECT_IS_OK(inferred_status1.status());
1555 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
1556
1557 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
1558 auto inferred_status2 = ShapeInference::InferConditionalShape(
1559 pred_,
1560 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
1561 ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
1562 {matrix_32_48_, tuple_f32_v32});
1563 EXPECT_IS_OK(inferred_status2.status());
1564 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
1565
1566 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
1567 f32_,
1568 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1569 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1570 {vector_32_, vector_64_});
1571 EXPECT_FALSE(inferred_status_error0.ok());
1572 EXPECT_THAT(inferred_status_error0.status().error_message(),
1573 HasSubstr("must be bool or int32"));
1574
1575 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
1576 pred_,
1577 {ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
1578 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
1579 {ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
1580 EXPECT_FALSE(inferred_status_error1.ok());
1581 EXPECT_THAT(inferred_status_error1.status().error_message(),
1582 HasSubstr("branch computation 0 must take 1 argument"));
1583
1584 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
1585 pred_,
1586 {ShapeUtil::MakeProgramShape({vector_64_}, f32_),
1587 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1588 {vector_32_, vector_64_});
1589 EXPECT_FALSE(inferred_status_error2.ok());
1590 EXPECT_THAT(inferred_status_error2.status().error_message(),
1591 HasSubstr("branch operand 0 must match the shape of the only "
1592 "parameter of branch computation 0"));
1593
1594 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
1595 pred_,
1596 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
1597 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
1598 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
1599 EXPECT_FALSE(inferred_status_error3.ok());
1600 EXPECT_THAT(inferred_status_error3.status().error_message(),
1601 HasSubstr("branch computation 1 must take 1 argument"));
1602
1603 auto inferred_status_error4 = ShapeInference::InferConditionalShape(
1604 pred_,
1605 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1606 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
1607 {vector_32_, vector_64_});
1608 EXPECT_FALSE(inferred_status_error4.ok());
1609 EXPECT_THAT(inferred_status_error4.status().error_message(),
1610 HasSubstr("branch operand 1 must match the shape of the only "
1611 "parameter of branch computation 1"));
1612
1613 auto inferred_status_error5 = ShapeInference::InferConditionalShape(
1614 pred_,
1615 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1616 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
1617 {vector_32_, vector_64_});
1618 EXPECT_FALSE(inferred_status_error5.ok());
1619 EXPECT_THAT(inferred_status_error5.status().error_message(),
1620 HasSubstr("the result of branch 0 computation and branch 1 "
1621 "computation must have the same shape"));
1622 }
1623
TEST_F(ShapeInferenceTest,ConditionalIndexed)1624 TEST_F(ShapeInferenceTest, ConditionalIndexed) {
1625 auto r0s32 = ShapeUtil::MakeShape(S32, {});
1626 auto inferred_status0 = ShapeInference::InferConditionalShape(
1627 r0s32,
1628 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1629 ShapeUtil::MakeProgramShape({vector_64_}, f32_),
1630 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1631 {vector_32_, vector_64_, vector_64_});
1632 EXPECT_IS_OK(inferred_status0.status());
1633 EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
1634
1635 auto inferred_status1 = ShapeInference::InferConditionalShape(
1636 r0s32,
1637 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
1638 ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
1639 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
1640 {matrix_32_48_, vector_32_, matrix_32_48_});
1641 EXPECT_IS_OK(inferred_status1.status());
1642 EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
1643
1644 auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
1645 auto inferred_status2 = ShapeInference::InferConditionalShape(
1646 r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
1647 {tuple_f32_v32});
1648 EXPECT_IS_OK(inferred_status2.status());
1649 EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
1650
1651 auto inferred_status_error0 = ShapeInference::InferConditionalShape(
1652 pred_,
1653 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1654 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1655 ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
1656 {vector_32_, vector_32_, vector_64_});
1657 EXPECT_FALSE(inferred_status_error0.ok());
1658 EXPECT_THAT(inferred_status_error0.status().error_message(),
1659 HasSubstr("2 == branch_computations.size()"));
1660
1661 auto inferred_status_error1 = ShapeInference::InferConditionalShape(
1662 r0s32,
1663 {ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
1664 ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
1665 ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
1666 {matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
1667 matrix_32_48_});
1668 EXPECT_FALSE(inferred_status_error1.ok());
1669 EXPECT_THAT(inferred_status_error1.status().error_message(),
1670 HasSubstr("branch computation 1 must take 1 argument"));
1671
1672 auto inferred_status_error2 = ShapeInference::InferConditionalShape(
1673 r0s32,
1674 {ShapeUtil::MakeProgramShape({r0s32}, f32_),
1675 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1676 ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
1677 {r0s32, vector_32_, vector_64_});
1678 EXPECT_FALSE(inferred_status_error2.ok());
1679 EXPECT_THAT(inferred_status_error2.status().error_message(),
1680 HasSubstr("branch operand 2 must match the shape of the only "
1681 "parameter of branch computation 2"));
1682
1683 auto inferred_status_error3 = ShapeInference::InferConditionalShape(
1684 r0s32,
1685 {ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1686 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1687 ShapeUtil::MakeProgramShape({vector_32_}, f32_),
1688 ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
1689 {vector_32_, vector_32_, vector_32_, vector_64_});
1690 EXPECT_FALSE(inferred_status_error3.ok());
1691 EXPECT_THAT(inferred_status_error3.status().error_message(),
1692 HasSubstr("the result of branch 0 computation and branch 3 "
1693 "computation must have the same shape"));
1694
1695 auto inferred_status_error4 =
1696 ShapeInference::InferConditionalShape(r0s32, {}, {});
1697 EXPECT_FALSE(inferred_status_error4.ok());
1698 EXPECT_THAT(inferred_status_error4.status().error_message(),
1699 HasSubstr("!branch_computations.empty()"));
1700 }
1701
TEST_F(ShapeInferenceTest,BadSlice)1702 TEST_F(ShapeInferenceTest, BadSlice) {
1703 auto arg = ShapeUtil::MakeShape(F32, {4});
1704 StatusOr<Shape> statusor =
1705 ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
1706 ASSERT_FALSE(statusor.ok());
1707
1708 LOG(INFO) << statusor.status();
1709
1710 EXPECT_THAT(statusor.status().error_message(),
1711 HasSubstr("less than or equal to dimension size"))
1712 << statusor.status();
1713 EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
1714 << statusor.status();
1715 }
1716
TEST_F(ShapeInferenceTest,BadSort)1717 TEST_F(ShapeInferenceTest, BadSort) {
1718 auto keys = ShapeUtil::MakeShape(F32, {4});
1719 auto values = ShapeUtil::MakeShape(F32, {5});
1720 StatusOr<Shape> statusor =
1721 ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
1722 EXPECT_FALSE(statusor.ok());
1723 EXPECT_THAT(statusor.status().error_message(),
1724 HasSubstr("dimensions must match"))
1725 << statusor.status();
1726 }
1727
TEST_F(ShapeInferenceTest,BadSortValuesMismatch)1728 TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
1729 auto keys = ShapeUtil::MakeShape(F32, {4});
1730 auto values_good = ShapeUtil::MakeShape(F32, {4});
1731 auto values_bad = ShapeUtil::MakeShape(F32, {5});
1732 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
1733 HloOpcode::kSort, {&keys, &values_good, &values_bad});
1734 EXPECT_FALSE(statusor.ok());
1735 EXPECT_THAT(statusor.status().error_message(),
1736 HasSubstr("dimensions must match"))
1737 << statusor.status();
1738 }
1739
TEST_F(ShapeInferenceTest,SortManyValues)1740 TEST_F(ShapeInferenceTest, SortManyValues) {
1741 auto keys = ShapeUtil::MakeShape(F32, {4});
1742 auto values_s32 = ShapeUtil::MakeShape(S32, {4});
1743 auto values_u32 = ShapeUtil::MakeShape(U32, {4});
1744 StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
1745 HloOpcode::kSort, {&keys, &values_s32, &values_u32});
1746 EXPECT_IS_OK(statusor);
1747 Shape inferred_shape = statusor.ValueOrDie();
1748 EXPECT_TRUE(ShapeUtil::Compatible(
1749 inferred_shape,
1750 ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
1751 }
1752
1753 class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
1754 protected:
1755 const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
1756 const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
1757 const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
1758 const Shape s64_4d_tensor_10_9_8_7_1_ =
1759 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
1760 const Shape s64_4d_tensor_10_9_8_7_5_ =
1761 ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
1762 const Shape s64_4d_tensor_5_10_9_7_6_ =
1763 ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
1764 const Shape s64_4d_tensor_10_9_5_7_6_ =
1765 ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
1766 const Shape f32_5d_tensor_50_49_48_47_46_ =
1767 ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
1768 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
1769 {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
1770 const ProgramShape to_apply_ =
1771 ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
1772 };
1773
1774 // Shape inference tests for Gather.
1775
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGather)1776 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
1777 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1778 ShapeInference::InferGatherShape(
1779 matrix_64_48_, s64_vector_32_,
1780 HloGatherInstruction::MakeGatherDimNumbers(
1781 /*offset_dims=*/{0},
1782 /*collapsed_slice_dims=*/{1},
1783 /*start_index_map=*/{1},
1784 /*index_vector_dim=*/1),
1785 /*slice_sizes=*/{64, 1}));
1786 EXPECT_TRUE(
1787 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
1788 << ShapeUtil::HumanString(gather_shape);
1789 }
1790
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherV2)1791 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
1792 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1793 ShapeInference::InferGatherShape(
1794 matrix_64_48_, s64_vector_32_,
1795 HloGatherInstruction::MakeGatherDimNumbers(
1796 /*offset_dims=*/{1},
1797 /*collapsed_slice_dims=*/{0},
1798 /*start_index_map=*/{0},
1799 /*index_vector_dim=*/1),
1800 /*slice_sizes=*/{1, 48}));
1801 EXPECT_TRUE(
1802 ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
1803 << ShapeUtil::HumanString(gather_shape);
1804 }
1805
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowGatherNd)1806 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
1807 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1808 ShapeInference::InferGatherShape(
1809 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
1810 HloGatherInstruction::MakeGatherDimNumbers(
1811 /*offset_dims=*/{4},
1812 /*collapsed_slice_dims=*/{0},
1813 /*start_index_map=*/{0},
1814 /*index_vector_dim=*/4),
1815 /*slice_sizes=*/{1, 48}));
1816 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
1817 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
1818 << ShapeUtil::HumanString(gather_shape);
1819 }
1820
TEST_F(ScatterGatherShapeInferenceTest,TensorFlowBatchDynamicSlice)1821 TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
1822 TF_ASSERT_OK_AND_ASSIGN(
1823 Shape gather_shape,
1824 ShapeInference::InferGatherShape(
1825 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1826 HloGatherInstruction::MakeGatherDimNumbers(
1827 /*offset_dims=*/{4, 5, 6, 7, 8},
1828 /*collapsed_slice_dims=*/{},
1829 /*start_index_map=*/{0, 1, 2, 3, 4},
1830 /*index_vector_dim=*/4),
1831 /*slice_sizes=*/{30, 29, 28, 27, 26}));
1832 EXPECT_TRUE(ShapeUtil::Equal(
1833 gather_shape,
1834 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
1835 << ShapeUtil::HumanString(gather_shape);
1836 }
1837
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_A)1838 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
1839 TF_ASSERT_OK_AND_ASSIGN(
1840 Shape gather_shape,
1841 ShapeInference::InferGatherShape(
1842 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
1843 HloGatherInstruction::MakeGatherDimNumbers(
1844 /*offset_dims=*/{4, 5, 6, 7, 8},
1845 /*collapsed_slice_dims=*/{},
1846 /*start_index_map=*/{0, 1, 2, 3, 4},
1847 /*index_vector_dim=*/2),
1848 /*slice_sizes=*/{30, 29, 28, 27, 26}));
1849
1850 EXPECT_TRUE(ShapeUtil::Equal(
1851 gather_shape,
1852 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
1853 << ShapeUtil::HumanString(gather_shape);
1854 }
1855
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultGatherIndicesLeafDim_B)1856 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
1857 TF_ASSERT_OK_AND_ASSIGN(
1858 Shape gather_shape,
1859 ShapeInference::InferGatherShape(
1860 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
1861 HloGatherInstruction::MakeGatherDimNumbers(
1862 /*offset_dims=*/{4, 5, 6, 7, 8},
1863 /*collapsed_slice_dims=*/{},
1864 /*start_index_map=*/{0, 1, 2, 3, 4},
1865 /*index_vector_dim=*/0),
1866 /*slice_sizes=*/{30, 29, 28, 27, 26}));
1867
1868 EXPECT_TRUE(ShapeUtil::Equal(
1869 gather_shape,
1870 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
1871 << ShapeUtil::HumanString(gather_shape);
1872 }
1873
TEST_F(ScatterGatherShapeInferenceTest,NoOutputGatherDims)1874 TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
1875 // This is equivalent to a dynamic slice.
1876 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1877 ShapeInference::InferGatherShape(
1878 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
1879 HloGatherInstruction::MakeGatherDimNumbers(
1880 /*offset_dims=*/{0, 1, 2, 3, 4},
1881 /*collapsed_slice_dims=*/{},
1882 /*start_index_map=*/{0, 1, 2, 3, 4},
1883 /*index_vector_dim=*/0),
1884 /*slice_sizes=*/{30, 29, 28, 27, 26}));
1885
1886 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
1887 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
1888 << ShapeUtil::HumanString(gather_shape);
1889 }
1890
TEST_F(ScatterGatherShapeInferenceTest,ScalarGatherIndices)1891 TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
1892 // The gather indices "tensor" is a scalar S here that's used to slice out
1893 // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
1894 TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
1895 ShapeInference::InferGatherShape(
1896 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
1897 HloGatherInstruction::MakeGatherDimNumbers(
1898 /*offset_dims=*/{0, 1, 2, 3},
1899 /*collapsed_slice_dims=*/{0},
1900 /*start_index_map=*/{0},
1901 /*index_vector_dim=*/0),
1902 /*slice_sizes=*/{1, 30, 29, 28, 27}));
1903
1904 EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
1905 ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
1906 << ShapeUtil::HumanString(gather_shape);
1907 }
1908
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedTensorInput)1909 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
1910 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1911 tuple_shape_, s64_vector_32_,
1912 HloGatherInstruction::MakeGatherDimNumbers(
1913 /*offset_dims=*/{0},
1914 /*collapsed_slice_dims=*/{1},
1915 /*start_index_map=*/{1},
1916 /*index_vector_dim=*/1),
1917 /*slice_sizes=*/{64, 1});
1918 ASSERT_FALSE(statusor.ok());
1919 EXPECT_THAT(statusor.status().error_message(),
1920 HasSubstr("Expected array argument for input"))
1921 << statusor.status();
1922 }
1923
TEST_F(ScatterGatherShapeInferenceTest,TupleShapedGatherIndicesInput)1924 TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
1925 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1926 s64_vector_32_, tuple_shape_,
1927 HloGatherInstruction::MakeGatherDimNumbers(
1928 /*offset_dims=*/{0},
1929 /*collapsed_slice_dims=*/{1},
1930 /*start_index_map=*/{1},
1931 /*index_vector_dim=*/0),
1932 /*slice_sizes=*/{64, 1});
1933 ASSERT_FALSE(statusor.ok());
1934 EXPECT_THAT(statusor.status().error_message(),
1935 HasSubstr("Expected array argument for gather indices"))
1936 << statusor.status();
1937 }
1938
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointGatherIndicesInput)1939 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
1940 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1941 s64_vector_32_, vector_32_,
1942 HloGatherInstruction::MakeGatherDimNumbers(
1943 /*offset_dims=*/{0},
1944 /*collapsed_slice_dims=*/{1},
1945 /*start_index_map=*/{1},
1946 /*index_vector_dim=*/0),
1947 /*slice_sizes=*/{64, 1});
1948 ASSERT_FALSE(statusor.ok());
1949 EXPECT_THAT(statusor.status().error_message(),
1950 HasSubstr("Gather indices parameter must be an integral tensor"))
1951 << statusor.status();
1952 }
1953
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingWindowIndices)1954 TEST_F(ScatterGatherShapeInferenceTest,
1955 InvalidGatherDimNumbers_NonAscendingWindowIndices) {
1956 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1957 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1958 HloGatherInstruction::MakeGatherDimNumbers(
1959 /*offset_dims=*/{4, 5, 6, 8, 7},
1960 /*collapsed_slice_dims=*/{},
1961 /*start_index_map=*/{0, 1, 2, 3, 4},
1962 /*index_vector_dim=*/4),
1963 /*slice_sizes=*/{30, 29, 28, 27, 26});
1964 ASSERT_FALSE(statusor.ok());
1965 EXPECT_THAT(
1966 statusor.status().error_message(),
1967 HasSubstr("Output window dimensions in gather op must be ascending"))
1968 << statusor.status();
1969 }
1970
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowIndices)1971 TEST_F(ScatterGatherShapeInferenceTest,
1972 InvalidGatherDimNumbers_RepeatedWindowIndices) {
1973 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1974 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1975 HloGatherInstruction::MakeGatherDimNumbers(
1976 /*offset_dims=*/{4, 5, 6, 7, 7},
1977 /*collapsed_slice_dims=*/{},
1978 /*start_index_map=*/{0, 1, 2, 3, 4},
1979 /*index_vector_dim=*/4),
1980 /*slice_sizes=*/{30, 29, 28, 27, 26});
1981 ASSERT_FALSE(statusor.ok());
1982 EXPECT_THAT(
1983 statusor.status().error_message(),
1984 HasSubstr("Output window dimensions in gather op must not repeat"))
1985 << statusor.status();
1986 }
1987
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexOutOfBounds)1988 TEST_F(ScatterGatherShapeInferenceTest,
1989 InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
1990 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
1991 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
1992 HloGatherInstruction::MakeGatherDimNumbers(
1993 /*offset_dims=*/{4, 5, 99, 100, 101},
1994 /*collapsed_slice_dims=*/{},
1995 /*start_index_map=*/{0, 1, 2, 3, 4},
1996 /*index_vector_dim=*/4),
1997 /*slice_sizes=*/{30, 29, 28, 27, 26});
1998 ASSERT_FALSE(statusor.ok());
1999 EXPECT_THAT(statusor.status().error_message(),
2000 HasSubstr("Offset dimension 2 in gather op is out of bounds"))
2001 << statusor.status();
2002 }
2003
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds)2004 TEST_F(ScatterGatherShapeInferenceTest,
2005 InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
2006 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2007 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2008 HloGatherInstruction::MakeGatherDimNumbers(
2009 /*offset_dims=*/{4, 5, 6, 7, 9},
2010 /*collapsed_slice_dims=*/{},
2011 /*start_index_map=*/{0, 1, 2, 3, 4},
2012 /*index_vector_dim=*/4),
2013 /*slice_sizes=*/{30, 29, 28, 27, 26});
2014 ASSERT_FALSE(statusor.ok());
2015 EXPECT_THAT(statusor.status().error_message(),
2016 HasSubstr("Offset dimension 4 in gather op is out of bounds"))
2017 << statusor.status();
2018 }
2019
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingElidedWindowDims)2020 TEST_F(ScatterGatherShapeInferenceTest,
2021 InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
2022 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2023 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2024 HloGatherInstruction::MakeGatherDimNumbers(
2025 /*offset_dims=*/{4, 5, 6, 7, 8},
2026 /*collapsed_slice_dims=*/{4},
2027 /*start_index_map=*/{0, 1, 2, 3, 4},
2028 /*index_vector_dim=*/4),
2029 /*slice_sizes=*/{30, 29, 28, 27, 26});
2030 ASSERT_FALSE(statusor.ok());
2031 EXPECT_THAT(
2032 statusor.status().error_message(),
2033 HasSubstr("All components of the offset index in a gather op must either "
2034 "be a offset dimension or explicitly collapsed"))
2035 << statusor.status();
2036 }
2037
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping)2038 TEST_F(ScatterGatherShapeInferenceTest,
2039 InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
2040 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2041 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2042 HloGatherInstruction::MakeGatherDimNumbers(
2043 /*offset_dims=*/{4, 5, 6, 7, 8},
2044 /*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
2045 /*start_index_map=*/{0, 1, 2, 3, 4},
2046 /*index_vector_dim=*/4),
2047 /*slice_sizes=*/{30, 29, 28, 27, 26});
2048 ASSERT_FALSE(statusor.ok());
2049 EXPECT_THAT(statusor.status().error_message(),
2050 HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
2051 "range is [0, 5), got: 19"))
2052 << statusor.status();
2053 }
2054
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedWindowToInputMapping)2055 TEST_F(ScatterGatherShapeInferenceTest,
2056 InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
2057 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2058 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2059 HloGatherInstruction::MakeGatherDimNumbers(
2060 /*offset_dims=*/{4, 5, 6, 7, 8},
2061 /*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
2062 /*start_index_map=*/{0, 1, 2, 3, 4},
2063 /*index_vector_dim=*/4),
2064 /*slice_sizes=*/{30, 29, 28, 27, 26});
2065 ASSERT_FALSE(statusor.ok());
2066 EXPECT_THAT(statusor.status().error_message(),
2067 HasSubstr("Repeated dimensions not allowed in "
2068 "collapsed_slice_dims in gather op"))
2069 << statusor.status();
2070 }
2071
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingGatherToInputMapping)2072 TEST_F(ScatterGatherShapeInferenceTest,
2073 InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
2074 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2075 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2076 HloGatherInstruction::MakeGatherDimNumbers(
2077 /*offset_dims=*/{4, 5, 6, 7, 8},
2078 /*collapsed_slice_dims=*/{},
2079 /*start_index_map=*/{0, 1, 2, 3},
2080 /*index_vector_dim=*/4),
2081 /*slice_sizes=*/{30, 29, 28, 27, 26});
2082 ASSERT_FALSE(statusor.ok());
2083 EXPECT_THAT(statusor.status().error_message(),
2084 HasSubstr("Gather op has 4 elements in start_index_map and "
2085 "the bound of dimension index_vector_dim=4 of "
2086 "start_indices is 5. These two numbers must be equal."))
2087 << statusor.status();
2088 }
2089
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping)2090 TEST_F(ScatterGatherShapeInferenceTest,
2091 InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
2092 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2093 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2094 HloGatherInstruction::MakeGatherDimNumbers(
2095 /*offset_dims=*/{4, 5, 6, 7, 8},
2096 /*collapsed_slice_dims=*/{},
2097 /*start_index_map=*/{0, 1, 2, 3, 7},
2098 /*index_vector_dim=*/4),
2099 /*slice_sizes=*/{30, 29, 28, 27, 26});
2100 ASSERT_FALSE(statusor.ok());
2101 EXPECT_THAT(statusor.status().error_message(),
2102 HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
2103 << statusor.status();
2104 }
2105
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_RepeatedGatherToInputMapping)2106 TEST_F(ScatterGatherShapeInferenceTest,
2107 InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
2108 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2109 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2110 HloGatherInstruction::MakeGatherDimNumbers(
2111 /*offset_dims=*/{4, 5, 6, 7, 8},
2112 /*collapsed_slice_dims=*/{},
2113 /*start_index_map=*/{0, 1, 2, 3, 3},
2114 /*index_vector_dim=*/4),
2115 /*slice_sizes=*/{30, 29, 28, 27, 26});
2116 ASSERT_FALSE(statusor.ok());
2117 EXPECT_THAT(
2118 statusor.status().error_message(),
2119 HasSubstr("Repeated dimensions are not allowed in start_index_map"))
2120 << statusor.status();
2121 }
2122
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_NonAscendingElidedWindowDims)2123 TEST_F(ScatterGatherShapeInferenceTest,
2124 InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
2125 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2126 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2127 HloGatherInstruction::MakeGatherDimNumbers(
2128 /*offset_dims=*/{4, 5, 6, 7, 8},
2129 /*collapsed_slice_dims=*/{2, 1},
2130 /*start_index_map=*/{0, 1, 2, 3, 4},
2131 /*index_vector_dim=*/4),
2132 /*slice_sizes=*/{1, 1, 28, 27, 26});
2133 ASSERT_FALSE(statusor.ok());
2134 EXPECT_THAT(statusor.status().error_message(),
2135 HasSubstr("collapsed_slice_dims in gather op must be sorted"))
2136 << statusor.status();
2137 }
2138
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsTooLarge)2139 TEST_F(ScatterGatherShapeInferenceTest,
2140 InvalidGatherDimNumbers_WindowBoundsTooLarge) {
2141 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2142 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2143 HloGatherInstruction::MakeGatherDimNumbers(
2144 /*offset_dims=*/{4, 5, 6, 7},
2145 /*collapsed_slice_dims=*/{2},
2146 /*start_index_map=*/{0, 1, 2, 3, 4},
2147 /*index_vector_dim=*/4),
2148 /*slice_sizes=*/{30, 29, 1, 300, 26});
2149 ASSERT_FALSE(statusor.ok());
2150 EXPECT_THAT(statusor.status().error_message(),
2151 HasSubstr("Slice size at index 3 in gather op is out of range, "
2152 "must be within [0, 48), got 300."))
2153 << statusor.status();
2154 }
2155
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds)2156 TEST_F(ScatterGatherShapeInferenceTest,
2157 InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
2158 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2159 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2160 HloGatherInstruction::MakeGatherDimNumbers(
2161 /*offset_dims=*/{4, 5, 6, 7, 8},
2162 /*collapsed_slice_dims=*/{},
2163 /*start_index_map=*/{0, 1, 2, 3, 4},
2164 /*index_vector_dim=*/4),
2165 /*slice_sizes=*/{30, 29, 28, 26});
2166 ASSERT_FALSE(statusor.ok());
2167 EXPECT_THAT(
2168 statusor.status().error_message(),
2169 HasSubstr("Gather op must have one slice size for every input dimension"))
2170 << statusor.status();
2171 }
2172
TEST_F(ScatterGatherShapeInferenceTest,InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim)2173 TEST_F(ScatterGatherShapeInferenceTest,
2174 InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
2175 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2176 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2177 HloGatherInstruction::MakeGatherDimNumbers(
2178 /*offset_dims=*/{4, 5, 6, 7},
2179 /*collapsed_slice_dims=*/{1},
2180 /*start_index_map=*/{0, 1, 2, 3, 4},
2181 /*index_vector_dim=*/4),
2182 /*slice_sizes=*/{30, 29, 28, 26, 20});
2183 ASSERT_FALSE(statusor.ok());
2184 EXPECT_THAT(statusor.status().error_message(),
2185 HasSubstr("Gather op can only collapse slice dims with bound 1, "
2186 "but bound is 29 for index 1 at position 0."))
2187 << statusor.status();
2188 }
2189
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsGatherIndicesLeafDim)2190 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
2191 StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
2192 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2193 HloGatherInstruction::MakeGatherDimNumbers(
2194 /*offset_dims=*/{4, 5, 6, 7, 8},
2195 /*collapsed_slice_dims=*/{},
2196 /*start_index_map=*/{0, 1, 2, 3, 4},
2197 /*index_vector_dim=*/32),
2198 /*slice_sizes=*/{30, 29, 28, 27, 26});
2199
2200 ASSERT_FALSE(statusor.ok());
2201 EXPECT_THAT(statusor.status().error_message(),
2202 HasSubstr("Gather index leaf dimension must be within [0, "
2203 "rank(start_indices) + 1)"))
2204 << statusor.status();
2205 }
2206
2207 // Shape inference tests for Scatter.
2208
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdates)2209 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
2210 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2211 ShapeInference::InferScatterShape(
2212 matrix_64_48_, s64_vector_32_,
2213 ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
2214 HloScatterInstruction::MakeScatterDimNumbers(
2215 /*update_window_dims=*/{0},
2216 /*inserted_window_dims=*/{1},
2217 /*scatter_dims_to_operand_dims=*/{1},
2218 /*index_vector_dim=*/1)));
2219 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2220 << ShapeUtil::HumanString(scatter_shape);
2221 }
2222
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithFullUpdatesV2)2223 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
2224 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2225 ShapeInference::InferScatterShape(
2226 matrix_64_48_, s64_vector_32_,
2227 ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
2228 HloScatterInstruction::MakeScatterDimNumbers(
2229 /*update_window_dims=*/{1},
2230 /*inserted_window_dims=*/{0},
2231 /*scatter_dims_to_operand_dims=*/{0},
2232 /*index_vector_dim=*/1)));
2233 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2234 << ShapeUtil::HumanString(scatter_shape);
2235 }
2236
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdates)2237 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
2238 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2239 ShapeInference::InferScatterShape(
2240 matrix_64_48_, s64_vector_32_,
2241 ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
2242 HloScatterInstruction::MakeScatterDimNumbers(
2243 /*update_window_dims=*/{0},
2244 /*inserted_window_dims=*/{1},
2245 /*scatter_dims_to_operand_dims=*/{1},
2246 /*index_vector_dim=*/1)));
2247 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2248 << ShapeUtil::HumanString(scatter_shape);
2249 }
2250
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithPartialUpdatesV2)2251 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
2252 TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
2253 ShapeInference::InferScatterShape(
2254 matrix_64_48_, s64_vector_32_,
2255 ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
2256 HloScatterInstruction::MakeScatterDimNumbers(
2257 /*update_window_dims=*/{1},
2258 /*inserted_window_dims=*/{0},
2259 /*scatter_dims_to_operand_dims=*/{0},
2260 /*index_vector_dim=*/1)));
2261 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2262 << ShapeUtil::HumanString(scatter_shape);
2263 }
2264
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInput)2265 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
2266 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2267 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
2268 to_apply_,
2269 HloScatterInstruction::MakeScatterDimNumbers(
2270 /*update_window_dims=*/{0},
2271 /*inserted_window_dims=*/{1},
2272 /*scatter_dims_to_operand_dims=*/{1},
2273 /*index_vector_dim=*/1));
2274 ASSERT_FALSE(statusor.ok());
2275 EXPECT_THAT(
2276 statusor.status().error_message(),
2277 HasSubstr("Bounds of the window dimensions of updates must not exceed "
2278 "the bounds of the corresponding dimensions of operand."))
2279 << statusor.status();
2280 }
2281
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesBiggerThanInputV2)2282 TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
2283 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2284 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
2285 to_apply_,
2286 HloScatterInstruction::MakeScatterDimNumbers(
2287 /*update_window_dims=*/{1},
2288 /*inserted_window_dims=*/{0},
2289 /*scatter_dims_to_operand_dims=*/{1},
2290 /*index_vector_dim=*/1));
2291 ASSERT_FALSE(statusor.ok());
2292 EXPECT_THAT(
2293 statusor.status().error_message(),
2294 HasSubstr("Bounds of the window dimensions of updates must not exceed "
2295 "the bounds of the corresponding dimensions of operand."))
2296 << statusor.status();
2297 }
2298
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndices)2299 TEST_F(ScatterGatherShapeInferenceTest,
2300 TfScatterWithUpdatesNotMatchingIndices) {
2301 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2302 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
2303 to_apply_,
2304 HloScatterInstruction::MakeScatterDimNumbers(
2305 /*update_window_dims=*/{0},
2306 /*inserted_window_dims=*/{1},
2307 /*scatter_dims_to_operand_dims=*/{1},
2308 /*index_vector_dim=*/1));
2309 ASSERT_FALSE(statusor.ok());
2310 EXPECT_THAT(
2311 statusor.status().error_message(),
2312 HasSubstr(
2313 "Bounds of the scatter dimensions of updates must be same as the "
2314 "bounds of the corresponding dimensions of scatter indices."))
2315 << statusor.status();
2316 }
2317
TEST_F(ScatterGatherShapeInferenceTest,TfScatterWithUpdatesNotMatchingIndicesV2)2318 TEST_F(ScatterGatherShapeInferenceTest,
2319 TfScatterWithUpdatesNotMatchingIndicesV2) {
2320 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2321 matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
2322 to_apply_,
2323 HloScatterInstruction::MakeScatterDimNumbers(
2324 /*update_window_dims=*/{1},
2325 /*inserted_window_dims=*/{0},
2326 /*scatter_dims_to_operand_dims=*/{1},
2327 /*index_vector_dim=*/1));
2328 ASSERT_FALSE(statusor.ok());
2329 EXPECT_THAT(
2330 statusor.status().error_message(),
2331 HasSubstr(
2332 "Bounds of the scatter dimensions of updates must be same as the "
2333 "bounds of the corresponding dimensions of scatter indices."))
2334 << statusor.status();
2335 }
2336
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdates)2337 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
2338 TF_ASSERT_OK_AND_ASSIGN(
2339 Shape scatter_shape,
2340 ShapeInference::InferScatterShape(
2341 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2342 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
2343 HloScatterInstruction::MakeScatterDimNumbers(
2344 /*update_window_dims=*/{4},
2345 /*inserted_window_dims=*/{0},
2346 /*scatter_dims_to_operand_dims=*/{0},
2347 /*index_vector_dim=*/4)));
2348 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2349 << ShapeUtil::HumanString(scatter_shape);
2350 }
2351
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithFullUpdatesV2)2352 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
2353 TF_ASSERT_OK_AND_ASSIGN(
2354 Shape scatter_shape,
2355 ShapeInference::InferScatterShape(
2356 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2357 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
2358 HloScatterInstruction::MakeScatterDimNumbers(
2359 /*update_window_dims=*/{4},
2360 /*inserted_window_dims=*/{1},
2361 /*scatter_dims_to_operand_dims=*/{0},
2362 /*index_vector_dim=*/4)));
2363 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2364 << ShapeUtil::HumanString(scatter_shape);
2365 }
2366
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdates)2367 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
2368 TF_ASSERT_OK_AND_ASSIGN(
2369 Shape scatter_shape,
2370 ShapeInference::InferScatterShape(
2371 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2372 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
2373 HloScatterInstruction::MakeScatterDimNumbers(
2374 /*update_window_dims=*/{4},
2375 /*inserted_window_dims=*/{0},
2376 /*scatter_dims_to_operand_dims=*/{0},
2377 /*index_vector_dim=*/4)));
2378 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2379 << ShapeUtil::HumanString(scatter_shape);
2380 }
2381
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithPartialUpdatesV2)2382 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
2383 TF_ASSERT_OK_AND_ASSIGN(
2384 Shape scatter_shape,
2385 ShapeInference::InferScatterShape(
2386 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2387 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
2388 HloScatterInstruction::MakeScatterDimNumbers(
2389 /*update_window_dims=*/{4},
2390 /*inserted_window_dims=*/{1},
2391 /*scatter_dims_to_operand_dims=*/{0},
2392 /*index_vector_dim=*/4)));
2393 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
2394 << ShapeUtil::HumanString(scatter_shape);
2395 }
2396
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesBiggerThanInput)2397 TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
2398 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2399 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2400 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
2401 HloScatterInstruction::MakeScatterDimNumbers(
2402 /*update_window_dims=*/{4},
2403 /*inserted_window_dims=*/{1},
2404 /*scatter_dims_to_operand_dims=*/{0},
2405 /*index_vector_dim=*/4));
2406 ASSERT_FALSE(statusor.ok());
2407 EXPECT_THAT(
2408 statusor.status().error_message(),
2409 HasSubstr("Bounds of the window dimensions of updates must not exceed "
2410 "the bounds of the corresponding dimensions of operand."))
2411 << statusor.status();
2412 }
2413
TEST_F(ScatterGatherShapeInferenceTest,TfScatterNdWithUpdatesNotMatchingIndices)2414 TEST_F(ScatterGatherShapeInferenceTest,
2415 TfScatterNdWithUpdatesNotMatchingIndices) {
2416 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2417 matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
2418 ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
2419 HloScatterInstruction::MakeScatterDimNumbers(
2420 /*update_window_dims=*/{4},
2421 /*inserted_window_dims=*/{1},
2422 /*scatter_dims_to_operand_dims=*/{0},
2423 /*index_vector_dim=*/4));
2424 ASSERT_FALSE(statusor.ok());
2425 EXPECT_THAT(
2426 statusor.status().error_message(),
2427 HasSubstr(
2428 "Bounds of the scatter dimensions of updates must be same as the "
2429 "bounds of the corresponding dimensions of scatter indices."))
2430 << statusor.status();
2431 }
2432
TEST_F(ScatterGatherShapeInferenceTest,TfBatchDynamicUpdateSlice)2433 TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
2434 TF_ASSERT_OK_AND_ASSIGN(
2435 Shape scatter_shape,
2436 ShapeInference::InferScatterShape(
2437 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2438 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
2439 to_apply_,
2440 HloScatterInstruction::MakeScatterDimNumbers(
2441 /*update_window_dims=*/{4, 5, 6, 7, 8},
2442 /*inserted_window_dims=*/{},
2443 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2444 /*index_vector_dim=*/4)));
2445 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2446 << ShapeUtil::HumanString(scatter_shape);
2447 }
2448
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDim)2449 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
2450 TF_ASSERT_OK_AND_ASSIGN(
2451 Shape scatter_shape,
2452 ShapeInference::InferScatterShape(
2453 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
2454 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
2455 to_apply_,
2456 HloScatterInstruction::MakeScatterDimNumbers(
2457 /*update_window_dims=*/{4, 5, 6, 7, 8},
2458 /*inserted_window_dims=*/{},
2459 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2460 /*index_vector_dim=*/2)));
2461
2462 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2463 << ShapeUtil::HumanString(scatter_shape);
2464 }
2465
TEST_F(ScatterGatherShapeInferenceTest,NonDefaultScatterIndicesLeafDimV2)2466 TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
2467 TF_ASSERT_OK_AND_ASSIGN(
2468 Shape scatter_shape,
2469 ShapeInference::InferScatterShape(
2470 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
2471 ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
2472 to_apply_,
2473 HloScatterInstruction::MakeScatterDimNumbers(
2474 /*update_window_dims=*/{4, 5, 6, 7, 8},
2475 /*inserted_window_dims=*/{},
2476 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2477 /*index_vector_dim=*/0)));
2478
2479 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2480 << ShapeUtil::HumanString(scatter_shape);
2481 }
2482
TEST_F(ScatterGatherShapeInferenceTest,NoUpdateScatterDims)2483 TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
2484 // This is equivalent to a dynamic update slice.
2485 TF_ASSERT_OK_AND_ASSIGN(
2486 Shape scatter_shape,
2487 ShapeInference::InferScatterShape(
2488 f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
2489 ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
2490 HloScatterInstruction::MakeScatterDimNumbers(
2491 /*update_window_dims=*/{0, 1, 2, 3, 4},
2492 /*inserted_window_dims=*/{},
2493 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2494 /*index_vector_dim=*/0)));
2495
2496 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2497 << ShapeUtil::HumanString(scatter_shape);
2498 }
2499
TEST_F(ScatterGatherShapeInferenceTest,ScalarScatterIndices)2500 TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
2501 // The scalar indices "tensor" is a scalar S here that's used to update a
2502 // [30,29,28,27] shaped tensor within the operand at position S.
2503 TF_ASSERT_OK_AND_ASSIGN(
2504 Shape scatter_shape,
2505 ShapeInference::InferScatterShape(
2506 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2507 ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
2508 HloScatterInstruction::MakeScatterDimNumbers(
2509 /*update_window_dims=*/{0, 1, 2, 3},
2510 /*inserted_window_dims=*/{0},
2511 /*scatter_dims_to_operand_dims=*/{0},
2512 /*index_vector_dim=*/0)));
2513
2514 EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
2515 << ShapeUtil::HumanString(scatter_shape);
2516 }
2517
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedTensorInput)2518 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
2519 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2520 tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
2521 HloScatterInstruction::MakeScatterDimNumbers(
2522 /*update_window_dims=*/{0},
2523 /*inserted_window_dims=*/{1},
2524 /*scatter_dims_to_operand_dims=*/{1},
2525 /*index_vector_dim=*/1));
2526 ASSERT_FALSE(statusor.ok());
2527 EXPECT_THAT(statusor.status().error_message(),
2528 HasSubstr("Expected array argument for operand"))
2529 << statusor.status();
2530 }
2531
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedScatterIndicesInput)2532 TEST_F(ScatterGatherShapeInferenceTest,
2533 ScatterWithTupleShapedScatterIndicesInput) {
2534 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2535 s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
2536 HloScatterInstruction::MakeScatterDimNumbers(
2537 /*update_window_dims=*/{0},
2538 /*inserted_window_dims=*/{1},
2539 /*scatter_dims_to_operand_dims=*/{1},
2540 /*index_vector_dim=*/0));
2541 ASSERT_FALSE(statusor.ok());
2542 EXPECT_THAT(statusor.status().error_message(),
2543 HasSubstr("Expected array argument for scatter indices"))
2544 << statusor.status();
2545 }
2546
TEST_F(ScatterGatherShapeInferenceTest,ScatterWithTupleShapedUpdatesInput)2547 TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
2548 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2549 s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
2550 HloScatterInstruction::MakeScatterDimNumbers(
2551 /*update_window_dims=*/{0},
2552 /*inserted_window_dims=*/{1},
2553 /*scatter_dims_to_operand_dims=*/{1},
2554 /*index_vector_dim=*/0));
2555 ASSERT_FALSE(statusor.ok());
2556 EXPECT_THAT(statusor.status().error_message(),
2557 HasSubstr("Expected array argument for updates"))
2558 << statusor.status();
2559 }
2560
TEST_F(ScatterGatherShapeInferenceTest,FloatingPointScatterIndicesInput)2561 TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
2562 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2563 s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
2564 HloScatterInstruction::MakeScatterDimNumbers(
2565 /*update_window_dims=*/{0},
2566 /*inserted_window_dims=*/{1},
2567 /*scatter_dims_to_operand_dims=*/{1},
2568 /*index_vector_dim=*/0));
2569 ASSERT_FALSE(statusor.ok());
2570 EXPECT_THAT(statusor.status().error_message(),
2571 HasSubstr("Scatter indices parameter must be an integral tensor"))
2572 << statusor.status();
2573 }
2574
TEST_F(ScatterGatherShapeInferenceTest,OutOfBoundsScatterIndicesLeafDim)2575 TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
2576 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2577 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2578 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2579 HloScatterInstruction::MakeScatterDimNumbers(
2580 /*update_window_dims=*/{4, 5, 6},
2581 /*inserted_window_dims=*/{1, 2},
2582 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2583 /*index_vector_dim=*/10));
2584 ASSERT_FALSE(statusor.ok());
2585 EXPECT_THAT(statusor.status().error_message(),
2586 HasSubstr("Scatter index leaf dimension must be within [0, "
2587 "rank(scatter_indices) + 1)"))
2588 << statusor.status();
2589 }
2590
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdates)2591 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
2592 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2593 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2594 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
2595 HloScatterInstruction::MakeScatterDimNumbers(
2596 /*update_window_dims=*/{4, 5, 6},
2597 /*inserted_window_dims=*/{1, 2},
2598 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2599 /*index_vector_dim=*/4));
2600 ASSERT_FALSE(statusor.ok());
2601 EXPECT_THAT(statusor.status().error_message(),
2602 HasSubstr("Updates tensor must be of rank 7; got 8."))
2603 << statusor.status();
2604 }
2605
TEST_F(ScatterGatherShapeInferenceTest,InvalidUpdateComputation)2606 TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
2607 const ProgramShape invalid_update_computation =
2608 ShapeUtil::MakeProgramShape({f32_}, f32_);
2609 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2610 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2611 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
2612 invalid_update_computation,
2613 HloScatterInstruction::MakeScatterDimNumbers(
2614 /*update_window_dims=*/{4, 5, 6},
2615 /*inserted_window_dims=*/{1, 2},
2616 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2617 /*index_vector_dim=*/4));
2618 ASSERT_FALSE(statusor.ok());
2619 EXPECT_THAT(
2620 statusor.status().error_message(),
2621 HasSubstr("Reduction function must take 2 parameters, but takes 1"))
2622 << statusor.status();
2623 }
2624
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingUpdateWindowDims)2625 TEST_F(ScatterGatherShapeInferenceTest,
2626 InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
2627 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2628 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2629 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
2630 HloScatterInstruction::MakeScatterDimNumbers(
2631 /*update_window_dims=*/{4, 5, 6, 8, 7},
2632 /*inserted_window_dims=*/{},
2633 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2634 /*index_vector_dim=*/4));
2635 ASSERT_FALSE(statusor.ok());
2636 EXPECT_THAT(statusor.status().error_message(),
2637 HasSubstr("update_window_dims in scatter op must be sorted"))
2638 << statusor.status();
2639 }
2640
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedUpdateWindowDims)2641 TEST_F(ScatterGatherShapeInferenceTest,
2642 InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
2643 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2644 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2645 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
2646 HloScatterInstruction::MakeScatterDimNumbers(
2647 /*update_window_dims=*/{4, 5, 6, 7, 7},
2648 /*inserted_window_dims=*/{},
2649 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2650 /*index_vector_dim=*/4));
2651 ASSERT_FALSE(statusor.ok());
2652 EXPECT_THAT(statusor.status().error_message(),
2653 HasSubstr("update_window_dims in scatter op must not repeat"))
2654 << statusor.status();
2655 }
2656
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims)2657 TEST_F(ScatterGatherShapeInferenceTest,
2658 InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
2659 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2660 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2661 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
2662 HloScatterInstruction::MakeScatterDimNumbers(
2663 /*update_window_dims=*/{4, 5, 6, 7, 9},
2664 /*inserted_window_dims=*/{},
2665 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2666 /*index_vector_dim=*/4));
2667 ASSERT_FALSE(statusor.ok());
2668 EXPECT_THAT(statusor.status().error_message(),
2669 HasSubstr("Invalid update_window_dims set in scatter op; valid "
2670 "range is [0, 9)"))
2671 << statusor.status();
2672 }
2673
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_NonAscendingInsertedWindowDims)2674 TEST_F(ScatterGatherShapeInferenceTest,
2675 InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
2676 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2677 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2678 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2679 HloScatterInstruction::MakeScatterDimNumbers(
2680 /*update_window_dims=*/{4, 5, 6},
2681 /*inserted_window_dims=*/{2, 1},
2682 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2683 /*index_vector_dim=*/4));
2684 ASSERT_FALSE(statusor.ok());
2685 EXPECT_THAT(statusor.status().error_message(),
2686 HasSubstr("inserted_window_dims in scatter op must be sorted"))
2687 << statusor.status();
2688 }
2689
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedInsertedWindowDims)2690 TEST_F(ScatterGatherShapeInferenceTest,
2691 InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
2692 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2693 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2694 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2695 HloScatterInstruction::MakeScatterDimNumbers(
2696 /*update_window_dims=*/{4, 5, 6},
2697 /*inserted_window_dims=*/{1, 1},
2698 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2699 /*index_vector_dim=*/4));
2700 ASSERT_FALSE(statusor.ok());
2701 EXPECT_THAT(statusor.status().error_message(),
2702 HasSubstr("inserted_window_dims in scatter op must not repeat"))
2703 << statusor.status();
2704 }
2705
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims)2706 TEST_F(ScatterGatherShapeInferenceTest,
2707 InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
2708 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2709 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2710 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2711 HloScatterInstruction::MakeScatterDimNumbers(
2712 /*update_window_dims=*/{4, 5, 6},
2713 /*inserted_window_dims=*/{1, 5},
2714 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
2715 /*index_vector_dim=*/4));
2716 ASSERT_FALSE(statusor.ok());
2717 EXPECT_THAT(statusor.status().error_message(),
2718 HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
2719 "range is [0, 5)"))
2720 << statusor.status();
2721 }
2722
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims)2723 TEST_F(ScatterGatherShapeInferenceTest,
2724 InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
2725 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2726 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2727 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2728 HloScatterInstruction::MakeScatterDimNumbers(
2729 /*update_window_dims=*/{4, 5, 6},
2730 /*inserted_window_dims=*/{1, 2},
2731 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
2732 /*index_vector_dim=*/4));
2733 ASSERT_FALSE(statusor.ok());
2734 EXPECT_THAT(
2735 statusor.status().error_message(),
2736 HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
2737 "the bound of dimension index_vector_dim=4 of scatter_indices "
2738 "is 5. These two numbers must be equal"))
2739 << statusor.status();
2740 }
2741
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims)2742 TEST_F(ScatterGatherShapeInferenceTest,
2743 InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
2744 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2745 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2746 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2747 HloScatterInstruction::MakeScatterDimNumbers(
2748 /*update_window_dims=*/{4, 5, 6},
2749 /*inserted_window_dims=*/{1, 2},
2750 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
2751 /*index_vector_dim=*/4));
2752 ASSERT_FALSE(statusor.ok());
2753 EXPECT_THAT(statusor.status().error_message(),
2754 HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
2755 "is [0, 5), got: 4->10"))
2756 << statusor.status();
2757 }
2758
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims)2759 TEST_F(ScatterGatherShapeInferenceTest,
2760 InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
2761 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2762 f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
2763 ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
2764 HloScatterInstruction::MakeScatterDimNumbers(
2765 /*update_window_dims=*/{4, 5, 6},
2766 /*inserted_window_dims=*/{1, 2},
2767 /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
2768 /*index_vector_dim=*/4));
2769 ASSERT_FALSE(statusor.ok());
2770 EXPECT_THAT(
2771 statusor.status().error_message(),
2772 HasSubstr(
2773 "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
2774 << statusor.status();
2775 }
2776
TEST_F(ScatterGatherShapeInferenceTest,InvalidScatterDimNumbers_InsufficientWindowDims)2777 TEST_F(ScatterGatherShapeInferenceTest,
2778 InvalidScatterDimNumbers_InsufficientWindowDims) {
2779 StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
2780 f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
2781 ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
2782 HloScatterInstruction::MakeScatterDimNumbers(
2783 /*update_window_dims=*/{0, 1, 2, 3},
2784 /*inserted_window_dims=*/{},
2785 /*scatter_dims_to_operand_dims=*/{0},
2786 /*index_vector_dim=*/0));
2787 ASSERT_FALSE(statusor.ok());
2788 EXPECT_THAT(
2789 statusor.status().error_message(),
2790 HasSubstr(
2791 "Scatter op has window of size 4; doesn't match operand of rank 5."))
2792 << statusor.status();
2793 }
2794
2795 } // namespace
2796 } // namespace xla
2797