1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/while_util.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/util.h"
23 
24 namespace xla {
25 namespace {
26 
27 namespace op = ::xla::testing::opcode_matchers;
28 
GetParsedModule(HloComputation ** entry_computation,HloInstruction ** param0,HloInstruction ** param1,HloInstruction ** param2)29 StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
30     HloComputation** entry_computation, HloInstruction** param0,
31     HloInstruction** param1, HloInstruction** param2) {
32   const char* const hlo_string = R"(
33 HloModule ModuleWithWhile
34 
35 while_body {
36   ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0)
37 }
38 
39 while_condition {
40   p_cond = f32[32,32]{1,0} parameter(0)
41   ROOT result = pred[] constant(true)
42 }
43 
44 ENTRY entry {
45   p_entry_0 = f32[32,32]{1,0} parameter(0)
46   p_entry_1 = s32[32,32]{1,0} parameter(1)
47   p_entry_2 = s64[32,32]{1,0} parameter(2)
48   while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0)
49   ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body
50 }
51 )";
52 
53   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
54                       ParseHloString(hlo_string));
55 
56   *entry_computation = module->entry_computation();
57   *param0 = (*entry_computation)->parameter_instruction(0);
58   *param1 = (*entry_computation)->parameter_instruction(1);
59   *param2 = (*entry_computation)->parameter_instruction(2);
60 
61   return std::move(module);
62 }
63 
TEST(WhileUtil,MakeZeroInstructionsLiveOp)64 TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
65   HloInstruction *param0, *param1, *param2;
66   HloComputation* entry_computation;
67 
68   TF_ASSERT_OK_AND_ASSIGN(
69       std::unique_ptr<HloModule> module,
70       GetParsedModule(&entry_computation, &param0, &param1, &param2));
71 
72   HloInstruction* while_instr = entry_computation->root_instruction();
73   ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
74 
75   TF_ASSERT_OK_AND_ASSIGN(
76       WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
77       WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{}));
78 
79   HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
80 
81   EXPECT_THAT(
82       entry_computation->root_instruction(),
83       op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
84                 op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
85 
86   auto param_reconstructed =
87       op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
88                 op::GetTupleElement(op::Parameter(0), 1));
89 
90   EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
91               op::Tuple(op::GetTupleElement(param_reconstructed, 0),
92                         op::GetTupleElement(param_reconstructed, 1)));
93 }
94 
TEST(WhileUtilTest,MakeTwoInstructionsLive)95 TEST(WhileUtilTest, MakeTwoInstructionsLive) {
96   HloInstruction *param0, *param1, *param2;
97   HloComputation* entry_computation;
98 
99   TF_ASSERT_OK_AND_ASSIGN(
100       std::unique_ptr<HloModule> module,
101       GetParsedModule(&entry_computation, &param0, &param1, &param2));
102 
103   HloInstruction* while_instr = entry_computation->root_instruction();
104   ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
105 
106   TF_ASSERT_OK_AND_ASSIGN(
107       WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
108       WhileUtil::MakeInstructionsLiveIn(while_instr,
109                                         /*instructions=*/{param0, param1}));
110 
111   HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
112 
113   XLA_VLOG_LINES(3, module->ToString());
114 
115   EXPECT_THAT(
116       entry_computation->root_instruction(),
117       op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
118                 op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
119 
120   auto first_half_param_reconstructed =
121       op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
122                 op::GetTupleElement(op::Parameter(0), 1));
123 
124   EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
125               op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0),
126                         op::GetTupleElement(first_half_param_reconstructed, 1),
127                         op::GetTupleElement(op::Parameter(0), 2),
128                         op::GetTupleElement(op::Parameter(0), 3)));
129 }
130 
TEST(WhileUtilTest,GetInvariantGTEsForWhileBody)131 TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) {
132   const char* const hlo_string = R"(
133 HloModule ModuleWithWhile
134 
135 body {
136   param.b = (s32[], s32[]) parameter(0)
137   gte.0 = s32[] get-tuple-element(param.b), index=0
138   gte.1 = s32[] get-tuple-element(param.b), index=1
139   add = s32[] add(gte.0, gte.1)
140   ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
141 }
142 
143 cond {
144   param.c = (s32[], s32[]) parameter(0)
145   ROOT constant = pred[] constant(true)
146 }
147 
148 ENTRY main {
149   init = (s32[], s32[]) parameter(0)
150   ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
151 }
152 )";
153 
154   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
155                           ParseHloString(hlo_string));
156 
157   HloComputation* while_body = module->GetComputationWithName("body");
158 
159   ASSERT_NE(while_body, nullptr)
160       << "Expected exactly one while_body computation";
161 
162   std::vector<HloInstruction*> gte_list =
163       WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
164 
165   ASSERT_EQ(gte_list.size(), 1);
166   EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
167 }
168 
TEST(WhileUtilTest,AlwaysRemovePreviousWhileBody)169 TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
170   const char* const hlo_string = R"(
171 HloModule WhileWithSideEffects
172 
173 body {
174   param.b = (s32[], s32[]) parameter(0)
175   gte.0 = s32[] get-tuple-element(param.b), index=0
176   gte.1 = s32[] get-tuple-element(param.b), index=1
177   add = s32[] add(gte.0, gte.1)
178   ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
179 }
180 
181 cond {
182   param.c = (s32[], s32[]) parameter(0)
183   token0 = token[] after-all()
184   infeed = (pred[], token[]) infeed(token0)
185   ROOT condition = pred[] get-tuple-element(infeed), index=0
186 }
187 
188 ENTRY main {
189   init = (s32[], s32[]) parameter(0)
190   to_make_live_in = f32[100] parameter(1)
191   ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
192 }
193 )";
194 
195   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
196                           ParseHloString(hlo_string));
197 
198   HloComputation* main = module->GetComputationWithName("main");
199   HloInstruction* while_instr = main->root_instruction();
200   HloInstruction* to_make_live_in = main->parameter_instruction(1);
201 
202   TF_ASSERT_OK_AND_ASSIGN(
203       WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
204       WhileUtil::MakeInstructionsLiveIn(while_instr,
205                                         /*instructions=*/{to_make_live_in}));
206 
207   auto is_while = [](const HloInstruction* instr) {
208     return instr->opcode() == HloOpcode::kWhile;
209   };
210   EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1);
211 }
212 }  // namespace
213 }  // namespace xla
214