1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_runner.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/test.h"
28 #include "tensorflow/compiler/xla/test_helpers.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/status_test_util.h"
32 #include "tensorflow/core/platform/test_benchmark.h"
33 
34 namespace op = xla::testing::opcode_matchers;
35 
36 namespace xla {
37 namespace {
38 
39 class DynamicDimensionInferenceTest : public HloTestBase {
40  protected:
DynamicDimensionInferenceTest()41   DynamicDimensionInferenceTest() : HloTestBase() {
42     module_ = CreateNewVerifiedModule();
43   }
44 
RunInference()45   Status RunInference() {
46     TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference,
47                         DynamicDimensionInference::Run(module_.get()));
48 
49     inference_ = absl::make_unique<DynamicDimensionInference>(inference);
50     return Status::OK();
51   }
52 
GetAdd()53   HloComputation* GetAdd() {
54     auto embedded_builder = HloComputation::Builder("add");
55     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
56         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
57     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
58         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
59     embedded_builder.AddInstruction(
60         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
61     return module_->AddEmbeddedComputation(embedded_builder.Build());
62   }
63 
GetGe()64   HloComputation* GetGe() {
65     auto embedded_builder = HloComputation::Builder("ge");
66     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
67         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
68     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
69         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
70     embedded_builder.AddInstruction(HloInstruction::CreateCompare(
71         ShapeUtil::MakeShape(PRED, {}), lhs, rhs, ComparisonDirection::kGe));
72     return module_->AddEmbeddedComputation(embedded_builder.Build());
73   }
74 
75   std::unique_ptr<HloModule> module_;
76   std::unique_ptr<DynamicDimensionInference> inference_;
77   const Shape scalar_shape_ = ShapeUtil::MakeShape(S32, {});
78 };
79 
TEST_F(DynamicDimensionInferenceTest,ParamTest)80 TEST_F(DynamicDimensionInferenceTest, ParamTest) {
81   auto builder = HloComputation::Builder(TestName());
82   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
83 
84   auto param = builder.AddInstruction(
85       HloInstruction::CreateParameter(0, input_shape, "param"));
86   auto param2 = builder.AddInstruction(
87       HloInstruction::CreateParameter(1, scalar_shape_, "param"));
88 
89   module_->AddEntryComputation(builder.Build());
90   SCOPED_TRACE(module_->ToString());
91 
92   // Set up dynamic parameter binding.
93   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
94       DynamicParameterBinding::DynamicParameter{1, {}},
95       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
96 
97   TF_ASSERT_OK(RunInference());
98   EXPECT_EQ(inference_->GetDynamicSize(param, {}, 1), param2);
99   EXPECT_EQ(inference_->GetDynamicSize(param, {}, 0), nullptr);
100   EXPECT_EQ(inference_->GetDynamicSize(param2, {}, 0), nullptr);
101 }
102 
TEST_F(DynamicDimensionInferenceTest,ParamTestTuple)103 TEST_F(DynamicDimensionInferenceTest, ParamTestTuple) {
104   auto builder = HloComputation::Builder(TestName());
105   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
106 
107   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
108       0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
109 
110   module_->AddEntryComputation(builder.Build());
111   // Set up dynamic parameter binding.
112   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
113       DynamicParameterBinding::DynamicParameter{0, {1}},
114       DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
115 
116   SCOPED_TRACE(module_->ToString());
117   TF_ASSERT_OK(RunInference());
118   EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
119               op::GetTupleElement(param, 1));
120 
121   EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
122 }
123 
TEST_F(DynamicDimensionInferenceTest,GetTupleElement)124 TEST_F(DynamicDimensionInferenceTest, GetTupleElement) {
125   // When data flows through GTE, the dynamic dimension size keeps the
126   // same, and the index has its front popped.
127   auto builder = HloComputation::Builder(TestName());
128   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
129 
130   auto param = builder.AddInstruction(HloInstruction::CreateParameter(
131       0, ShapeUtil::MakeTupleShape({input_shape, scalar_shape_}), "param"));
132 
133   auto gte = builder.AddInstruction(
134       HloInstruction::CreateGetTupleElement(input_shape, param, 0));
135 
136   module_->AddEntryComputation(builder.Build());
137   // Set up dynamic parameter binding.
138   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
139       DynamicParameterBinding::DynamicParameter{0, {1}},
140       DynamicParameterBinding::DynamicDimension{0, {0}, 1}));
141 
142   SCOPED_TRACE(module_->ToString());
143   TF_ASSERT_OK(RunInference());
144   EXPECT_THAT(inference_->GetDynamicSize(param, {0}, 1),
145               op::GetTupleElement(param, 1));
146 
147   EXPECT_THAT(inference_->GetDynamicSize(gte, {}, 1),
148               op::GetTupleElement(param, 1));
149 
150   EXPECT_EQ(inference_->GetDynamicSize(param, {0}, 0), nullptr);
151 }
152 
TEST_F(DynamicDimensionInferenceTest,ElementwiseTest)153 TEST_F(DynamicDimensionInferenceTest, ElementwiseTest) {
154   // When data flows through elementwise, the dynamic dimension size keeps the
155   // same.
156   auto builder = HloComputation::Builder(TestName());
157   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
158 
159   auto data_param = builder.AddInstruction(
160       HloInstruction::CreateParameter(0, input_shape, "data_param"));
161   auto size_param = builder.AddInstruction(
162       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
163 
164   auto* negate = builder.AddInstruction(
165       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
166 
167   module_->AddEntryComputation(builder.Build());
168   // Set up dynamic parameter binding.
169   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
170       DynamicParameterBinding::DynamicParameter{1, {}},
171       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
172 
173   SCOPED_TRACE(module_->ToString());
174   TF_ASSERT_OK(RunInference());
175   EXPECT_EQ(inference_->GetDynamicSize(negate, {}, 1), size_param);
176 }
177 
TEST_F(DynamicDimensionInferenceTest,ReduceTestI)178 TEST_F(DynamicDimensionInferenceTest, ReduceTestI) {
179   auto builder = HloComputation::Builder(TestName());
180   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
181   auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
182 
183   auto data_param = builder.AddInstruction(
184       HloInstruction::CreateParameter(0, input_shape, "data_param"));
185   auto size_param = builder.AddInstruction(
186       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
187 
188   auto negate = builder.AddInstruction(
189       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
190 
191   auto init = builder.AddInstruction(
192       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
193 
194   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
195       reduce_shape, negate, init, {0, 2}, GetAdd()));
196 
197   module_->AddEntryComputation(builder.Build());
198 
199   // Set up dynamic parameter binding.
200   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
201       DynamicParameterBinding::DynamicParameter{1, {}},
202       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
203 
204   SCOPED_TRACE(module_->ToString());
205   TF_ASSERT_OK(RunInference());
206   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), size_param);
207 }
208 
TEST_F(DynamicDimensionInferenceTest,ReduceTestII)209 TEST_F(DynamicDimensionInferenceTest, ReduceTestII) {
210   // Same as ReduceTestI, but only reduce one dimension.
211   auto builder = HloComputation::Builder(TestName());
212   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
213   auto reduce_shape = ShapeUtil::MakeShape(F32, {1, 2});
214 
215   auto data_param = builder.AddInstruction(
216       HloInstruction::CreateParameter(0, input_shape, "data_param"));
217   auto size_param = builder.AddInstruction(
218       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
219 
220   auto negate = builder.AddInstruction(
221       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
222 
223   auto init = builder.AddInstruction(
224       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
225 
226   auto reduce = builder.AddInstruction(
227       HloInstruction::CreateReduce(reduce_shape, negate, init, {1}, GetAdd()));
228 
229   module_->AddEntryComputation(builder.Build());
230 
231   // Set up dynamic parameter binding.
232   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
233       DynamicParameterBinding::DynamicParameter{1, {}},
234       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
235 
236   SCOPED_TRACE(module_->ToString());
237   TF_ASSERT_OK(RunInference());
238   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 1), size_param);
239   EXPECT_EQ(inference_->GetDynamicSize(reduce, {}, 0), nullptr);
240 }
241 
TEST_F(DynamicDimensionInferenceTest,DotTest)242 TEST_F(DynamicDimensionInferenceTest, DotTest) {
243   auto builder = HloComputation::Builder(TestName());
244   constexpr int xdim = 3;
245   constexpr int ydim = 2;
246   constexpr int zdim = 1;
247   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
248   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
249   auto xz_shape = ShapeUtil::MakeShape(F32, {xdim, zdim});
250 
251   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
252       /*parameter_number=*/0, xy_shape, "A"));
253   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
254       /*parameter_number=*/1, yz_shape, "B"));
255   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
256       /*parameter_number=*/2, scalar_shape_, "size_param"));
257 
258   DotDimensionNumbers dot_dnums;
259   dot_dnums.add_lhs_contracting_dimensions(1);
260   dot_dnums.add_rhs_contracting_dimensions(0);
261   auto dot = builder.AddInstruction(
262       HloInstruction::CreateDot(xz_shape, a_param, b_param, dot_dnums,
263                                 HloTestBase::DefaultPrecisionConfig(2)));
264 
265   module_->AddEntryComputation(builder.Build());
266 
267   // Set up dynamic parameter binding for non-contracting dimension.
268   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
269       DynamicParameterBinding::DynamicParameter{2, {}},
270       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
271 
272   // Set up binding for contracting dimensions.
273   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
274       DynamicParameterBinding::DynamicParameter{2, {}},
275       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
276   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
277       DynamicParameterBinding::DynamicParameter{2, {}},
278       DynamicParameterBinding::DynamicDimension{1, {}, 0}));
279 
280   SCOPED_TRACE(module_->ToString());
281   TF_ASSERT_OK(RunInference());
282   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 0), size_param);
283   EXPECT_EQ(inference_->GetDynamicSize(dot, {}, 1), nullptr);
284 }
285 
TEST_F(DynamicDimensionInferenceTest,ConvolutionTest)286 TEST_F(DynamicDimensionInferenceTest, ConvolutionTest) {
287   auto builder = HloComputation::Builder(TestName());
288   constexpr int xdim = 3;
289   constexpr int ydim = 2;
290   constexpr int zdim = 1;
291   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
292   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
293   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
294 
295   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
296       /*parameter_number=*/0, xy_shape, "A"));
297   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
298       /*parameter_number=*/1, yz_shape, "B"));
299   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
300       /*parameter_number=*/2, scalar_shape_, "size_param"));
301 
302   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
303 
304   dnums.set_kernel_input_feature_dimension(0);
305   dnums.set_kernel_output_feature_dimension(1);
306   dnums.set_input_batch_dimension(0);
307   dnums.set_output_batch_dimension(1);
308   dnums.set_output_feature_dimension(0);
309 
310   Window window;
311 
312   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
313       zx_shape, a_param, b_param, /*feature_group_count=*/1,
314       /*batch_group_count=*/1, window, dnums,
315       HloTestBase::DefaultPrecisionConfig(2)));
316 
317   module_->AddEntryComputation(builder.Build());
318 
319   // Set up dynamic parameter binding for non-contracting dimension.
320   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
321       DynamicParameterBinding::DynamicParameter{2, {}},
322       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
323 
324   // Set up binding for contracting dimensions.
325   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
326       DynamicParameterBinding::DynamicParameter{2, {}},
327       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
328 
329   SCOPED_TRACE(module_->ToString());
330   TF_ASSERT_OK(RunInference());
331   EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 1), size_param);
332   EXPECT_EQ(inference_->GetDynamicSize(conv, {}, 0), nullptr);
333 }
334 
TEST_F(DynamicDimensionInferenceTest,TransposeTest)335 TEST_F(DynamicDimensionInferenceTest, TransposeTest) {
336   // Test the ability to trace unmodified dimensions
337   auto builder = HloComputation::Builder(TestName());
338   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 3});
339   auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 1});
340 
341   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
342       /*parameter_number=*/0, input_shape, "A"));
343   auto* size_param_1 = builder.AddInstruction(HloInstruction::CreateParameter(
344       /*parameter_number=*/1, scalar_shape_, "size_param"));
345   auto* size_param_2 = builder.AddInstruction(HloInstruction::CreateParameter(
346       /*parameter_number=*/2, scalar_shape_, "size_param"));
347   auto* size_param_3 = builder.AddInstruction(HloInstruction::CreateParameter(
348       /*parameter_number=*/3, scalar_shape_, "size_param"));
349 
350   auto* transpose = builder.AddInstruction(
351       HloInstruction::CreateTranspose(output_shape, a_param, {2, 1, 0}));
352 
353   module_->AddEntryComputation(builder.Build());
354 
355   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
356       DynamicParameterBinding::DynamicParameter{1, {}},
357       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
358 
359   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
360       DynamicParameterBinding::DynamicParameter{2, {}},
361       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
362 
363   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
364       DynamicParameterBinding::DynamicParameter{3, {}},
365       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
366 
367   SCOPED_TRACE(module_->ToString());
368   TF_ASSERT_OK(RunInference());
369   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 0), size_param_3);
370   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 1), size_param_2);
371   EXPECT_EQ(inference_->GetDynamicSize(transpose, {}, 2), size_param_1);
372 }
373 
TEST_F(DynamicDimensionInferenceTest,ReshapeTest)374 TEST_F(DynamicDimensionInferenceTest, ReshapeTest) {
375   // Test the ability to trace unmodified reshape dimensions.
376   auto builder = HloComputation::Builder(TestName());
377   auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
378   auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
379 
380   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
381       /*parameter_number=*/0, input_shape, "A"));
382   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
383       /*parameter_number=*/1, scalar_shape_, "size_param"));
384 
385   auto* reshape = builder.AddInstruction(
386       HloInstruction::CreateReshape(output_shape, a_param));
387 
388   module_->AddEntryComputation(builder.Build());
389 
390   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
391       DynamicParameterBinding::DynamicParameter{1, {}},
392       DynamicParameterBinding::DynamicDimension{0, {}, 2}));
393 
394   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
395       DynamicParameterBinding::DynamicParameter{1, {}},
396       DynamicParameterBinding::DynamicDimension{0, {}, 3}));
397 
398   SCOPED_TRACE(module_->ToString());
399   TF_ASSERT_OK(RunInference());
400   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 0), nullptr);
401   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 1), size_param);
402   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 2), nullptr);
403   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 3), size_param);
404   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 4), nullptr);
405   EXPECT_EQ(inference_->GetDynamicSize(reshape, {}, 5), nullptr);
406 }
407 
TEST_F(DynamicDimensionInferenceTest,ReshapeTestUnimplemented)408 TEST_F(DynamicDimensionInferenceTest, ReshapeTestUnimplemented) {
409   // Test the ability to trace unmodified reshape dimensions.
410   auto builder = HloComputation::Builder(TestName());
411   auto input_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5, 6});
412   auto output_shape = ShapeUtil::MakeShape(F32, {6, 4, 1, 5, 2, 3});
413 
414   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
415       /*parameter_number=*/0, input_shape, "A"));
416 
417   builder.AddInstruction(HloInstruction::CreateParameter(
418       /*parameter_number=*/1, scalar_shape_, "size_param"));
419 
420   builder.AddInstruction(HloInstruction::CreateReshape(output_shape, a_param));
421 
422   module_->AddEntryComputation(builder.Build());
423 
424   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
425       DynamicParameterBinding::DynamicParameter{1, {}},
426       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
427 
428   SCOPED_TRACE(module_->ToString());
429   Status status = RunInference();
430   EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
431 }
432 
TEST_F(DynamicDimensionInferenceTest,BroadcastTest)433 TEST_F(DynamicDimensionInferenceTest, BroadcastTest) {
434   // Test the ability to trace broadcast dimension.
435   auto builder = HloComputation::Builder(TestName());
436   auto input_shape = ShapeUtil::MakeShape(F32, {2});
437   auto output_shape = ShapeUtil::MakeShape(F32, {3, 2, 4});
438 
439   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
440       /*parameter_number=*/0, input_shape, "A"));
441   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
442       /*parameter_number=*/1, scalar_shape_, "size_param"));
443 
444   auto* broadcast = builder.AddInstruction(
445       HloInstruction::CreateBroadcast(output_shape, a_param, {1}));
446 
447   module_->AddEntryComputation(builder.Build());
448 
449   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
450       DynamicParameterBinding::DynamicParameter{1, {}},
451       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
452 
453   SCOPED_TRACE(module_->ToString());
454   TF_ASSERT_OK(RunInference());
455   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 0), nullptr);
456   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 1), size_param);
457   EXPECT_EQ(inference_->GetDynamicSize(broadcast, {}, 2), nullptr);
458 }
459 
TEST_F(DynamicDimensionInferenceTest,WhileTest)460 TEST_F(DynamicDimensionInferenceTest, WhileTest) {
461   // Test the ability to trace into while loops.
462   auto builder = HloComputation::Builder(TestName());
463   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
464   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
465   auto tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
466 
467   // Body:
468   //
469   //   Param
470   //   |  |
471   // GTE1 GTE2
472   //   |  |
473   //    ADD
474   auto body_builder = HloComputation::Builder("body");
475   auto body_param = body_builder.AddInstruction(
476       HloInstruction::CreateParameter(0, tuple_shape, "param"));
477   auto gte_0 = body_builder.AddInstruction(
478       HloInstruction::CreateGetTupleElement(input_shape, body_param, 0));
479   auto gte_1 = body_builder.AddInstruction(
480       HloInstruction::CreateGetTupleElement(input_shape, body_param, 1));
481   auto add = body_builder.AddInstruction(
482       HloInstruction::CreateBinary(input_shape, HloOpcode::kAdd, gte_0, gte_1));
483   body_builder.AddInstruction(HloInstruction::CreateTuple({add, add}));
484 
485   HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
486 
487   auto cond_builder = HloComputation::Builder("condition");
488   cond_builder.AddInstruction(
489       HloInstruction::CreateParameter(0, tuple_shape, "param"));
490   cond_builder.AddInstruction(
491       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(false)));
492   HloComputation* condition =
493       module_->AddEmbeddedComputation(cond_builder.Build());
494 
495   // Entry:
496   //
497   //  Param
498   //   |
499   //  While
500   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
501       /*parameter_number=*/0, tuple_shape, "A"));
502   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
503       /*parameter_number=*/1, scalar_shape_, "size_param"));
504   builder.AddInstruction(
505       HloInstruction::CreateWhile(tuple_shape, condition, body, a_param));
506 
507   module_->AddEntryComputation(builder.Build());
508 
509   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
510       DynamicParameterBinding::DynamicParameter{1, {}},
511       DynamicParameterBinding::DynamicDimension{0, {0}, 0}));
512 
513   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
514       DynamicParameterBinding::DynamicParameter{1, {}},
515       DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
516 
517   // Test that dynamic dimension inference does the right thing. A lambda is
518   // used here since we want to test twice by running inference again
519   // (idempotency).
520   auto test_dynamic_dimension = [&]() {
521     HloInstruction* while_hlo = nullptr;
522     // The while hlo has been replaced, find the new one.
523     for (HloInstruction* inst : module_->entry_computation()->instructions()) {
524       if (inst->opcode() == HloOpcode::kWhile) {
525         while_hlo = inst;
526       }
527     }
528     ASSERT_NE(while_hlo, nullptr);
529     // The original while shape has 2 parameters. With dynamic size passed in
530     // as an extra parameter, the tuple should have 3 elements.
531     EXPECT_EQ(while_hlo->shape().tuple_shapes_size(), 3);
532     HloInstruction* add = nullptr;
533     for (HloInstruction* inst : while_hlo->while_body()->instructions()) {
534       if (inst->opcode() == HloOpcode::kAdd) {
535         add = inst;
536       }
537     }
538     EXPECT_NE(add, nullptr);
539     EXPECT_NE(inference_->GetDynamicSize(add, {}, 0), nullptr);
540     EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {0}, 0), size_param);
541     EXPECT_EQ(inference_->GetDynamicSize(while_hlo, {1}, 0), size_param);
542   };
543 
544   TF_ASSERT_OK(RunInference());
545   test_dynamic_dimension();
546   TF_ASSERT_OK(RunInference());
547   test_dynamic_dimension();
548 }
549 
TEST_F(DynamicDimensionInferenceTest,ReduceWindowBatchTest)550 TEST_F(DynamicDimensionInferenceTest, ReduceWindowBatchTest) {
551   // Test the ability to trace reduce window batch dimensions.
552   auto builder = HloComputation::Builder(TestName());
553   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
554   auto output_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
555 
556   Window window;
557   // First dimension is unchanged.
558   WindowDimension* batch_dim = window.add_dimensions();
559   batch_dim->set_size(1);
560   batch_dim->set_stride(1);
561   batch_dim->set_padding_low(0);
562   batch_dim->set_padding_high(0);
563   batch_dim->set_window_dilation(1);
564   batch_dim->set_base_dilation(1);
565 
566   // Second and third dimension are reduced.
567   for (int64 i = 0; i < 2; ++i) {
568     WindowDimension* dim = window.add_dimensions();
569     dim->set_size(2);
570     dim->set_stride(2);
571     dim->set_padding_low(0);
572     dim->set_padding_high(0);
573     dim->set_window_dilation(1);
574     dim->set_base_dilation(1);
575   }
576 
577   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
578       /*parameter_number=*/0, input_shape, "A"));
579   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
580       /*parameter_number=*/1, scalar_shape_, "size_param"));
581 
582   auto init = builder.AddInstruction(
583       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
584 
585   auto* reduce_window =
586       builder.AddInstruction(HloInstruction::CreateReduceWindow(
587           output_shape, a_param, init, window, GetAdd()));
588 
589   module_->AddEntryComputation(builder.Build());
590 
591   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
592       DynamicParameterBinding::DynamicParameter{1, {}},
593       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
594 
595   SCOPED_TRACE(module_->ToString());
596   TF_ASSERT_OK(RunInference());
597   EXPECT_EQ(inference_->GetDynamicSize(reduce_window, {}, 0), size_param);
598 }
599 
TEST_F(DynamicDimensionInferenceTest,SelectAndScatterTest)600 TEST_F(DynamicDimensionInferenceTest, SelectAndScatterTest) {
601   // Test the ability to trace select and scatter batch dimensions.
602   auto builder = HloComputation::Builder(TestName());
603   auto input_shape = ShapeUtil::MakeShape(F32, {2, 4, 4});
604   auto source_shape = ShapeUtil::MakeShape(F32, {2, 2, 2});
605 
606   Window window;
607   // First dimension is unchanged.
608   WindowDimension* batch_dim = window.add_dimensions();
609   batch_dim->set_size(1);
610   batch_dim->set_stride(1);
611   batch_dim->set_padding_low(0);
612   batch_dim->set_padding_high(0);
613   batch_dim->set_window_dilation(1);
614   batch_dim->set_base_dilation(1);
615 
616   // Second and third dimension are reduced.
617   for (int64 i = 0; i < 2; ++i) {
618     WindowDimension* dim = window.add_dimensions();
619     dim->set_size(2);
620     dim->set_stride(2);
621     dim->set_padding_low(0);
622     dim->set_padding_high(0);
623     dim->set_window_dilation(1);
624     dim->set_base_dilation(1);
625   }
626 
627   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
628       /*parameter_number=*/0, input_shape, "A"));
629   auto* size_param = builder.AddInstruction(HloInstruction::CreateParameter(
630       /*parameter_number=*/1, scalar_shape_, "size_param"));
631   auto* source = builder.AddInstruction(HloInstruction::CreateParameter(
632       /*parameter_number=*/2, source_shape, "B"));
633 
634   auto init = builder.AddInstruction(
635       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
636 
637   auto* sns = builder.AddInstruction(HloInstruction::CreateSelectAndScatter(
638       input_shape, a_param, GetGe(), window, source, init, GetAdd()));
639 
640   module_->AddEntryComputation(builder.Build());
641 
642   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
643       DynamicParameterBinding::DynamicParameter{1, {}},
644       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
645   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
646       DynamicParameterBinding::DynamicParameter{1, {}},
647       DynamicParameterBinding::DynamicDimension{2, {}, 0}));
648 
649   SCOPED_TRACE(module_->ToString());
650   TF_ASSERT_OK(RunInference());
651   EXPECT_EQ(inference_->GetDynamicSize(sns, {}, 0), size_param);
652 }
653 
TEST_F(DynamicDimensionInferenceTest,SliceTest)654 TEST_F(DynamicDimensionInferenceTest, SliceTest) {
655   auto builder = HloComputation::Builder(TestName());
656 
657   auto data_param = builder.AddInstruction(HloInstruction::CreateParameter(
658       0, ShapeUtil::MakeShape(F32, {5, 7}), "data_param"));
659   auto size_param = builder.AddInstruction(
660       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
661 
662   auto* slice = builder.AddInstruction(HloInstruction::CreateSlice(
663       ShapeUtil::MakeShape(F32, {5, 7}), data_param, /*start_indices=*/{0, 0},
664       /*limit_indices=*/{5, 7}, /*strides=*/{1, 1}));
665 
666   module_->AddEntryComputation(builder.Build());
667   // Set up dynamic parameter binding.
668   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
669       DynamicParameterBinding::DynamicParameter{1, {}},
670       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
671 
672   TF_ASSERT_OK(RunInference());
673   EXPECT_EQ(inference_->GetDynamicSize(slice, {}, 1), size_param);
674 }
675 
676 }  // namespace
677 }  // namespace xla
678