1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/test.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35
36 namespace xla {
37 namespace {
38
39 namespace m = xla::match;
40
41 using HloConstantFoldingTest = HloTestBase;
42
TEST_F(HloConstantFoldingTest,ConvertF32ToS64)43 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
44 HloComputation::Builder builder(TestName());
45 HloInstruction* input = builder.AddInstruction(
46 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
47 builder.AddInstruction(
48 HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
49
50 auto module = CreateNewVerifiedModule();
51 auto computation = module->AddEntryComputation(builder.Build());
52
53 EXPECT_THAT(computation->root_instruction(),
54 GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
55
56 HloConstantFolding const_folder;
57 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
58 EXPECT_TRUE(result);
59
60 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
61 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<int64>(),
62 42);
63 }
64
TEST_F(HloConstantFoldingTest,ConvertS64ToF32)65 TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
66 HloComputation::Builder builder(TestName());
67 HloInstruction* input = builder.AddInstruction(
68 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
69 builder.AddInstruction(
70 HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
71
72 auto module = CreateNewVerifiedModule();
73 auto computation = module->AddEntryComputation(builder.Build());
74
75 EXPECT_THAT(computation->root_instruction(),
76 GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
77
78 HloConstantFolding const_folder;
79 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
80 EXPECT_TRUE(result);
81
82 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
83 EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
84 42.0f);
85 }
86
TEST_F(HloConstantFoldingTest,ConvertF32ArrayToS64Array)87 TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
88 HloComputation::Builder builder(TestName());
89 HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
90 LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
91 builder.AddInstruction(
92 HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
93
94 auto module = CreateNewVerifiedModule();
95 auto computation = module->AddEntryComputation(builder.Build());
96
97 EXPECT_THAT(computation->root_instruction(),
98 GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
99
100 HloConstantFolding const_folder;
101 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
102 EXPECT_TRUE(result);
103
104 EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
105 EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({0}), 42);
106 EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({1}), 19);
107 }
108
TEST_F(HloConstantFoldingTest,Concatenate)109 TEST_F(HloConstantFoldingTest, Concatenate) {
110 const struct TestConfig {
111 int concat_dimension;
112 absl::Span<const int64> dimensions;
113 absl::Span<const int64> concat_sizes;
114 } test_configs[] = {
115 {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
116 {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
117 };
118
119 for (auto& test_config : test_configs) {
120 HloComputation::Builder builder(TestName());
121 std::vector<int64> dimensions(test_config.dimensions.begin(),
122 test_config.dimensions.end());
123 int64 concat_size = 0;
124 std::vector<HloInstruction*> operands;
125 for (auto csize : test_config.concat_sizes) {
126 dimensions[test_config.concat_dimension] = csize;
127 concat_size += csize;
128 auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
129 HloInstruction* insn = builder.AddInstruction(
130 HloInstruction::CreateConstant(std::move(literal)));
131 operands.push_back(insn);
132 }
133 dimensions[test_config.concat_dimension] = concat_size;
134 Shape shape = ShapeUtil::MakeShape(F32, dimensions);
135 builder.AddInstruction(HloInstruction::CreateConcatenate(
136 shape, operands, test_config.concat_dimension));
137 auto module = CreateNewVerifiedModule();
138 auto computation = module->AddEntryComputation(builder.Build());
139
140 HloConstantFolding const_folder;
141 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
142 EXPECT_TRUE(result);
143
144 HloInstruction* root = computation->root_instruction();
145 EXPECT_THAT(root, GmockMatch(m::Constant()));
146 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
147 }
148 }
149
TEST_F(HloConstantFoldingTest,Slice)150 TEST_F(HloConstantFoldingTest, Slice) {
151 HloComputation::Builder builder(TestName());
152 const int64 dimensions[] = {11, 8, 7, 5, 9};
153 const int64 slice_start[] = {4, 2, 3, 1, 5};
154 const int64 slice_limits[] = {10, 8, 6, 5, 9};
155 const int64 slice_strides[] = {1, 1, 1, 1, 1};
156 TF_ASSERT_OK_AND_ASSIGN(auto literal,
157 LiteralUtil::CreateRandomLiteral<F32>(
158 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
159 HloInstruction* literal_instruction = builder.AddInstruction(
160 HloInstruction::CreateConstant(std::move(literal)));
161 Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
162 builder.AddInstruction(HloInstruction::CreateSlice(
163 shape, literal_instruction, slice_start, slice_limits, slice_strides));
164 auto module = CreateNewVerifiedModule();
165 auto computation = module->AddEntryComputation(builder.Build());
166
167 HloConstantFolding const_folder;
168 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
169 EXPECT_TRUE(result);
170
171 HloInstruction* root = computation->root_instruction();
172 EXPECT_THAT(root, GmockMatch(m::Constant()));
173 EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
174 }
175
TEST_F(HloConstantFoldingTest,TransposeConstantFold)176 TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
177 HloComputation::Builder builder(TestName());
178 const int64 dimensions[] = {11, 8, 7, 5, 9};
179 TF_ASSERT_OK_AND_ASSIGN(auto literal,
180 LiteralUtil::CreateRandomLiteral<F32>(
181 ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
182 auto literal_clone = literal.Clone();
183 HloInstruction* literal_instruction = builder.AddInstruction(
184 HloInstruction::CreateConstant(std::move(literal)));
185 Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
186 const int64 permutation[] = {1, 2, 0, 4, 3};
187 builder.AddInstruction(
188 HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
189 auto module = CreateNewVerifiedModule();
190 auto computation = module->AddEntryComputation(builder.Build());
191
192 HloConstantFolding const_folder;
193 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
194 EXPECT_TRUE(result);
195
196 HloInstruction* root = computation->root_instruction();
197 EXPECT_THAT(root, GmockMatch(m::Constant()));
198 EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
199
200 using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
201 bool matched = true;
202 root->literal().EachCell<NativeT>(
203 [&](absl::Span<const int64> indices, NativeT value) {
204 std::vector<int64> rindexes = Permute(permutation, indices);
205 matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
206 });
207 EXPECT_TRUE(matched);
208 }
209
210 const char* const kConstantFoldReduce = R"(
211 HloModule ConstantFoldReduce
212
213 add {
214 a = s32[] parameter(0)
215 b = s32[] parameter(1)
216 ROOT add = s32[] add(a, b)
217 }
218
219 ENTRY r {
220 x = s32[3] constant({1, 2, 3})
221 init = s32[] constant(0)
222 ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
223 })";
224
TEST_F(HloConstantFoldingTest,ConstantFoldReduce)225 TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
226 TF_ASSERT_OK_AND_ASSIGN(auto m,
227 ParseAndReturnVerifiedModule(kConstantFoldReduce));
228 HloConstantFolding const_folder;
229 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
230 EXPECT_TRUE(result);
231
232 EXPECT_EQ(6, m->entry_computation()
233 ->root_instruction()
234 ->literal()
235 .GetFirstElement<int32>());
236 }
237
TEST_F(HloConstantFoldingTest,ConstantFoldReduceNoLayout)238 TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
239 TF_ASSERT_OK_AND_ASSIGN(auto m,
240 ParseAndReturnVerifiedModule(kConstantFoldReduce));
241 HloInstruction* add = m->computations().begin()->root_instruction();
242 LayoutUtil::ClearLayout(add->mutable_shape());
243 HloConstantFolding const_folder;
244 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
245 EXPECT_FALSE(result);
246
247 EXPECT_THAT(m->entry_computation()->root_instruction(),
248 GmockMatch(m::Reduce()));
249 }
250
251 const char* const kConstantFoldLargePad = R"(
252 HloModule ConstantFoldLargePad
253
254 ENTRY r {
255 a = f32[1,1,1] constant({{{7}}})
256 b = f32[] constant(42)
257 ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
258 })";
259
TEST_F(HloConstantFoldingTest,DoesNotFoldLargePad)260 TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
261 TF_ASSERT_OK_AND_ASSIGN(auto module,
262 ParseAndReturnVerifiedModule(kConstantFoldLargePad));
263 HloConstantFolding const_folder;
264 TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
265 EXPECT_FALSE(result);
266
267 EXPECT_THAT(module->entry_computation()->root_instruction(),
268 GmockMatch(m::Pad(m::Constant(), m::Constant())));
269 }
270
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationContainingAfterAll)271 TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) {
272 const char* const kModuleStr = R"(
273 HloModule test
274
275 Fn {
276 tok = token[] after-all()
277 ROOT root = f32[10] iota(), iota_dimension=0
278 }
279
280 ENTRY entry {
281 ROOT call = f32[10] call(), to_apply=Fn
282 })";
283 TF_ASSERT_OK_AND_ASSIGN(auto module,
284 ParseAndReturnVerifiedModule(kModuleStr));
285 HloConstantFolding constant_folding;
286 TF_ASSERT_OK_AND_ASSIGN(bool result,
287 RunHloPass(&constant_folding, module.get()));
288 EXPECT_FALSE(result);
289 }
290
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationTransitivelyContainingRng)291 TEST_F(HloConstantFoldingTest,
292 DontFoldSubcomputationTransitivelyContainingRng) {
293 const char* const kModuleStr = R"(
294 HloModule test
295
296 InnerFn {
297 c0 = f32[] constant(0)
298 c1 = f32[] constant(1)
299 ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform
300 }
301
302 Fn {
303 ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn
304 }
305
306 ENTRY entry {
307 ROOT call = f32[10] call(), to_apply=Fn
308 })";
309 TF_ASSERT_OK_AND_ASSIGN(auto module,
310 ParseAndReturnVerifiedModule(kModuleStr));
311 HloConstantFolding constant_folding;
312 TF_ASSERT_OK_AND_ASSIGN(bool result,
313 RunHloPass(&constant_folding, module.get()));
314 EXPECT_FALSE(result);
315 }
316
317 } // namespace
318 } // namespace xla
319