1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/layout_util.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/hlo_parser.h"
27 #include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
28 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
29 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/test.h"
32 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
33 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
34 #include "tensorflow/compiler/xla/types.h"
35 
36 namespace xla {
37 namespace {
38 
39 namespace m = xla::match;
40 
41 using HloConstantFoldingTest = HloTestBase;
42 
TEST_F(HloConstantFoldingTest,ConvertF32ToS64)43 TEST_F(HloConstantFoldingTest, ConvertF32ToS64) {
44   HloComputation::Builder builder(TestName());
45   HloInstruction* input = builder.AddInstruction(
46       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
47   builder.AddInstruction(
48       HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input));
49 
50   auto module = CreateNewVerifiedModule();
51   auto computation = module->AddEntryComputation(builder.Build());
52 
53   EXPECT_THAT(computation->root_instruction(),
54               GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
55 
56   HloConstantFolding const_folder;
57   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
58   EXPECT_TRUE(result);
59 
60   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
61   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<int64>(),
62             42);
63 }
64 
TEST_F(HloConstantFoldingTest,ConvertS64ToF32)65 TEST_F(HloConstantFoldingTest, ConvertS64ToF32) {
66   HloComputation::Builder builder(TestName());
67   HloInstruction* input = builder.AddInstruction(
68       HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42)));
69   builder.AddInstruction(
70       HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input));
71 
72   auto module = CreateNewVerifiedModule();
73   auto computation = module->AddEntryComputation(builder.Build());
74 
75   EXPECT_THAT(computation->root_instruction(),
76               GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
77 
78   HloConstantFolding const_folder;
79   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
80   EXPECT_TRUE(result);
81 
82   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
83   EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
84             42.0f);
85 }
86 
TEST_F(HloConstantFoldingTest,ConvertF32ArrayToS64Array)87 TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) {
88   HloComputation::Builder builder(TestName());
89   HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant(
90       LiteralUtil::CreateR1<float>({42.0f, 19.0f})));
91   builder.AddInstruction(
92       HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input));
93 
94   auto module = CreateNewVerifiedModule();
95   auto computation = module->AddEntryComputation(builder.Build());
96 
97   EXPECT_THAT(computation->root_instruction(),
98               GmockMatch(m::Convert().WithOperand(0, m::Op().Is(input))));
99 
100   HloConstantFolding const_folder;
101   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
102   EXPECT_TRUE(result);
103 
104   EXPECT_THAT(computation->root_instruction(), GmockMatch(m::Constant()));
105   EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({0}), 42);
106   EXPECT_EQ(computation->root_instruction()->literal().Get<int64>({1}), 19);
107 }
108 
TEST_F(HloConstantFoldingTest,Concatenate)109 TEST_F(HloConstantFoldingTest, Concatenate) {
110   const struct TestConfig {
111     int concat_dimension;
112     absl::Span<const int64> dimensions;
113     absl::Span<const int64> concat_sizes;
114   } test_configs[] = {
115       {1, {11, 0, 7, 5, 9}, {2, 5, 7, 11}},
116       {3, {1, 4, 17, 0, 8}, {1, 3, 9, 12}},
117   };
118 
119   for (auto& test_config : test_configs) {
120     HloComputation::Builder builder(TestName());
121     std::vector<int64> dimensions(test_config.dimensions.begin(),
122                                   test_config.dimensions.end());
123     int64 concat_size = 0;
124     std::vector<HloInstruction*> operands;
125     for (auto csize : test_config.concat_sizes) {
126       dimensions[test_config.concat_dimension] = csize;
127       concat_size += csize;
128       auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions);
129       HloInstruction* insn = builder.AddInstruction(
130           HloInstruction::CreateConstant(std::move(literal)));
131       operands.push_back(insn);
132     }
133     dimensions[test_config.concat_dimension] = concat_size;
134     Shape shape = ShapeUtil::MakeShape(F32, dimensions);
135     builder.AddInstruction(HloInstruction::CreateConcatenate(
136         shape, operands, test_config.concat_dimension));
137     auto module = CreateNewVerifiedModule();
138     auto computation = module->AddEntryComputation(builder.Build());
139 
140     HloConstantFolding const_folder;
141     TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
142     EXPECT_TRUE(result);
143 
144     HloInstruction* root = computation->root_instruction();
145     EXPECT_THAT(root, GmockMatch(m::Constant()));
146     EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
147   }
148 }
149 
TEST_F(HloConstantFoldingTest,Slice)150 TEST_F(HloConstantFoldingTest, Slice) {
151   HloComputation::Builder builder(TestName());
152   const int64 dimensions[] = {11, 8, 7, 5, 9};
153   const int64 slice_start[] = {4, 2, 3, 1, 5};
154   const int64 slice_limits[] = {10, 8, 6, 5, 9};
155   const int64 slice_strides[] = {1, 1, 1, 1, 1};
156   TF_ASSERT_OK_AND_ASSIGN(auto literal,
157                           LiteralUtil::CreateRandomLiteral<F32>(
158                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
159   HloInstruction* literal_instruction = builder.AddInstruction(
160       HloInstruction::CreateConstant(std::move(literal)));
161   Shape shape = ShapeUtil::MakeShape(F32, {6, 6, 3, 4, 4});
162   builder.AddInstruction(HloInstruction::CreateSlice(
163       shape, literal_instruction, slice_start, slice_limits, slice_strides));
164   auto module = CreateNewVerifiedModule();
165   auto computation = module->AddEntryComputation(builder.Build());
166 
167   HloConstantFolding const_folder;
168   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
169   EXPECT_TRUE(result);
170 
171   HloInstruction* root = computation->root_instruction();
172   EXPECT_THAT(root, GmockMatch(m::Constant()));
173   EXPECT_TRUE(ShapeUtil::Equal(root->shape(), shape));
174 }
175 
TEST_F(HloConstantFoldingTest,TransposeConstantFold)176 TEST_F(HloConstantFoldingTest, TransposeConstantFold) {
177   HloComputation::Builder builder(TestName());
178   const int64 dimensions[] = {11, 8, 7, 5, 9};
179   TF_ASSERT_OK_AND_ASSIGN(auto literal,
180                           LiteralUtil::CreateRandomLiteral<F32>(
181                               ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0));
182   auto literal_clone = literal.Clone();
183   HloInstruction* literal_instruction = builder.AddInstruction(
184       HloInstruction::CreateConstant(std::move(literal)));
185   Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5});
186   const int64 permutation[] = {1, 2, 0, 4, 3};
187   builder.AddInstruction(
188       HloInstruction::CreateTranspose(shape, literal_instruction, permutation));
189   auto module = CreateNewVerifiedModule();
190   auto computation = module->AddEntryComputation(builder.Build());
191 
192   HloConstantFolding const_folder;
193   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
194   EXPECT_TRUE(result);
195 
196   HloInstruction* root = computation->root_instruction();
197   EXPECT_THAT(root, GmockMatch(m::Constant()));
198   EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), shape));
199 
200   using NativeT = typename primitive_util::PrimitiveTypeToNative<F32>::type;
201   bool matched = true;
202   root->literal().EachCell<NativeT>(
203       [&](absl::Span<const int64> indices, NativeT value) {
204         std::vector<int64> rindexes = Permute(permutation, indices);
205         matched = matched && (value == literal_clone.Get<NativeT>(rindexes));
206       });
207   EXPECT_TRUE(matched);
208 }
209 
210 const char* const kConstantFoldReduce = R"(
211   HloModule ConstantFoldReduce
212 
213   add {
214     a = s32[] parameter(0)
215     b = s32[] parameter(1)
216     ROOT add = s32[] add(a, b)
217   }
218 
219   ENTRY r {
220     x = s32[3] constant({1, 2, 3})
221     init = s32[] constant(0)
222     ROOT reduce = s32[] reduce(x, init), dimensions={0}, to_apply=add
223   })";
224 
TEST_F(HloConstantFoldingTest,ConstantFoldReduce)225 TEST_F(HloConstantFoldingTest, ConstantFoldReduce) {
226   TF_ASSERT_OK_AND_ASSIGN(auto m,
227                           ParseAndReturnVerifiedModule(kConstantFoldReduce));
228   HloConstantFolding const_folder;
229   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
230   EXPECT_TRUE(result);
231 
232   EXPECT_EQ(6, m->entry_computation()
233                    ->root_instruction()
234                    ->literal()
235                    .GetFirstElement<int32>());
236 }
237 
TEST_F(HloConstantFoldingTest,ConstantFoldReduceNoLayout)238 TEST_F(HloConstantFoldingTest, ConstantFoldReduceNoLayout) {
239   TF_ASSERT_OK_AND_ASSIGN(auto m,
240                           ParseAndReturnVerifiedModule(kConstantFoldReduce));
241   HloInstruction* add = m->computations().begin()->root_instruction();
242   LayoutUtil::ClearLayout(add->mutable_shape());
243   HloConstantFolding const_folder;
244   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(m.get()));
245   EXPECT_FALSE(result);
246 
247   EXPECT_THAT(m->entry_computation()->root_instruction(),
248               GmockMatch(m::Reduce()));
249 }
250 
251 const char* const kConstantFoldLargePad = R"(
252   HloModule ConstantFoldLargePad
253 
254   ENTRY r {
255     a = f32[1,1,1] constant({{{7}}})
256     b = f32[] constant(42)
257     ROOT pad = f32[2048,2048,128] pad(a, b), padding=1024_1023x1024_1023x64_63
258   })";
259 
TEST_F(HloConstantFoldingTest,DoesNotFoldLargePad)260 TEST_F(HloConstantFoldingTest, DoesNotFoldLargePad) {
261   TF_ASSERT_OK_AND_ASSIGN(auto module,
262                           ParseAndReturnVerifiedModule(kConstantFoldLargePad));
263   HloConstantFolding const_folder;
264   TF_ASSERT_OK_AND_ASSIGN(bool result, const_folder.Run(module.get()));
265   EXPECT_FALSE(result);
266 
267   EXPECT_THAT(module->entry_computation()->root_instruction(),
268               GmockMatch(m::Pad(m::Constant(), m::Constant())));
269 }
270 
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationContainingAfterAll)271 TEST_F(HloConstantFoldingTest, DontFoldSubcomputationContainingAfterAll) {
272   const char* const kModuleStr = R"(
273   HloModule test
274 
275   Fn {
276     tok = token[] after-all()
277     ROOT root = f32[10] iota(), iota_dimension=0
278   }
279 
280   ENTRY entry {
281     ROOT call = f32[10] call(), to_apply=Fn
282   })";
283   TF_ASSERT_OK_AND_ASSIGN(auto module,
284                           ParseAndReturnVerifiedModule(kModuleStr));
285   HloConstantFolding constant_folding;
286   TF_ASSERT_OK_AND_ASSIGN(bool result,
287                           RunHloPass(&constant_folding, module.get()));
288   EXPECT_FALSE(result);
289 }
290 
TEST_F(HloConstantFoldingTest,DontFoldSubcomputationTransitivelyContainingRng)291 TEST_F(HloConstantFoldingTest,
292        DontFoldSubcomputationTransitivelyContainingRng) {
293   const char* const kModuleStr = R"(
294   HloModule test
295 
296   InnerFn {
297     c0 = f32[] constant(0)
298     c1 = f32[] constant(1)
299     ROOT rng = f32[10] rng(c0, c1), distribution=rng_uniform
300   }
301 
302   Fn {
303     ROOT fusion = f32[10] fusion(), kind=kLoop, calls=InnerFn
304   }
305 
306   ENTRY entry {
307     ROOT call = f32[10] call(), to_apply=Fn
308   })";
309   TF_ASSERT_OK_AND_ASSIGN(auto module,
310                           ParseAndReturnVerifiedModule(kModuleStr));
311   HloConstantFolding constant_folding;
312   TF_ASSERT_OK_AND_ASSIGN(bool result,
313                           RunHloPass(&constant_folding, module.get()));
314   EXPECT_FALSE(result);
315 }
316 
317 }  // namespace
318 }  // namespace xla
319