1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_schedule.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/lib/core/status_test_util.h"
34 
35 namespace xla {
36 namespace {
37 
38 class HloScheduleTest : public HloTestBase {};
39 
TEST_F(HloScheduleTest,UpdateScheduleUnchangedModule)40 TEST_F(HloScheduleTest, UpdateScheduleUnchangedModule) {
41   // Updating the schedule of an unchanged HLO module should not affect the
42   // schedule at all.
43   const string module_str = R"(
44 HloModule UpdateScheduleUnchanged
45 
46 ENTRY main {
47   a = f32[] parameter(0)
48   b = f32[] parameter(1)
49   c = f32[] constant(42.0)
50   sum = f32[] add(a, b)
51   neg = f32[] negate(c)
52   ROOT root = f32[] multiply(sum, neg)
53 }
54 )";
55   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
56                           ParseHloString(module_str));
57   TF_ASSERT_OK_AND_ASSIGN(
58       HloSchedule schedule,
59       ScheduleModule(module.get(), [](const BufferValue& buffer) {
60         return ShapeUtil::ByteSizeOf(buffer.shape());
61       }));
62   const auto& entry_schedule =
63       schedule.sequence(module->entry_computation()).instructions();
64 
65   EXPECT_EQ(entry_schedule.size(), 6);
66 
67   TF_ASSERT_OK(schedule.Update());
68   TF_ASSERT_OK(schedule.Verify());
69 
70   EXPECT_EQ(entry_schedule,
71             schedule.sequence(module->entry_computation()).instructions());
72 }
73 
TEST_F(HloScheduleTest,UpdateScheduleWithNewInstructions)74 TEST_F(HloScheduleTest, UpdateScheduleWithNewInstructions) {
75   // Add some additional instructions to a module and verify the schedule can be
76   // updated.
77   const string module_str = R"(
78 HloModule UpdateScheduleWithNewInstructions
79 
80 ENTRY main {
81   a = f32[] parameter(0)
82   b = f32[] parameter(1)
83   c = f32[] constant(42.0)
84   sum = f32[] add(a, b)
85   neg = f32[] negate(c)
86   ROOT root = f32[] multiply(sum, neg)
87 }
88 )";
89   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
90                           ParseHloString(module_str));
91   TF_ASSERT_OK_AND_ASSIGN(
92       HloSchedule schedule,
93       ScheduleModule(module.get(), [](const BufferValue& buffer) {
94         return ShapeUtil::ByteSizeOf(buffer.shape());
95       }));
96 
97   HloComputation* entry = module->entry_computation();
98   const Shape shape = entry->root_instruction()->shape();
99   HloInstruction* constant = entry->AddInstruction(
100       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
101   HloInstruction* sub = entry->AddInstruction(HloInstruction::CreateBinary(
102       shape, HloOpcode::kSubtract, constant, entry->root_instruction()));
103   entry->set_root_instruction(sub);
104 
105   auto in_schedule = [&](const HloInstruction* hlo) {
106     return absl::c_linear_search(schedule.sequence(entry).instructions(), hlo);
107   };
108 
109   EXPECT_EQ(schedule.sequence(entry).size(), 6);
110   EXPECT_FALSE(in_schedule(constant));
111   EXPECT_FALSE(in_schedule(sub));
112 
113   ASSERT_IS_NOT_OK(schedule.Verify());
114   TF_ASSERT_OK(schedule.Update());
115   TF_ASSERT_OK(schedule.Verify());
116 
117   EXPECT_EQ(schedule.sequence(entry).size(), 8);
118   EXPECT_TRUE(in_schedule(constant));
119   EXPECT_TRUE(in_schedule(sub));
120 }
121 
TEST_F(HloScheduleTest,UpdateScheduleWithAddedAndDeletedInstruction)122 TEST_F(HloScheduleTest, UpdateScheduleWithAddedAndDeletedInstruction) {
123   // Add and delete some instructions from a module and verify that the schedule
124   // can be updated successfully.
125   const string module_str = R"(
126 HloModule UpdateScheduleWithAddedAndDeletedInstruction
127 
128 ENTRY main {
129   a = f32[] parameter(0)
130   b = f32[] parameter(1)
131   c = f32[] constant(42.0)
132   sum = f32[] add(a, b)
133   neg = f32[] negate(c)
134   ROOT root = f32[] multiply(sum, neg)
135 }
136 )";
137 
138   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
139                           ParseHloString(module_str));
140   TF_ASSERT_OK_AND_ASSIGN(
141       HloSchedule schedule,
142       ScheduleModule(module.get(), [](const BufferValue& buffer) {
143         return ShapeUtil::ByteSizeOf(buffer.shape());
144       }));
145 
146   // Set the entry root to some expression containing just a parameter and a
147   // constant.
148   HloComputation* entry = module->entry_computation();
149   HloInstruction* constant = entry->AddInstruction(
150       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
151   HloInstruction* new_root = entry->AddInstruction(
152       HloInstruction::CreateBinary(constant->shape(), HloOpcode::kSubtract,
153                                    constant, entry->parameter_instruction(0)));
154   entry->set_root_instruction(new_root);
155 
156   // DCE should remove everything but the parameters and the newly added code.
157   HloDCE dce;
158   TF_ASSERT_OK(dce.Run(module.get()).status());
159 
160   EXPECT_EQ(schedule.sequence(entry).size(), 6);
161 
162   ASSERT_IS_NOT_OK(schedule.Verify());
163   TF_ASSERT_OK(schedule.Update());
164   TF_ASSERT_OK(schedule.Verify());
165 
166   EXPECT_EQ(schedule.sequence(entry).size(), 4);
167 }
168 
TEST_F(HloScheduleTest,UpdateScheduleWithCompletelyReplacedModule)169 TEST_F(HloScheduleTest, UpdateScheduleWithCompletelyReplacedModule) {
170   // Completely replace a module with an entirely new set of instructions and
171   // verify that the schedule can be updated successfully.
172   const string module_str = R"(
173 HloModule UpdateScheduleWithCompletelyReplacedModule
174 
175 ENTRY main {
176   a = f32[] constant(42.0)
177   b = f32[] constant(123.0)
178   ROOT sum = f32[] add(a, b)
179 }
180 )";
181 
182   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
183                           ParseHloString(module_str));
184   TF_ASSERT_OK_AND_ASSIGN(
185       HloSchedule schedule,
186       ScheduleModule(module.get(), [](const BufferValue& buffer) {
187         return ShapeUtil::ByteSizeOf(buffer.shape());
188       }));
189 
190   // Replace the entry computation with the negation of a constant.
191   HloComputation* entry = module->entry_computation();
192   HloInstruction* constant = entry->AddInstruction(
193       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
194   HloInstruction* new_root = entry->AddInstruction(HloInstruction::CreateUnary(
195       constant->shape(), HloOpcode::kNegate, constant));
196   entry->set_root_instruction(new_root);
197 
198   // DCE the old instructions.
199   HloDCE dce;
200   TF_ASSERT_OK(dce.Run(module.get()).status());
201 
202   EXPECT_EQ(schedule.sequence(entry).size(), 3);
203 
204   ASSERT_IS_NOT_OK(schedule.Verify());
205   TF_ASSERT_OK(schedule.Update());
206   TF_ASSERT_OK(schedule.Verify());
207 
208   EXPECT_EQ(schedule.sequence(entry).size(), 2);
209 }
210 
TEST_F(HloScheduleTest,UpdateScheduleWithMultipleComputations)211 TEST_F(HloScheduleTest, UpdateScheduleWithMultipleComputations) {
212   // Create changes to more than one computation in an HLO module and verify
213   // that the schedule can be updated.
214   const string module_str = R"(
215 HloModule UpdateScheduleWithMultipleComputations
216 
217 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
218   %param.1 = (s32[], token[]) parameter(0)
219   %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
220   %constant.1 = s32[] constant(1)
221   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
222   %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
223   %after-all = token[] after-all(token[] %get-tuple-element.2)
224   ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
225 }
226 
227 %Cond (param: (s32[], token[])) -> pred[] {
228   %param = (s32[], token[]) parameter(0)
229   %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
230   %constant = s32[] constant(42)
231   ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
232 }
233 
234 ENTRY %WhileLoop () -> s32[] {
235   %zero = s32[] constant(0)
236   %init_token = token[] after-all()
237   %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
238   %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
239   ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
240 }
241 )";
242 
243   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
244                           ParseHloString(module_str));
245   TF_ASSERT_OK_AND_ASSIGN(
246       HloSchedule schedule,
247       ScheduleModule(module.get(), [](const BufferValue& buffer) {
248         return ShapeUtil::ByteSizeOf(buffer.shape(),
249                                      /*pointer_size=*/sizeof(void*));
250       }));
251 
252   const HloInstruction* xla_while =
253       module->entry_computation()->root_instruction()->operand(0);
254   HloComputation* body = xla_while->while_body();
255   HloComputation* cond = xla_while->while_condition();
256 
257   // Negate the root of the cond.
258   cond->set_root_instruction(cond->AddInstruction(
259       HloInstruction::CreateUnary(ShapeUtil::MakeShape(PRED, {}),
260                                   HloOpcode::kNot, cond->root_instruction())));
261 
262   // Replace the body with a computation which just passes through its
263   // parameter.
264   body->set_root_instruction(body->parameter_instruction(0));
265 
266   // DCE the dead code in the body.
267   HloDCE dce;
268   TF_ASSERT_OK(dce.Run(module.get()).status());
269 
270   EXPECT_EQ(schedule.sequence(body).size(), 7);
271   EXPECT_EQ(schedule.sequence(cond).size(), 4);
272 
273   ASSERT_IS_NOT_OK(schedule.Verify());
274   TF_ASSERT_OK(schedule.Update());
275   TF_ASSERT_OK(schedule.Verify());
276 
277   EXPECT_EQ(schedule.sequence(body).size(), 1);
278   EXPECT_EQ(schedule.sequence(cond).size(), 5);
279 }
280 
TEST_F(HloScheduleTest,UpdateScheduleComputationRemoved)281 TEST_F(HloScheduleTest, UpdateScheduleComputationRemoved) {
282   // Remove computations from a module and verify the schedule can be updated.
283   const string module_str = R"(
284 HloModule UpdateScheduleWithMultipleComputations
285 
286 %Body (param.1: (s32[], token[])) -> (s32[], token[]) {
287   %param.1 = (s32[], token[]) parameter(0)
288   %get-tuple-element.1 = s32[] get-tuple-element((s32[], token[]) %param.1), index=0
289   %constant.1 = s32[] constant(1)
290   %add = s32[] add(s32[] %get-tuple-element.1, s32[] %constant.1)
291   %get-tuple-element.2 = token[] get-tuple-element((s32[], token[]) %param.1), index=1
292   %after-all = token[] after-all(token[] %get-tuple-element.2)
293   ROOT %tuple = (s32[], token[]) tuple(s32[] %add, token[] %after-all)
294 }
295 
296 %Cond (param: (s32[], token[])) -> pred[] {
297   %param = (s32[], token[]) parameter(0)
298   %get-tuple-element = s32[] get-tuple-element((s32[], token[]) %param), index=0
299   %constant = s32[] constant(42)
300   ROOT %less-than = pred[] compare(s32[] %get-tuple-element, s32[] %constant), direction=LT
301 }
302 
303 ENTRY %WhileLoop () -> s32[] {
304   %zero = s32[] constant(0)
305   %init_token = token[] after-all()
306   %init_tuple = (s32[], token[]) tuple(s32[] %zero, token[] %init_token)
307   %while = (s32[], token[]) while((s32[], token[]) %init_tuple), condition=%Cond, body=%Body
308   ROOT %root = s32[] get-tuple-element((s32[], token[]) %while), index=0
309 }
310 )";
311 
312   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
313                           ParseHloString(module_str));
314   TF_ASSERT_OK_AND_ASSIGN(
315       HloSchedule schedule,
316       ScheduleModule(module.get(), [](const BufferValue& buffer) {
317         return ShapeUtil::ByteSizeOf(buffer.shape(),
318                                      /*pointer_size=*/sizeof(void*));
319       }));
320 
321   HloInstruction* xla_while =
322       module->entry_computation()->root_instruction()->mutable_operand(0);
323   HloInstruction* init = xla_while->mutable_operand(0);
324 
325   // Replace the while with its init value. The conditional and body
326   // computations should then be dead.
327   TF_ASSERT_OK(xla_while->ReplaceAllUsesWith(init));
328 
329   // DCE the dead code in the body.
330   HloDCE dce;
331   ASSERT_EQ(module->computation_count(), 3);
332   TF_ASSERT_OK(dce.Run(module.get()).status());
333   ASSERT_EQ(module->computation_count(), 1);
334 
335   ASSERT_IS_NOT_OK(schedule.Verify());
336   TF_ASSERT_OK(schedule.Update());
337   TF_ASSERT_OK(schedule.Verify());
338 }
339 
340 }  // namespace
341 }  // namespace xla
342