1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/while_util.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/util.h"
23
24 namespace xla {
25 namespace {
26
27 namespace op = ::xla::testing::opcode_matchers;
28
GetParsedModule(HloComputation ** entry_computation,HloInstruction ** param0,HloInstruction ** param1,HloInstruction ** param2)29 StatusOr<std::unique_ptr<HloModule>> GetParsedModule(
30 HloComputation** entry_computation, HloInstruction** param0,
31 HloInstruction** param1, HloInstruction** param2) {
32 const char* const hlo_string = R"(
33 HloModule ModuleWithWhile
34
35 while_body {
36 ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0)
37 }
38
39 while_condition {
40 p_cond = f32[32,32]{1,0} parameter(0)
41 ROOT result = pred[] constant(true)
42 }
43
44 ENTRY entry {
45 p_entry_0 = f32[32,32]{1,0} parameter(0)
46 p_entry_1 = s32[32,32]{1,0} parameter(1)
47 p_entry_2 = s64[32,32]{1,0} parameter(2)
48 while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0)
49 ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body
50 }
51 )";
52
53 TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
54 ParseHloString(hlo_string));
55
56 *entry_computation = module->entry_computation();
57 *param0 = (*entry_computation)->parameter_instruction(0);
58 *param1 = (*entry_computation)->parameter_instruction(1);
59 *param2 = (*entry_computation)->parameter_instruction(2);
60
61 return std::move(module);
62 }
63
TEST(WhileUtil,MakeZeroInstructionsLiveOp)64 TEST(WhileUtil, MakeZeroInstructionsLiveOp) {
65 HloInstruction *param0, *param1, *param2;
66 HloComputation* entry_computation;
67
68 TF_ASSERT_OK_AND_ASSIGN(
69 std::unique_ptr<HloModule> module,
70 GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2));
71
72 HloInstruction* while_instr = entry_computation->root_instruction();
73 ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
74
75 TF_ASSERT_OK_AND_ASSIGN(
76 WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
77 WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{}));
78
79 HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
80
81 EXPECT_THAT(
82 entry_computation->root_instruction(),
83 op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
84 op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
85
86 auto param_reconstructed =
87 op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
88 op::GetTupleElement(op::Parameter(0), 1));
89
90 EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
91 op::Tuple(op::GetTupleElement(param_reconstructed, 0),
92 op::GetTupleElement(param_reconstructed, 1)));
93 }
94
TEST(WhileUtilTest,MakeTwoInstructionsLive)95 TEST(WhileUtilTest, MakeTwoInstructionsLive) {
96 HloInstruction *param0, *param1, *param2;
97 HloComputation* entry_computation;
98
99 TF_ASSERT_OK_AND_ASSIGN(
100 std::unique_ptr<HloModule> module,
101 GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2));
102
103 HloInstruction* while_instr = entry_computation->root_instruction();
104 ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile);
105
106 TF_ASSERT_OK_AND_ASSIGN(
107 WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
108 WhileUtil::MakeInstructionsLiveIn(while_instr,
109 /*instructions=*/{param0, param1}));
110
111 HloInstruction* new_while_instr = make_live_in_result.new_while_instr;
112
113 XLA_VLOG_LINES(3, module->ToString());
114
115 EXPECT_THAT(
116 entry_computation->root_instruction(),
117 op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0),
118 op::GetTupleElement(::testing::Eq(new_while_instr), 1)));
119
120 auto first_half_param_reconstructed =
121 op::Tuple(op::GetTupleElement(op::Parameter(0), 0),
122 op::GetTupleElement(op::Parameter(0), 1));
123
124 EXPECT_THAT(new_while_instr->while_body()->root_instruction(),
125 op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0),
126 op::GetTupleElement(first_half_param_reconstructed, 1),
127 op::GetTupleElement(op::Parameter(0), 2),
128 op::GetTupleElement(op::Parameter(0), 3)));
129 }
130
TEST(WhileUtilTest,GetInvariantGTEsForWhileBody)131 TEST(WhileUtilTest, GetInvariantGTEsForWhileBody) {
132 const char* const hlo_string = R"(
133 HloModule ModuleWithWhile
134
135 body {
136 param.b = (s32[], s32[]) parameter(0)
137 gte.0 = s32[] get-tuple-element(param.b), index=0
138 gte.1 = s32[] get-tuple-element(param.b), index=1
139 add = s32[] add(gte.0, gte.1)
140 ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
141 }
142
143 cond {
144 param.c = (s32[], s32[]) parameter(0)
145 ROOT constant = pred[] constant(true)
146 }
147
148 ENTRY main {
149 init = (s32[], s32[]) parameter(0)
150 ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
151 }
152 )";
153
154 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
155 ParseHloString(hlo_string));
156
157 HloComputation* while_body = module->GetComputationWithName("body");
158
159 ASSERT_NE(while_body, nullptr)
160 << "Expected exactly one while_body computation";
161
162 std::vector<HloInstruction*> gte_list =
163 WhileUtil::GetInvariantGTEsForWhileBody(*while_body);
164
165 ASSERT_EQ(gte_list.size(), 1);
166 EXPECT_EQ((*gte_list.begin())->name(), "gte.0");
167 }
168
TEST(WhileUtilTest,AlwaysRemovePreviousWhileBody)169 TEST(WhileUtilTest, AlwaysRemovePreviousWhileBody) {
170 const char* const hlo_string = R"(
171 HloModule WhileWithSideEffects
172
173 body {
174 param.b = (s32[], s32[]) parameter(0)
175 gte.0 = s32[] get-tuple-element(param.b), index=0
176 gte.1 = s32[] get-tuple-element(param.b), index=1
177 add = s32[] add(gte.0, gte.1)
178 ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
179 }
180
181 cond {
182 param.c = (s32[], s32[]) parameter(0)
183 token0 = token[] after-all()
184 infeed = (pred[], token[]) infeed(token0)
185 ROOT condition = pred[] get-tuple-element(infeed), index=0
186 }
187
188 ENTRY main {
189 init = (s32[], s32[]) parameter(0)
190 to_make_live_in = f32[100] parameter(1)
191 ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
192 }
193 )";
194
195 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
196 ParseHloString(hlo_string));
197
198 HloComputation* main = module->GetComputationWithName("main");
199 HloInstruction* while_instr = main->root_instruction();
200 HloInstruction* to_make_live_in = main->parameter_instruction(1);
201
202 TF_ASSERT_OK_AND_ASSIGN(
203 WhileUtil::MakeInstructionsLiveInResult make_live_in_result,
204 WhileUtil::MakeInstructionsLiveIn(while_instr,
205 /*instructions=*/{to_make_live_in}));
206
207 auto is_while = [](const HloInstruction* instr) {
208 return instr->opcode() == HloOpcode::kWhile;
209 };
210 EXPECT_EQ(absl::c_count_if(main->instructions(), is_while), 1);
211 }
212 } // namespace
213 } // namespace xla
214