1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/hlo_runner.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35 
36 namespace op = xla::testing::opcode_matchers;
37 
38 namespace xla {
39 namespace {
40 
41 class DynamicPadderTest : public HloTestBase {
42  protected:
DynamicPadderTest()43   DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); }
44 
RunPadder()45   StatusOr<bool> RunPadder() {
46     DynamicPadder padder;
47     return padder.Run(module_.get());
48   }
49 
ExpectPadded(const HloInstruction * inst)50   void ExpectPadded(const HloInstruction* inst) {
51     EXPECT_THAT(inst,
52                 op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())),
53                            ::testing::_, op::Broadcast()));
54   }
55 
GetScalarAddComputation()56   HloComputation* GetScalarAddComputation() {
57     auto embedded_builder = HloComputation::Builder("add");
58     auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
59         0, ShapeUtil::MakeShape(F32, {}), "lhs"));
60     auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
61         1, ShapeUtil::MakeShape(F32, {}), "rhs"));
62     embedded_builder.AddInstruction(
63         HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
64     return module_->AddEmbeddedComputation(embedded_builder.Build());
65   }
66 
67   std::unique_ptr<HloModule> module_;
68   const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
69 };
70 
TEST_F(DynamicPadderTest,ReduceTest)71 TEST_F(DynamicPadderTest, ReduceTest) {
72   auto builder = HloComputation::Builder(TestName());
73   auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
74   auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
75 
76   auto data_param = builder.AddInstruction(
77       HloInstruction::CreateParameter(0, input_shape, "data_param"));
78   builder.AddInstruction(
79       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
80 
81   auto negate = builder.AddInstruction(
82       HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
83 
84   auto init = builder.AddInstruction(
85       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
86 
87   auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
88       reduce_shape, negate, init, {0, 2}, GetScalarAddComputation()));
89 
90   module_->AddEntryComputation(builder.Build());
91 
92   // Set up dynamic parameter binding.
93   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
94       DynamicParameterBinding::DynamicParameter{1, {}},
95       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
96 
97   TF_ASSERT_OK(RunPadder().status());
98 
99   ExpectPadded(reduce->operand(0));
100 }
101 
TEST_F(DynamicPadderTest,ConvolutionTest)102 TEST_F(DynamicPadderTest, ConvolutionTest) {
103   auto builder = HloComputation::Builder(TestName());
104   constexpr int xdim = 3;
105   constexpr int ydim = 2;
106   constexpr int zdim = 1;
107   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
108   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
109   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
110 
111   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
112       /*parameter_number=*/0, xy_shape, "A"));
113   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
114       /*parameter_number=*/1, yz_shape, "B"));
115   builder.AddInstruction(HloInstruction::CreateParameter(
116       /*parameter_number=*/2, scalar_shape_, "size_param"));
117 
118   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
119 
120   dnums.set_kernel_input_feature_dimension(0);
121   dnums.set_kernel_output_feature_dimension(1);
122   dnums.set_input_batch_dimension(0);
123   dnums.set_output_batch_dimension(1);
124   dnums.set_output_feature_dimension(0);
125 
126   Window window;
127 
128   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
129       zx_shape, a_param, b_param, /*feature_group_count=*/1,
130       /*batch_group_count=*/1, window, dnums,
131       HloTestBase::DefaultPrecisionConfig(2)));
132 
133   module_->AddEntryComputation(builder.Build());
134 
135   // Set up binding for contracting dimensions.
136   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
137       DynamicParameterBinding::DynamicParameter{2, {}},
138       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
139 
140   TF_ASSERT_OK(RunPadder().status());
141 
142   ExpectPadded(conv->operand(0));
143 }
144 
TEST_F(DynamicPadderTest,ConvolutionNoPad)145 TEST_F(DynamicPadderTest, ConvolutionNoPad) {
146   auto builder = HloComputation::Builder(TestName());
147   constexpr int xdim = 3;
148   constexpr int ydim = 2;
149   constexpr int zdim = 1;
150   auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
151   auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
152   auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
153 
154   auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
155       /*parameter_number=*/0, xy_shape, "A"));
156   auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
157       /*parameter_number=*/1, yz_shape, "B"));
158   builder.AddInstruction(HloInstruction::CreateParameter(
159       /*parameter_number=*/2, scalar_shape_, "size_param"));
160 
161   auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
162 
163   dnums.set_kernel_input_feature_dimension(0);
164   dnums.set_kernel_output_feature_dimension(1);
165   dnums.set_input_batch_dimension(0);
166   dnums.set_output_batch_dimension(1);
167   dnums.set_output_feature_dimension(0);
168 
169   Window window;
170 
171   auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
172       zx_shape, a_param, b_param, /*feature_group_count=*/1,
173       /*batch_group_count=*/1, window, dnums,
174       HloTestBase::DefaultPrecisionConfig(2)));
175 
176   module_->AddEntryComputation(builder.Build());
177 
178   // Set up dynamic parameter binding for non-contracting dimension.
179   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
180       DynamicParameterBinding::DynamicParameter{2, {}},
181       DynamicParameterBinding::DynamicDimension{0, {}, 0}));
182 
183   TF_ASSERT_OK(RunPadder().status());
184 
185   EXPECT_THAT(conv->operand(0), op::Parameter());
186 }
187 
TEST_F(DynamicPadderTest,ReduceWindowNoPadForTrivialWindow)188 TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
189   auto builder = HloComputation::Builder(TestName());
190   auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
191   auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5});
192 
193   auto input = builder.AddInstruction(
194       HloInstruction::CreateParameter(0, input_shape, "input"));
195   builder.AddInstruction(
196       HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
197   auto init = builder.AddInstruction(
198       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
199   TF_ASSERT_OK_AND_ASSIGN(Window window, ParseWindow("size=2x1 pad=0_0x0_0"));
200   auto output = builder.AddInstruction(HloInstruction::CreateReduceWindow(
201       reduce_shape, input, init, window, GetScalarAddComputation()));
202 
203   module_->AddEntryComputation(builder.Build());
204 
205   // Set up dynamic parameter binding.
206   TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
207       DynamicParameterBinding::DynamicParameter{1, {}},
208       DynamicParameterBinding::DynamicDimension{0, {}, 1}));
209 
210   TF_ASSERT_OK(RunPadder().status());
211 
212   EXPECT_THAT(output->operand(0), op::Parameter());
213 }
214 
215 }  // namespace
216 }  // namespace xla
217