1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/service/heap_simulator.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 
36 namespace xla {
37 namespace {
38 
39 class HloSchedulingTest : public HloTestBase {};
40 
TEST_F(HloSchedulingTest,LastUseScheduledFirst)41 TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
42   // Tests scheduling of the following HLO code:
43   //
44   //   %ab = abs(%param)
45   //   %exp = exp(%param)
46   //   %add = add(%ab, %exp)
47   //   %negate = negate(%exp)
48   //   %sub = subtract(%add, %negate)
49   //
50   // %add should be scheduled before %negate because %add is the last (and only)
51   // use of %ab. Scheduling %add first then frees up %ab's buffer.
52   const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
53   auto builder = HloComputation::Builder(TestName());
54   auto param =
55       builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
56   auto ab = builder.AddInstruction(
57       HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
58   auto exp = builder.AddInstruction(
59       HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
60 
61   auto add = builder.AddInstruction(
62       HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
63   auto negate = builder.AddInstruction(
64       HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
65   auto sub = builder.AddInstruction(
66       HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
67 
68   auto module = CreateNewVerifiedModule();
69   module->AddEntryComputation(builder.Build());
70 
71   HloMemoryScheduler scheduler([](const BufferValue& buffer) {
72     return ShapeUtil::ByteSizeOf(buffer.shape());
73   });
74   ASSERT_FALSE(module->has_schedule());
75   TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
76   EXPECT_TRUE(changed);
77   ASSERT_TRUE(module->has_schedule());
78   TF_ASSERT_OK(module->schedule().Verify());
79 
80   // Verify that all instructions are in the sequence.
81   const std::vector<HloInstruction*>& sequence =
82       module->schedule().sequence(module->entry_computation()).instructions();
83   EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
84 
85   // The first instruction should be the parameter and the last the root "sub".
86   EXPECT_EQ(param, sequence.front());
87   EXPECT_EQ(sub, sequence.back());
88 
89   SequentialHloOrdering ordering(module->schedule());
90   EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
91 
92   // Clear the schedule using the descheduling pass.
93   HloDescheduler descheduler;
94   EXPECT_TRUE(module->has_schedule());
95   TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
96                           descheduler.Run(module.get()));
97   EXPECT_TRUE(descheduler_changed);
98   EXPECT_FALSE(module->has_schedule());
99 }
100 
TEST_F(HloSchedulingTest,ListSchedulerHandlesAliasing)101 TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
102   const char* module_str = R"(
103 HloModule test_aliasing_module
104 
105 ENTRY root {
106   param = s32[1000] parameter(0)
107   p0 = s32[1000] copy(param)
108   p1 = s32[1000] copy(param)
109   t = (s32[1000], s32[1000]) tuple(p0, p1)
110   a = s32[1000] get-tuple-element(t), index=0
111   b = s32[1000] get-tuple-element(t), index=1
112   c = s32[1000] add(a, b)
113   d = s32[1000] add(c, b)
114   e = s32[1000] add(c, c)
115   f = s32[1000] add(e, e)
116   ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f)
117 })";
118 
119   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
120                           ParseHloString(module_str));
121 
122   auto size_fn = [](const BufferValue& buffer) {
123     return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
124   };
125   TF_ASSERT_OK_AND_ASSIGN(
126       HloSchedule schedule,
127       ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
128   // Verify that all instructions are in the sequence.
129   const std::vector<HloInstruction*>& sequence =
130       schedule.sequence(module->entry_computation()).instructions();
131   EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
132 
133   std::unordered_map<string, const HloInstruction*> instructions_by_name;
134   for (const HloInstruction* instruction : sequence) {
135     instructions_by_name[instruction->name()] = instruction;
136   }
137 
138   // The first instruction should be the parameter and the last the root.
139   EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
140   EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
141 
142   // Instructions "d" and "e" will both be schedulable at the same time, but
143   // instruction "d" allows us to free the buffer of "p1", so the list scheduler
144   // should prefer it.
145   SequentialHloOrdering ordering(schedule);
146   EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
147                                       instructions_by_name.at("e")));
148 }
149 
TEST_F(HloSchedulingTest,TuplesAreAccountedCorrectly)150 TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
151   auto builder = HloComputation::Builder(TestName());
152   const auto TUPLE_SIZE = 1;
153   const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
154 
155   // Wrap lit in abs because constants are considered free by
156   // IgnoreInstruction, and it skews the accounting.
157   auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
158       LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
159   auto abs_const = builder.AddInstruction(
160       HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
161 
162   auto abs_abs1 = builder.AddInstruction(
163       HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
164   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
165       absl::Span<HloInstruction* const>({abs_abs1})));
166   auto tuple_elm = builder.AddInstruction(
167       HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
168 
169   auto abs_abs2 = builder.AddInstruction(
170       HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
171 
172   builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
173                                                       tuple_elm, abs_abs2));
174 
175   auto module = CreateNewVerifiedModule();
176   module->AddEntryComputation(builder.Build());
177   TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
178                           ScheduleModule(
179                               module.get(),
180                               [](const BufferValue& buffer) {
181                                 return ShapeUtil::ByteSizeOf(buffer.shape(),
182                                                              TUPLE_SIZE);
183                               },
184                               ListMemoryScheduler));
185 
186   // Verify that all instructions are in the sequence.
187   EXPECT_EQ(module->entry_computation()->instruction_count(),
188             schedule.sequence(module->entry_computation()).size());
189   SequentialHloOrdering ordering(schedule);
190   // tuple allocates the tuple buffer and doesn't free anything.
191   // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
192   // abs_abs2 should be scheduled before tuple by List.
193   EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
194 }
195 
TEST_F(HloSchedulingTest,MultiOutputFusionAccountedCorrectly)196 TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
197   const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
198   HloComputation::Builder builder(TestName());
199 
200   auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
201       LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
202   auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
203       LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
204   auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
205       LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
206 
207   auto add = builder.AddInstruction(
208       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
209   auto mul = builder.AddInstruction(
210       HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
211   auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
212 
213   auto tuple_elm = builder.AddInstruction(
214       HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
215 
216   auto exp = builder.AddInstruction(
217       HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
218 
219   builder.AddInstruction(
220       HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
221 
222   auto module = CreateNewVerifiedModule();
223   auto* computation = module->AddEntryComputation(builder.Build());
224 
225   auto fusion = computation->CreateFusionInstruction(
226       {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
227 
228   TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
229                           ScheduleModule(
230                               module.get(),
231                               [](const BufferValue& buffer) {
232                                 return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
233                               },
234                               ListMemoryScheduler));
235 
236   // Verify that all instructions are in the sequence.
237   EXPECT_EQ(module->entry_computation()->instruction_count(),
238             schedule.sequence(module->entry_computation()).size());
239   SequentialHloOrdering ordering(schedule);
240   // fusion allocates memory for the tuple elements and doesn't free anything,
241   // so it's more expensive than exp.
242   EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
243 }
244 
TEST_F(HloSchedulingTest,HeapSimulatorAccountsForSubcomputations)245 TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
246   auto module = CreateNewUnverifiedModule();
247   const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
248 
249   // param != 0
250   // Needs 17 bytes
251   auto cond_builder = HloComputation::Builder("WhileCond");
252   HloInstruction* cond_param = cond_builder.AddInstruction(
253       HloInstruction::CreateParameter(0, r1f32, "cond_param"));
254   HloInstruction* zero_vector =
255       cond_builder.AddInstruction(HloInstruction::CreateConstant(
256           LiteralUtil::CreateR1<float>({0, 0, 0, 0})));
257   cond_builder.AddInstruction(
258       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param,
259                                     zero_vector, ComparisonDirection::kNe));
260   auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
261 
262   // param - 1
263   // Needs 16 bytes
264   auto body_builder = HloComputation::Builder("WhileBody");
265   HloInstruction* body_param = body_builder.AddInstruction(
266       HloInstruction::CreateParameter(0, r1f32, "body_param"));
267   HloInstruction* one_vector =
268       body_builder.AddInstruction(HloInstruction::CreateConstant(
269           LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
270   body_builder.AddInstruction(HloInstruction::CreateBinary(
271       r1f32, HloOpcode::kSubtract, body_param, one_vector));
272   auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
273 
274   auto builder = HloComputation::Builder(TestName());
275   HloInstruction* while_init =
276       builder.AddInstruction(HloInstruction::CreateConstant(
277           LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
278   // Creates 16 bytes, ignoring subcomputations
279   builder.AddInstruction(HloInstruction::CreateWhile(
280       r1f32, cond_computation, body_computation, while_init));
281 
282   module->AddEntryComputation(builder.Build());
283 
284   auto size_fn = [](const BufferValue& buffer) {
285     return ShapeUtil::ByteSizeOf(buffer.shape());
286   };
287   TF_ASSERT_OK_AND_ASSIGN(
288       HloSchedule schedule,
289       ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
290   // Verify that all instructions are in the sequence.
291   auto entry_computation = module->entry_computation();
292   EXPECT_EQ(module->entry_computation()->instruction_count(),
293             schedule.sequence(module->entry_computation()).size());
294 
295   absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
296   memory_by_computation[cond_computation] = 17;
297   memory_by_computation[body_computation] = 16;
298   std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
299       TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
300 
301   // HeapSimulator doesn't account for subcomputations
302   EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
303                     *entry_computation, schedule.sequence(entry_computation),
304                     *points_to_analysis, size_fn)
305                     .ValueOrDie());
306   // HeapSimulator accounts for subcomputations. Cond is the largest one.
307   // The output buffer of the while is aliased.
308   EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
309                     *entry_computation, schedule.sequence(entry_computation),
310                     *points_to_analysis, size_fn, &memory_by_computation)
311                     .ValueOrDie());
312 }
313 
TEST_F(HloSchedulingTest,TrivialScheduler)314 TEST_F(HloSchedulingTest, TrivialScheduler) {
315   const char* const hlo_string = R"(
316 HloModule ModuleWithWhile
317 
318 body {
319   param.b = (s32[], s32[]) parameter(0)
320   gte.0 = s32[] get-tuple-element(param.b), index=0
321   gte.1 = s32[] get-tuple-element(param.b), index=1
322   add = s32[] add(gte.0, gte.1)
323   ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
324 }
325 
326 cond {
327   param.c = (s32[], s32[]) parameter(0)
328   ROOT constant = pred[] constant(true)
329 }
330 
331 ENTRY main {
332   init = (s32[], s32[]) parameter(0)
333   ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
334 }
335 )";
336   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
337                           ParseHloString(hlo_string));
338   EXPECT_FALSE(module->has_schedule());
339   TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status());
340   ASSERT_TRUE(module->has_schedule());
341   TF_ASSERT_OK(module->schedule().Verify());
342 
343   // Verify that a clone of the module also has a schedule.
344   std::unique_ptr<HloModule> clone = module->Clone();
345   ASSERT_TRUE(clone->has_schedule());
346   TF_ASSERT_OK(clone->schedule().Verify());
347 }
348 
349 }  // namespace
350 }  // namespace xla
351