1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_padder.h"
17
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/service/hlo_parser.h"
26 #include "tensorflow/compiler/xla/service/hlo_runner.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/test.h"
30 #include "tensorflow/compiler/xla/test_helpers.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 #include "tensorflow/core/platform/test_benchmark.h"
35
36 namespace op = xla::testing::opcode_matchers;
37
38 namespace xla {
39 namespace {
40
41 class DynamicPadderTest : public HloTestBase {
42 protected:
DynamicPadderTest()43 DynamicPadderTest() : HloTestBase() { module_ = CreateNewVerifiedModule(); }
44
RunPadder()45 StatusOr<bool> RunPadder() {
46 DynamicPadder padder;
47 return padder.Run(module_.get());
48 }
49
ExpectPadded(const HloInstruction * inst)50 void ExpectPadded(const HloInstruction* inst) {
51 EXPECT_THAT(inst,
52 op::Select(op::Lt(op::Iota(), op::Broadcast(op::Parameter())),
53 ::testing::_, op::Broadcast()));
54 }
55
GetScalarAddComputation()56 HloComputation* GetScalarAddComputation() {
57 auto embedded_builder = HloComputation::Builder("add");
58 auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
59 0, ShapeUtil::MakeShape(F32, {}), "lhs"));
60 auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
61 1, ShapeUtil::MakeShape(F32, {}), "rhs"));
62 embedded_builder.AddInstruction(
63 HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
64 return module_->AddEmbeddedComputation(embedded_builder.Build());
65 }
66
67 std::unique_ptr<HloModule> module_;
68 const Shape scalar_shape_ = ShapeUtil::MakeShape(U32, {});
69 };
70
TEST_F(DynamicPadderTest,ReduceTest)71 TEST_F(DynamicPadderTest, ReduceTest) {
72 auto builder = HloComputation::Builder(TestName());
73 auto input_shape = ShapeUtil::MakeShape(F32, {1, 2, 2});
74 auto reduce_shape = ShapeUtil::MakeShape(F32, {2});
75
76 auto data_param = builder.AddInstruction(
77 HloInstruction::CreateParameter(0, input_shape, "data_param"));
78 builder.AddInstruction(
79 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
80
81 auto negate = builder.AddInstruction(
82 HloInstruction::CreateUnary(input_shape, HloOpcode::kNegate, data_param));
83
84 auto init = builder.AddInstruction(
85 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
86
87 auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
88 reduce_shape, negate, init, {0, 2}, GetScalarAddComputation()));
89
90 module_->AddEntryComputation(builder.Build());
91
92 // Set up dynamic parameter binding.
93 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
94 DynamicParameterBinding::DynamicParameter{1, {}},
95 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
96
97 TF_ASSERT_OK(RunPadder().status());
98
99 ExpectPadded(reduce->operand(0));
100 }
101
TEST_F(DynamicPadderTest,ConvolutionTest)102 TEST_F(DynamicPadderTest, ConvolutionTest) {
103 auto builder = HloComputation::Builder(TestName());
104 constexpr int xdim = 3;
105 constexpr int ydim = 2;
106 constexpr int zdim = 1;
107 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
108 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
109 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
110
111 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
112 /*parameter_number=*/0, xy_shape, "A"));
113 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
114 /*parameter_number=*/1, yz_shape, "B"));
115 builder.AddInstruction(HloInstruction::CreateParameter(
116 /*parameter_number=*/2, scalar_shape_, "size_param"));
117
118 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
119
120 dnums.set_kernel_input_feature_dimension(0);
121 dnums.set_kernel_output_feature_dimension(1);
122 dnums.set_input_batch_dimension(0);
123 dnums.set_output_batch_dimension(1);
124 dnums.set_output_feature_dimension(0);
125
126 Window window;
127
128 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
129 zx_shape, a_param, b_param, /*feature_group_count=*/1,
130 /*batch_group_count=*/1, window, dnums,
131 HloTestBase::DefaultPrecisionConfig(2)));
132
133 module_->AddEntryComputation(builder.Build());
134
135 // Set up binding for contracting dimensions.
136 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
137 DynamicParameterBinding::DynamicParameter{2, {}},
138 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
139
140 TF_ASSERT_OK(RunPadder().status());
141
142 ExpectPadded(conv->operand(0));
143 }
144
TEST_F(DynamicPadderTest,ConvolutionNoPad)145 TEST_F(DynamicPadderTest, ConvolutionNoPad) {
146 auto builder = HloComputation::Builder(TestName());
147 constexpr int xdim = 3;
148 constexpr int ydim = 2;
149 constexpr int zdim = 1;
150 auto xy_shape = ShapeUtil::MakeShape(F32, {xdim, ydim});
151 auto yz_shape = ShapeUtil::MakeShape(F32, {ydim, zdim});
152 auto zx_shape = ShapeUtil::MakeShape(F32, {zdim, xdim});
153
154 auto* a_param = builder.AddInstruction(HloInstruction::CreateParameter(
155 /*parameter_number=*/0, xy_shape, "A"));
156 auto* b_param = builder.AddInstruction(HloInstruction::CreateParameter(
157 /*parameter_number=*/1, yz_shape, "B"));
158 builder.AddInstruction(HloInstruction::CreateParameter(
159 /*parameter_number=*/2, scalar_shape_, "size_param"));
160
161 auto dnums = XlaBuilder::CreateDefaultConvDimensionNumbers(0);
162
163 dnums.set_kernel_input_feature_dimension(0);
164 dnums.set_kernel_output_feature_dimension(1);
165 dnums.set_input_batch_dimension(0);
166 dnums.set_output_batch_dimension(1);
167 dnums.set_output_feature_dimension(0);
168
169 Window window;
170
171 auto* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
172 zx_shape, a_param, b_param, /*feature_group_count=*/1,
173 /*batch_group_count=*/1, window, dnums,
174 HloTestBase::DefaultPrecisionConfig(2)));
175
176 module_->AddEntryComputation(builder.Build());
177
178 // Set up dynamic parameter binding for non-contracting dimension.
179 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
180 DynamicParameterBinding::DynamicParameter{2, {}},
181 DynamicParameterBinding::DynamicDimension{0, {}, 0}));
182
183 TF_ASSERT_OK(RunPadder().status());
184
185 EXPECT_THAT(conv->operand(0), op::Parameter());
186 }
187
TEST_F(DynamicPadderTest,ReduceWindowNoPadForTrivialWindow)188 TEST_F(DynamicPadderTest, ReduceWindowNoPadForTrivialWindow) {
189 auto builder = HloComputation::Builder(TestName());
190 auto input_shape = ShapeUtil::MakeShape(F32, {4, 5});
191 auto reduce_shape = ShapeUtil::MakeShape(F32, {3, 5});
192
193 auto input = builder.AddInstruction(
194 HloInstruction::CreateParameter(0, input_shape, "input"));
195 builder.AddInstruction(
196 HloInstruction::CreateParameter(1, scalar_shape_, "size_param"));
197 auto init = builder.AddInstruction(
198 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0)));
199 TF_ASSERT_OK_AND_ASSIGN(Window window, ParseWindow("size=2x1 pad=0_0x0_0"));
200 auto output = builder.AddInstruction(HloInstruction::CreateReduceWindow(
201 reduce_shape, input, init, window, GetScalarAddComputation()));
202
203 module_->AddEntryComputation(builder.Build());
204
205 // Set up dynamic parameter binding.
206 TF_CHECK_OK(module_->dynamic_parameter_binding().Bind(
207 DynamicParameterBinding::DynamicParameter{1, {}},
208 DynamicParameterBinding::DynamicDimension{0, {}, 1}));
209
210 TF_ASSERT_OK(RunPadder().status());
211
212 EXPECT_THAT(output->operand(0), op::Parameter());
213 }
214
215 } // namespace
216 } // namespace xla
217