1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34
35 namespace xla {
36 namespace {
37
38 class HloScheduleTest : public HloTestBase {};
39
TEST_F(HloScheduleTest,UpdateScheduleUnchangedModule)40 TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
41 // Updating the schedule of an unchanged HLO module should not affect the
42 // schedule at all.
43 const string module_str = R"(
44 HloModule UpdateScheduleUnchanged
45
46 ENTRY main {
47 a = f32[] parameter(0)
48 b = f32[] parameter(1)
49 c = f32[] constant(42.0)
50 sum = f32[] add(a, b)
51 neg = f32[] negate(c)
52 ROOT root = f32[] multiply(sum, neg)
53 }
54 )";
55 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
56 ParseHloString(module_str));
57 TF_ASSERT_OK_AND_ASSIGN(
58 HloSchedule schedule,
59 ScheduleModule(module.get(), [](const BufferValue& buffer) {
60 return ShapeUtil::ByteSizeOf(buffer.shape());
61 }));
62 const auto& entry_schedule =
63 schedule.sequence(module->entry_computation()).instructions();
64
65 EXPECT_EQ(entry_schedule.size(), 6);
66
67 TF_ASSERT_OK(schedule.Update());
68 TF_ASSERT_OK(schedule.Verify());
69
70 EXPECT_EQ(entry_schedule,
71 schedule.sequence(module->entry_computation()).instructions());
72 }
73
TEST_F(HloScheduleTest,UpdateScheduleWithNewInstructions)74 TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
75 // Add some additional instructions to a module and verify the schedule can be
76 // updated.
77 const string module_str = R"(
78 HloModule UpdateScheduleWithNewInstructions
79
80 ENTRY main {
81 a = f32[] parameter(0)
82 b = f32[] parameter(1)
83 c = f32[] constant(42.0)
84 sum = f32[] add(a, b)
85 neg = f32[] negate(c)
86 ROOT root = f32[] multiply(sum, neg)
87 }
88 )";
89 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
90 ParseHloString(module_str));
91 TF_ASSERT_OK_AND_ASSIGN(
92 HloSchedule schedule,
93 ScheduleModule(module.get(), [](const BufferValue& buffer) {
94 return ShapeUtil::ByteSizeOf(buffer.shape());
95 }));
96
97 HloComputation* entry = module->entry_computation();
98 const Shape shape = entry->root_instruction()->shape();
99 HloInstruction* constant = entry->AddInstruction(
100 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
101 HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
102 shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
103 entry->set_root_instruction(sub);
104
105 auto in_schedule = [&](const HloInstruction* hlo) {
106 return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
107 };
108
109 EXPECT_EQ(schedule.sequence(entry).size(), 6);
110 EXPECT_FALSE(in_schedule(constant));
111 EXPECT_FALSE(in_schedule(sub));
112
113 ASSERT_IS_NOT_OK(schedule.Verify());
114 TF_ASSERT_OK(schedule.Update());
115 TF_ASSERT_OK(schedule.Verify());
116
117 EXPECT_EQ(schedule.sequence(entry).size(), 8);
118 EXPECT_TRUE(in_schedule(constant));
119 EXPECT_TRUE(in_schedule(sub));
120 }
121
TEST_F(HloScheduleTest,UpdateScheduleWithAddedAndDeletedInstruction)122 TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
123 // Add and delete some instructions from a module and verify that the schedule
124 // can be updated successfully.
125 const string module_str = R"(
126 HloModule UpdateScheduleWithAddedAndDeletedInstruction
127
128 ENTRY main {
129 a = f32[] parameter(0)
130 b = f32[] parameter(1)
131 c = f32[] constant(42.0)
132 sum = f32[] add(a, b)
133 neg = f32[] negate(c)
134 ROOT root = f32[] multiply(sum, neg)
135 }
136 )";
137
138 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139 ParseHloString(module_str));
140 TF_ASSERT_OK_AND_ASSIGN(
141 HloSchedule schedule,
142 ScheduleModule(module.get(), [](const BufferValue& buffer) {
143 return ShapeUtil::ByteSizeOf(buffer.shape());
144 }));
145
146 // Set the entry root to some expression containing just a parameter and a
147 // constant.
148 HloComputation* entry = module->entry_computation();
149 HloInstruction* constant = entry->AddInstruction(
150 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
151 HloInstruction* new_root = entry->AddInstruction(
152 HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
153 constant, entry->parameter_instruction(0)));
154 entry->set_root_instruction(new_root);
155
156 // DCE should remove everything but the parameters and the newly added code.
157 HloDCE dce;
158 TF_ASSERT_OK(dce.Run(module.get()).status());
159
160 EXPECT_EQ(schedule.sequence(entry).size(), 6);
161
162 ASSERT_IS_NOT_OK(schedule.Verify());
163 TF_ASSERT_OK(schedule.Update());
164 TF_ASSERT_OK(schedule.Verify());
165
166 EXPECT_EQ(schedule.sequence(entry).size(), 4);
167 }
168
TEST_F(HloScheduleTest,UpdateScheduleWithCompletelyReplacedModule)169 TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
170 // Completely replace a module with an entirely new set of instructions and
171 // verify that the schedule can be updated successfully.
172 const string module_str = R"(
173 HloModule UpdateScheduleWithCompletelyReplacedModule
174
175 ENTRY main {
176 a = f32[] constant(42.0)
177 b = f32[] constant(123.0)
178 ROOT sum = f32[] add(a, b)
179 }
180 )";
181
182 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
183 ParseHloString(module_str));
184 TF_ASSERT_OK_AND_ASSIGN(
185 HloSchedule schedule,
186 ScheduleModule(module.get(), [](const BufferValue& buffer) {
187 return ShapeUtil::ByteSizeOf(buffer.shape());
188 }));
189
190 // Replace the entry computation with the negation of a constant.
191 HloComputation* entry = module->entry_computation();
192 HloInstruction* constant = entry->AddInstruction(
193 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
194 HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
195 constant->shape(), HloOpcode::kNegate, constant));
196 entry->set_root_instruction(new_root);
197
198 // DCE the old instructions.
199 HloDCE dce;
200 TF_ASSERT_OK(dce.Run(module.get()).status());
201
202 EXPECT_EQ(schedule.sequence(entry).size(), 3);
203
204 ASSERT_IS_NOT_OK(schedule.Verify());
205 TF_ASSERT_OK(schedule.Update());
206 TF_ASSERT_OK(schedule.Verify());
207
208 EXPECT_EQ(schedule.sequence(entry).size(), 2);
209 }
210
TEST_F(HloScheduleTest,UpdateScheduleWithMultipleComputations)211 TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
212 // Create changes to more than one computation in an HLO module and verify
213 // that the schedule can be updated.
214 const string module_str = R"(
215 HloModule UpdateScheduleWithMultipleComputations
216
217 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
218 %param.1 = (s32[], token[]) parameter(0)
219 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
220 %constant.1 = s32[] constant(1)
221 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
222 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
223 %after-all = token[] after-all(token[] %get-tuple-element.2)
224 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
225 }
226
227 %Cond (param: (s32[], token[])) -> pred[] {
228 %param = (s32[], token[]) parameter(0)
229 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
230 %constant = s32[] constant(42)
231 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
232 }
233
234 ENTRY %WhileLoop () -> s32[] {
235 %zero = s32[] constant(0)
236 %init_token = token[] after-all()
237 %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
238 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
239 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
240 }
241 )";
242
243 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
244 ParseHloString(module_str));
245 TF_ASSERT_OK_AND_ASSIGN(
246 HloSchedule schedule,
247 ScheduleModule(module.get(), [](const BufferValue& buffer) {
248 return ShapeUtil::ByteSizeOf(buffer.shape(),
249 /*pointer_size=*/sizeof(void*));
250 }));
251
252 const HloInstruction* xla_while =
253 module->entry_computation()->root_instruction()->operand(0);
254 HloComputation* body = xla_while->while_body();
255 HloComputation* cond = xla_while->while_condition();
256
257 // Negate the root of the cond.
258 cond->set_root_instruction(cond->AddInstruction(
259 HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
260 HloOpcode::kNot, cond->root_instruction())));
261
262 // Replace the body with a computation which just passes through its
263 // parameter.
264 body->set_root_instruction(body->parameter_instruction(0));
265
266 // DCE the dead code in the body.
267 HloDCE dce;
268 TF_ASSERT_OK(dce.Run(module.get()).status());
269
270 EXPECT_EQ(schedule.sequence(body).size(), 7);
271 EXPECT_EQ(schedule.sequence(cond).size(), 4);
272
273 ASSERT_IS_NOT_OK(schedule.Verify());
274 TF_ASSERT_OK(schedule.Update());
275 TF_ASSERT_OK(schedule.Verify());
276
277 EXPECT_EQ(schedule.sequence(body).size(), 1);
278 EXPECT_EQ(schedule.sequence(cond).size(), 5);
279 }
280
TEST_F(HloScheduleTest,UpdateScheduleComputationRemoved)281 TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
282 // Remove computations from a module and verify the schedule can be updated.
283 const string module_str = R"(
284 HloModule UpdateScheduleWithMultipleComputations
285
286 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
287 %param.1 = (s32[], token[]) parameter(0)
288 %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
289 %constant.1 = s32[] constant(1)
290 %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
291 %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
292 %after-all = token[] after-all(token[] %get-tuple-element.2)
293 ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
294 }
295
296 %Cond (param: (s32[], token[])) -> pred[] {
297 %param = (s32[], token[]) parameter(0)
298 %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
299 %constant = s32[] constant(42)
300 ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
301 }
302
303 ENTRY %WhileLoop () -> s32[] {
304 %zero = s32[] constant(0)
305 %init_token = token[] after-all()
306 %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
307 %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
308 ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
309 }
310 )";
311
312 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
313 ParseHloString(module_str));
314 TF_ASSERT_OK_AND_ASSIGN(
315 HloSchedule schedule,
316 ScheduleModule(module.get(), [](const BufferValue& buffer) {
317 return ShapeUtil::ByteSizeOf(buffer.shape(),
318 /*pointer_size=*/sizeof(void*));
319 }));
320
321 HloInstruction* xla_while =
322 module->entry_computation()->root_instruction()->mutable_operand(0);
323 HloInstruction* init = xla_while->mutable_operand(0);
324
325 // Replace the while with its init value. The conditional and body
326 // computations should then be dead.
327 TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
328
329 // DCE the dead code in the body.
330 HloDCE dce;
331 ASSERT_EQ(module->computation_count(), 3);
332 TF_ASSERT_OK(dce.Run(module.get()).status());
333 ASSERT_EQ(module->computation_count(), 1);
334
335 ASSERT_IS_NOT_OK(schedule.Verify());
336 TF_ASSERT_OK(schedule.Update());
337 TF_ASSERT_OK(schedule.Verify());
338 }
339
340 } // namespace
341 } // namespace xla
342