1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_runner.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33
34 namespace op = xla::testing::opcode_matchers;
35
36 namespace xla {
37 namespace {
38
39 class DynamicDimensionInferenceTest : public HloTestBase {
40 protected:
DynamicDimensionInferenceTest()41 DynamicDimensionInferenceTest() : HloTestBase() {
42 module_ = CreateNewVerifiedModule();
43 }
44
RunInference()45 Status RunInference() {
46 TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
47 DynamicDimensionInference::Run(module_.get()));
48
49 inference_ = absl::make_unique<DynamicDimensionInference>(inference);
50 return Status::OK();
51 }
52
GetAdd()53 HloComputation* GetAdd() {
54 auto embedded_builder = HloComputation::Builder("add");
55 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
56 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
57 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
58 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
59 embedded_builder.AddInstruction(
60 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
61 return module_->AddEmbeddedComputation(embedded_builder.Build());
62 }
63
GetGe()64 HloComputation* GetGe() {
65 auto embedded_builder = HloComputation::Builder("ge");
66 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
67 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
68 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
69 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
70 embedded_builder.AddInstruction(HloInstruction::CreateCompare(
71 ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
72 return module_->AddEmbeddedComputation(embedded_builder.Build());
73 }
74
75 std::unique_ptr<HloModule> module_;
76 std::unique_ptr<DynamicDimensionInference> inference_;
77 const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
78 };
79
TEST_F(DynamicDimensionInferenceTest,ParamTest)80 TEST_F(DynamicDimensionInferenceTest, ParamTest) {
81 auto builder = HloComputation::Builder(TestName());
82 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
83
84 auto param = builder.AddInstruction(
85 HloInstruction::CreateParameter(0, input_shape, "param"));
86 auto param2 = builder.AddInstruction(
87 HloInstruction::CreateParameter(1, scalar_shape_, "param"));
88
89 module_->AddEntryComputation(builder.Build());
90 SCOPED_TRACE(module_->ToString());
91
92 // Set up dynamic parameter binding.
93 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
94 DynamicParameterBinding::DynamicParameter{1, {}},
95 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
96
97 TF_ASSERT_OK(RunInference());
98 EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2);
99 EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr);
100 EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr);
101 }
102
TEST_F(DynamicDimensionInferenceTest,ParamTestTuple)103 TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) {
104 auto builder = HloComputation::Builder(TestName());
105 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
106
107 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
108 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
109
110 module_->AddEntryComputation(builder.Build());
111 // Set up dynamic parameter binding.
112 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
113 DynamicParameterBinding::DynamicParameter{0, {1}},
114 DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
115
116 SCOPED_TRACE(module_->ToString());
117 TF_ASSERT_OK(RunInference());
118 EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
119 op::GetTupleElement(param, 1));
120
121 EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
122 }
123
TEST_F(DynamicDimensionInferenceTest,GetTupleElement)124 TEST_F(DynamicDimensionInferenceTest, GetTupleElement) {
125 // When data flows through GTE, the dynamic dimension size keeps the
126 // same, and the index has its front popped.
127 auto builder = HloComputation::Builder(TestName());
128 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
129
130 auto param = builder.AddInstruction(HloInstruction::CreateParameter(
131 0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
132
133 auto gte = builder.AddInstruction(
134 HloInstruction::CreateGetTupleElement(input_shape, param, 0));
135
136 module_->AddEntryComputation(builder.Build());
137 // Set up dynamic parameter binding.
138 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
139 DynamicParameterBinding::DynamicParameter{0, {1}},
140 DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
141
142 SCOPED_TRACE(module_->ToString());
143 TF_ASSERT_OK(RunInference());
144 EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
145 op::GetTupleElement(param, 1));
146
147 EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1),
148 op::GetTupleElement(param, 1));
149
150 EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
151 }
152
TEST_F(DynamicDimensionInferenceTest,ElementwiseTest)153 TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) {
154 // When data flows through elementwise, the dynamic dimension size keeps the
155 // same.
156 auto builder = HloComputation::Builder(TestName());
157 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
158
159 auto data_param = builder.AddInstruction(
160 HloInstruction::CreateParameter(0, input_shape, "data_param"));
161 auto size_param = builder.AddInstruction(
162 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
163
164 auto* negate = builder.AddInstruction(
165 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
166
167 module_->AddEntryComputation(builder.Build());
168 // Set up dynamic parameter binding.
169 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
170 DynamicParameterBinding::DynamicParameter{1, {}},
171 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
172
173 SCOPED_TRACE(module_->ToString());
174 TF_ASSERT_OK(RunInference());
175 EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param);
176 }
177
TEST_F(DynamicDimensionInferenceTest,ReduceTestI)178 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
179 auto builder = HloComputation::Builder(TestName());
180 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
181 auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
182
183 auto data_param = builder.AddInstruction(
184 HloInstruction::CreateParameter(0, input_shape, "data_param"));
185 auto size_param = builder.AddInstruction(
186 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
187
188 auto negate = builder.AddInstruction(
189 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
190
191 auto init = builder.AddInstruction(
192 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
193
194 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
195 reduce_shape, negate, init, {0, 2}, GetAdd()));
196
197 module_->AddEntryComputation(builder.Build());
198
199 // Set up dynamic parameter binding.
200 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
201 DynamicParameterBinding::DynamicParameter{1, {}},
202 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
203
204 SCOPED_TRACE(module_->ToString());
205 TF_ASSERT_OK(RunInference());
206 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param);
207 }
208
TEST_F(DynamicDimensionInferenceTest,ReduceTestII)209 TEST_F(DynamicDimensionInferenceTest, ReduceTestII) {
210 // Same as ReduceTestI, but only reduce one dimension.
211 auto builder = HloComputation::Builder(TestName());
212 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
213 auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
214
215 auto data_param = builder.AddInstruction(
216 HloInstruction::CreateParameter(0, input_shape, "data_param"));
217 auto size_param = builder.AddInstruction(
218 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
219
220 auto negate = builder.AddInstruction(
221 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
222
223 auto init = builder.AddInstruction(
224 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
225
226 auto reduce = builder.AddInstruction(
227 HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd()));
228
229 module_->AddEntryComputation(builder.Build());
230
231 // Set up dynamic parameter binding.
232 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
233 DynamicParameterBinding::DynamicParameter{1, {}},
234 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
235
236 SCOPED_TRACE(module_->ToString());
237 TF_ASSERT_OK(RunInference());
238 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param);
239 EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr);
240 }
241
TEST_F(DynamicDimensionInferenceTest,DotTest)242 TEST_F(DynamicDimensionInferenceTest, DotTest) {
243 auto builder = HloComputation::Builder(TestName());
244 constexpr int xdim = 3;
245 constexpr int ydim = 2;
246 constexpr int zdim = 1;
247 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
248 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
249 auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
250
251 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
252 /*parameter_number=*/0, xy_shape, "A"));
253 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
254 /*parameter_number=*/1, yz_shape, "B"));
255 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
256 /*parameter_number=*/2, scalar_shape_, "size_param"));
257
258 DotDimensionNumbers dot_dnums;
259 dot_dnums.add_lhs_contracting_dimensions(1);
260 dot_dnums.add_rhs_contracting_dimensions(0);
261 auto dot = builder.AddInstruction(
262 HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
263 HloTestBase::DefaultPrecisionConfig(2)));
264
265 module_->AddEntryComputation(builder.Build());
266
267 // Set up dynamic parameter binding for non-contracting dimension.
268 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
269 DynamicParameterBinding::DynamicParameter{2, {}},
270 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
271
272 // Set up binding for contracting dimensions.
273 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
274 DynamicParameterBinding::DynamicParameter{2, {}},
275 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
276 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
277 DynamicParameterBinding::DynamicParameter{2, {}},
278 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
279
280 SCOPED_TRACE(module_->ToString());
281 TF_ASSERT_OK(RunInference());
282 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
283 EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
284 }
285
TEST_F(DynamicDimensionInferenceTest,ConvolutionTest)286 TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
287 auto builder = HloComputation::Builder(TestName());
288 constexpr int xdim = 3;
289 constexpr int ydim = 2;
290 constexpr int zdim = 1;
291 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
292 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
293 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
294
295 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
296 /*parameter_number=*/0, xy_shape, "A"));
297 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
298 /*parameter_number=*/1, yz_shape, "B"));
299 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
300 /*parameter_number=*/2, scalar_shape_, "size_param"));
301
302 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
303
304 dnums.set_kernel_input_feature_dimension(0);
305 dnums.set_kernel_output_feature_dimension(1);
306 dnums.set_input_batch_dimension(0);
307 dnums.set_output_batch_dimension(1);
308 dnums.set_output_feature_dimension(0);
309
310 Window window;
311
312 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
313 zx_shape, a_param, b_param, /*feature_group_count=*/1,
314 /*batch_group_count=*/1, window, dnums,
315 HloTestBase::DefaultPrecisionConfig(2)));
316
317 module_->AddEntryComputation(builder.Build());
318
319 // Set up dynamic parameter binding for non-contracting dimension.
320 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
321 DynamicParameterBinding::DynamicParameter{2, {}},
322 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
323
324 // Set up binding for contracting dimensions.
325 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
326 DynamicParameterBinding::DynamicParameter{2, {}},
327 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
328
329 SCOPED_TRACE(module_->ToString());
330 TF_ASSERT_OK(RunInference());
331 EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param);
332 EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr);
333 }
334
TEST_F(DynamicDimensionInferenceTest,TransposeTest)335 TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
336 // Test the ability to trace unmodified dimensions
337 auto builder = HloComputation::Builder(TestName());
338 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
339 auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
340
341 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
342 /*parameter_number=*/0, input_shape, "A"));
343 auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
344 /*parameter_number=*/1, scalar_shape_, "size_param"));
345 auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
346 /*parameter_number=*/2, scalar_shape_, "size_param"));
347 auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
348 /*parameter_number=*/3, scalar_shape_, "size_param"));
349
350 auto* transpose = builder.AddInstruction(
351 HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0}));
352
353 module_->AddEntryComputation(builder.Build());
354
355 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
356 DynamicParameterBinding::DynamicParameter{1, {}},
357 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
358
359 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
360 DynamicParameterBinding::DynamicParameter{2, {}},
361 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
362
363 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
364 DynamicParameterBinding::DynamicParameter{3, {}},
365 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
366
367 SCOPED_TRACE(module_->ToString());
368 TF_ASSERT_OK(RunInference());
369 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
370 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2);
371 EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
372 }
373
TEST_F(DynamicDimensionInferenceTest,ReshapeTest)374 TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
375 // Test the ability to trace unmodified reshape dimensions.
376 auto builder = HloComputation::Builder(TestName());
377 auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
378 auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
379
380 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
381 /*parameter_number=*/0, input_shape, "A"));
382 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
383 /*parameter_number=*/1, scalar_shape_, "size_param"));
384
385 auto* reshape = builder.AddInstruction(
386 HloInstruction::CreateReshape(output_shape, a_param));
387
388 module_->AddEntryComputation(builder.Build());
389
390 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
391 DynamicParameterBinding::DynamicParameter{1, {}},
392 DynamicParameterBinding::DynamicDimension{0, {}, 2}));
393
394 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
395 DynamicParameterBinding::DynamicParameter{1, {}},
396 DynamicParameterBinding::DynamicDimension{0, {}, 3}));
397
398 SCOPED_TRACE(module_->ToString());
399 TF_ASSERT_OK(RunInference());
400 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
401 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param);
402 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr);
403 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param);
404 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr);
405 EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr);
406 }
407
TEST_F(DynamicDimensionInferenceTest,ReshapeTestUnimplemented)408 TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) {
409 // Test the ability to trace unmodified reshape dimensions.
410 auto builder = HloComputation::Builder(TestName());
411 auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
412 auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
413
414 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
415 /*parameter_number=*/0, input_shape, "A"));
416
417 builder.AddInstruction(HloInstruction::CreateParameter(
418 /*parameter_number=*/1, scalar_shape_, "size_param"));
419
420 builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param));
421
422 module_->AddEntryComputation(builder.Build());
423
424 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
425 DynamicParameterBinding::DynamicParameter{1, {}},
426 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
427
428 SCOPED_TRACE(module_->ToString());
429 Status status = RunInference();
430 EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
431 }
432
TEST_F(DynamicDimensionInferenceTest,BroadcastTest)433 TEST_F(DynamicDimensionInferenceTest, BroadcastTest) {
434 // Test the ability to trace broadcast dimension.
435 auto builder = HloComputation::Builder(TestName());
436 auto input_shape = ShapeUtil::MakeShape(F32, {2});
437 auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
438
439 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
440 /*parameter_number=*/0, input_shape, "A"));
441 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
442 /*parameter_number=*/1, scalar_shape_, "size_param"));
443
444 auto* broadcast = builder.AddInstruction(
445 HloInstruction::CreateBroadcast(output_shape, a_param, {1}));
446
447 module_->AddEntryComputation(builder.Build());
448
449 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
450 DynamicParameterBinding::DynamicParameter{1, {}},
451 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
452
453 SCOPED_TRACE(module_->ToString());
454 TF_ASSERT_OK(RunInference());
455 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr);
456 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param);
457 EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr);
458 }
459
TEST_F(DynamicDimensionInferenceTest,WhileTest)460 TEST_F(DynamicDimensionInferenceTest, WhileTest) {
461 // Test the ability to trace into while loops.
462 auto builder = HloComputation::Builder(TestName());
463 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
464 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
465 auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
466
467 // Body:
468 //
469 // Param
470 // | |
471 // GTE1 GTE2
472 // | |
473 // ADD
474 auto body_builder = HloComputation::Builder("body");
475 auto body_param = body_builder.AddInstruction(
476 HloInstruction::CreateParameter(0, tuple_shape, "param"));
477 auto gte_0 = body_builder.AddInstruction(
478 HloInstruction::CreateGetTupleElement(input_shape, body_param, 0));
479 auto gte_1 = body_builder.AddInstruction(
480 HloInstruction::CreateGetTupleElement(input_shape, body_param, 1));
481 auto add = body_builder.AddInstruction(
482 HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1));
483 body_builder.AddInstruction(HloInstruction::CreateTuple({add, add}));
484
485 HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
486
487 auto cond_builder = HloComputation::Builder("condition");
488 cond_builder.AddInstruction(
489 HloInstruction::CreateParameter(0, tuple_shape, "param"));
490 cond_builder.AddInstruction(
491 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
492 HloComputation* condition =
493 module_->AddEmbeddedComputation(cond_builder.Build());
494
495 // Entry:
496 //
497 // Param
498 // |
499 // While
500 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
501 /*parameter_number=*/0, tuple_shape, "A"));
502 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
503 /*parameter_number=*/1, scalar_shape_, "size_param"));
504 builder.AddInstruction(
505 HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
506
507 module_->AddEntryComputation(builder.Build());
508
509 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
510 DynamicParameterBinding::DynamicParameter{1, {}},
511 DynamicParameterBinding::DynamicDimension{0, {0}, 0}));
512
513 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
514 DynamicParameterBinding::DynamicParameter{1, {}},
515 DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
516
517 // Test that dynamic dimension inference does the right thing. A lambda is
518 // used here since we want to test twice by running inference again
519 // (idempotency).
520 auto test_dynamic_dimension = [&]() {
521 HloInstruction* while_hlo = nullptr;
522 // The while hlo has been replaced, find the new one.
523 for (HloInstruction* inst : module_->entry_computation()->instructions()) {
524 if (inst->opcode() == HloOpcode::kWhile) {
525 while_hlo = inst;
526 }
527 }
528 ASSERT_NE(while_hlo, nullptr);
529 // The original while shape has 2 parameters. With dynamic size passed in
530 // as an extra parameter, the tuple should have 3 elements.
531 EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3);
532 HloInstruction* add = nullptr;
533 for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
534 if (inst->opcode() == HloOpcode::kAdd) {
535 add = inst;
536 }
537 }
538 EXPECT_NE(add, nullptr);
539 EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr);
540 EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param);
541 EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param);
542 };
543
544 TF_ASSERT_OK(RunInference());
545 test_dynamic_dimension();
546 TF_ASSERT_OK(RunInference());
547 test_dynamic_dimension();
548 }
549
TEST_F(DynamicDimensionInferenceTest,ReduceWindowBatchTest)550 TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
551 // Test the ability to trace reduce window batch dimensions.
552 auto builder = HloComputation::Builder(TestName());
553 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
554 auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
555
556 Window window;
557 // First dimension is unchanged.
558 WindowDimension* batch_dim = window.add_dimensions();
559 batch_dim->set_size(1);
560 batch_dim->set_stride(1);
561 batch_dim->set_padding_low(0);
562 batch_dim->set_padding_high(0);
563 batch_dim->set_window_dilation(1);
564 batch_dim->set_base_dilation(1);
565
566 // Second and third dimension are reduced.
567 for (int64 i = 0; i < 2; ++i) {
568 WindowDimension* dim = window.add_dimensions();
569 dim->set_size(2);
570 dim->set_stride(2);
571 dim->set_padding_low(0);
572 dim->set_padding_high(0);
573 dim->set_window_dilation(1);
574 dim->set_base_dilation(1);
575 }
576
577 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
578 /*parameter_number=*/0, input_shape, "A"));
579 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
580 /*parameter_number=*/1, scalar_shape_, "size_param"));
581
582 auto init = builder.AddInstruction(
583 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
584
585 auto* reduce_window =
586 builder.AddInstruction(HloInstruction::CreateReduceWindow(
587 output_shape, a_param, init, window, GetAdd()));
588
589 module_->AddEntryComputation(builder.Build());
590
591 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
592 DynamicParameterBinding::DynamicParameter{1, {}},
593 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
594
595 SCOPED_TRACE(module_->ToString());
596 TF_ASSERT_OK(RunInference());
597 EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param);
598 }
599
TEST_F(DynamicDimensionInferenceTest,SelectAndScatterTest)600 TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) {
601 // Test the ability to trace select and scatter batch dimensions.
602 auto builder = HloComputation::Builder(TestName());
603 auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
604 auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
605
606 Window window;
607 // First dimension is unchanged.
608 WindowDimension* batch_dim = window.add_dimensions();
609 batch_dim->set_size(1);
610 batch_dim->set_stride(1);
611 batch_dim->set_padding_low(0);
612 batch_dim->set_padding_high(0);
613 batch_dim->set_window_dilation(1);
614 batch_dim->set_base_dilation(1);
615
616 // Second and third dimension are reduced.
617 for (int64 i = 0; i < 2; ++i) {
618 WindowDimension* dim = window.add_dimensions();
619 dim->set_size(2);
620 dim->set_stride(2);
621 dim->set_padding_low(0);
622 dim->set_padding_high(0);
623 dim->set_window_dilation(1);
624 dim->set_base_dilation(1);
625 }
626
627 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
628 /*parameter_number=*/0, input_shape, "A"));
629 auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
630 /*parameter_number=*/1, scalar_shape_, "size_param"));
631 auto* source = builder.AddInstruction(HloInstruction::CreateParameter(
632 /*parameter_number=*/2, source_shape, "B"));
633
634 auto init = builder.AddInstruction(
635 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
636
637 auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
638 input_shape, a_param, GetGe(), window, source, init, GetAdd()));
639
640 module_->AddEntryComputation(builder.Build());
641
642 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
643 DynamicParameterBinding::DynamicParameter{1, {}},
644 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
645 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
646 DynamicParameterBinding::DynamicParameter{1, {}},
647 DynamicParameterBinding::DynamicDimension{2, {}, 0}));
648
649 SCOPED_TRACE(module_->ToString());
650 TF_ASSERT_OK(RunInference());
651 EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param);
652 }
653
TEST_F(DynamicDimensionInferenceTest,SliceTest)654 TEST_F(DynamicDimensionInferenceTest, SliceTest) {
655 auto builder = HloComputation::Builder(TestName());
656
657 auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
658 0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
659 auto size_param = builder.AddInstruction(
660 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
661
662 auto* slice = builder.AddInstruction(HloInstruction::CreateSlice(
663 ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0},
664 /*limit_indices=*/{5, 7}, /*strides=*/{1, 1}));
665
666 module_->AddEntryComputation(builder.Build());
667 // Set up dynamic parameter binding.
668 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
669 DynamicParameterBinding::DynamicParameter{1, {}},
670 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
671
672 TF_ASSERT_OK(RunInference());
673 EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param);
674 }
675
676 } // namespace
677 } // namespace xla
678