1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/container/flat_hash_map.h"
23 #include "tensorflow/compiler/xla/service/heap_simulator.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_dce.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
28 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
29 #include "tensorflow/compiler/xla/service/hlo_parser.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35
36 namespace xla {
37 namespace {
38
39 class HloSchedulingTest : public HloTestBase {};
40
TEST_F(HloSchedulingTest,LastUseScheduledFirst)41 TEST_F(HloSchedulingTest, LastUseScheduledFirst) {
42 // Tests scheduling of the following HLO code:
43 //
44 // %ab = abs(%param)
45 // %exp = exp(%param)
46 // %add = add(%ab, %exp)
47 // %negate = negate(%exp)
48 // %sub = subtract(%add, %negate)
49 //
50 // %add should be scheduled before %negate because %add is the last (and only)
51 // use of %ab. Scheduling %add first then frees up %ab's buffer.
52 const Shape vec = ShapeUtil::MakeShape(xla::F32, {42});
53 auto builder = HloComputation::Builder(TestName());
54 auto param =
55 builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param"));
56 auto ab = builder.AddInstruction(
57 HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param));
58 auto exp = builder.AddInstruction(
59 HloInstruction::CreateUnary(vec, HloOpcode::kExp, param));
60
61 auto add = builder.AddInstruction(
62 HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp));
63 auto negate = builder.AddInstruction(
64 HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp));
65 auto sub = builder.AddInstruction(
66 HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate));
67
68 auto module = CreateNewVerifiedModule();
69 module->AddEntryComputation(builder.Build());
70
71 HloMemoryScheduler scheduler([](const BufferValue& buffer) {
72 return ShapeUtil::ByteSizeOf(buffer.shape());
73 });
74 ASSERT_FALSE(module->has_schedule());
75 TF_ASSERT_OK_AND_ASSIGN(bool changed, scheduler.Run(module.get()));
76 EXPECT_TRUE(changed);
77 ASSERT_TRUE(module->has_schedule());
78 TF_ASSERT_OK(module->schedule().Verify());
79
80 // Verify that all instructions are in the sequence.
81 const std::vector<HloInstruction*>& sequence =
82 module->schedule().sequence(module->entry_computation()).instructions();
83 EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
84
85 // The first instruction should be the parameter and the last the root "sub".
86 EXPECT_EQ(param, sequence.front());
87 EXPECT_EQ(sub, sequence.back());
88
89 SequentialHloOrdering ordering(module->schedule());
90 EXPECT_TRUE(ordering.ExecutesBefore(add, negate));
91
92 // Clear the schedule using the descheduling pass.
93 HloDescheduler descheduler;
94 EXPECT_TRUE(module->has_schedule());
95 TF_ASSERT_OK_AND_ASSIGN(bool descheduler_changed,
96 descheduler.Run(module.get()));
97 EXPECT_TRUE(descheduler_changed);
98 EXPECT_FALSE(module->has_schedule());
99 }
100
TEST_F(HloSchedulingTest,ListSchedulerHandlesAliasing)101 TEST_F(HloSchedulingTest, ListSchedulerHandlesAliasing) {
102 const char* module_str = R"(
103 HloModule test_aliasing_module
104
105 ENTRY root {
106 param = s32[1000] parameter(0)
107 p0 = s32[1000] copy(param)
108 p1 = s32[1000] copy(param)
109 t = (s32[1000], s32[1000]) tuple(p0, p1)
110 a = s32[1000] get-tuple-element(t), index=0
111 b = s32[1000] get-tuple-element(t), index=1
112 c = s32[1000] add(a, b)
113 d = s32[1000] add(c, b)
114 e = s32[1000] add(c, c)
115 f = s32[1000] add(e, e)
116 ROOT result = (s32[1000], s32[1000], s32[1000]) tuple(d, e, f)
117 })";
118
119 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
120 ParseHloString(module_str));
121
122 auto size_fn = [](const BufferValue& buffer) {
123 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
124 };
125 TF_ASSERT_OK_AND_ASSIGN(
126 HloSchedule schedule,
127 ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
128 // Verify that all instructions are in the sequence.
129 const std::vector<HloInstruction*>& sequence =
130 schedule.sequence(module->entry_computation()).instructions();
131 EXPECT_EQ(module->entry_computation()->instruction_count(), sequence.size());
132
133 std::unordered_map<string, const HloInstruction*> instructions_by_name;
134 for (const HloInstruction* instruction : sequence) {
135 instructions_by_name[instruction->name()] = instruction;
136 }
137
138 // The first instruction should be the parameter and the last the root.
139 EXPECT_EQ(instructions_by_name.at("param"), sequence.front());
140 EXPECT_EQ(instructions_by_name.at("result"), sequence.back());
141
142 // Instructions "d" and "e" will both be schedulable at the same time, but
143 // instruction "d" allows us to free the buffer of "p1", so the list scheduler
144 // should prefer it.
145 SequentialHloOrdering ordering(schedule);
146 EXPECT_TRUE(ordering.ExecutesBefore(instructions_by_name.at("d"),
147 instructions_by_name.at("e")));
148 }
149
TEST_F(HloSchedulingTest,TuplesAreAccountedCorrectly)150 TEST_F(HloSchedulingTest, TuplesAreAccountedCorrectly) {
151 auto builder = HloComputation::Builder(TestName());
152 const auto TUPLE_SIZE = 1;
153 const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {6});
154
155 // Wrap lit in abs because constants are considered free by
156 // IgnoreInstruction, and it skews the accounting.
157 auto lit = builder.AddInstruction(HloInstruction::CreateConstant(
158 LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1, 1})));
159 auto abs_const = builder.AddInstruction(
160 HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, lit));
161
162 auto abs_abs1 = builder.AddInstruction(
163 HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
164 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
165 absl::Span<HloInstruction* const>({abs_abs1})));
166 auto tuple_elm = builder.AddInstruction(
167 HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
168
169 auto abs_abs2 = builder.AddInstruction(
170 HloInstruction::CreateUnary(r1f32, HloOpcode::kAbs, abs_const));
171
172 builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd,
173 tuple_elm, abs_abs2));
174
175 auto module = CreateNewVerifiedModule();
176 module->AddEntryComputation(builder.Build());
177 TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
178 ScheduleModule(
179 module.get(),
180 [](const BufferValue& buffer) {
181 return ShapeUtil::ByteSizeOf(buffer.shape(),
182 TUPLE_SIZE);
183 },
184 ListMemoryScheduler));
185
186 // Verify that all instructions are in the sequence.
187 EXPECT_EQ(module->entry_computation()->instruction_count(),
188 schedule.sequence(module->entry_computation()).size());
189 SequentialHloOrdering ordering(schedule);
190 // tuple allocates the tuple buffer and doesn't free anything.
191 // abs_abs2 uses the same buffer for input/output, so its bytes-freed is 0.
192 // abs_abs2 should be scheduled before tuple by List.
193 EXPECT_TRUE(ordering.ExecutesBefore(abs_abs2, tuple));
194 }
195
TEST_F(HloSchedulingTest,MultiOutputFusionAccountedCorrectly)196 TEST_F(HloSchedulingTest, MultiOutputFusionAccountedCorrectly) {
197 const Shape r1f32 = ShapeUtil::MakeShape(xla::F32, {5});
198 HloComputation::Builder builder(TestName());
199
200 auto c1 = builder.AddInstruction(HloInstruction::CreateConstant(
201 LiteralUtil::CreateR1<float>({1, 1, 1, 1, 1})));
202 auto c2 = builder.AddInstruction(HloInstruction::CreateConstant(
203 LiteralUtil::CreateR1<float>({1, 2, 3, 4, 5})));
204 auto c3 = builder.AddInstruction(HloInstruction::CreateConstant(
205 LiteralUtil::CreateR1<float>({0, 2, 4, 6, 8})));
206
207 auto add = builder.AddInstruction(
208 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, c1, c2));
209 auto mul = builder.AddInstruction(
210 HloInstruction::CreateBinary(r1f32, HloOpcode::kMultiply, add, c3));
211 auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({add, mul}));
212
213 auto tuple_elm = builder.AddInstruction(
214 HloInstruction::CreateGetTupleElement(r1f32, tuple, 0));
215
216 auto exp = builder.AddInstruction(
217 HloInstruction::CreateUnary(r1f32, HloOpcode::kExp, c3));
218
219 builder.AddInstruction(
220 HloInstruction::CreateBinary(r1f32, HloOpcode::kAdd, tuple_elm, exp));
221
222 auto module = CreateNewVerifiedModule();
223 auto* computation = module->AddEntryComputation(builder.Build());
224
225 auto fusion = computation->CreateFusionInstruction(
226 {tuple, mul, add}, HloInstruction::FusionKind::kLoop);
227
228 TF_ASSERT_OK_AND_ASSIGN(HloSchedule schedule,
229 ScheduleModule(
230 module.get(),
231 [](const BufferValue& buffer) {
232 return ShapeUtil::ByteSizeOf(buffer.shape(), 2);
233 },
234 ListMemoryScheduler));
235
236 // Verify that all instructions are in the sequence.
237 EXPECT_EQ(module->entry_computation()->instruction_count(),
238 schedule.sequence(module->entry_computation()).size());
239 SequentialHloOrdering ordering(schedule);
240 // fusion allocates memory for the tuple elements and doesn't free anything,
241 // so it's more expensive than exp.
242 EXPECT_TRUE(ordering.ExecutesBefore(exp, fusion));
243 }
244
TEST_F(HloSchedulingTest,HeapSimulatorAccountsForSubcomputations)245 TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
246 auto module = CreateNewUnverifiedModule();
247 const Shape r1f32 = ShapeUtil::MakeShape(F32, {4});
248
249 // param != 0
250 // Needs 17 bytes
251 auto cond_builder = HloComputation::Builder("WhileCond");
252 HloInstruction* cond_param = cond_builder.AddInstruction(
253 HloInstruction::CreateParameter(0, r1f32, "cond_param"));
254 HloInstruction* zero_vector =
255 cond_builder.AddInstruction(HloInstruction::CreateConstant(
256 LiteralUtil::CreateR1<float>({0, 0, 0, 0})));
257 cond_builder.AddInstruction(
258 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_param,
259 zero_vector, ComparisonDirection::kNe));
260 auto cond_computation = module->AddEmbeddedComputation(cond_builder.Build());
261
262 // param - 1
263 // Needs 16 bytes
264 auto body_builder = HloComputation::Builder("WhileBody");
265 HloInstruction* body_param = body_builder.AddInstruction(
266 HloInstruction::CreateParameter(0, r1f32, "body_param"));
267 HloInstruction* one_vector =
268 body_builder.AddInstruction(HloInstruction::CreateConstant(
269 LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
270 body_builder.AddInstruction(HloInstruction::CreateBinary(
271 r1f32, HloOpcode::kSubtract, body_param, one_vector));
272 auto body_computation = module->AddEmbeddedComputation(body_builder.Build());
273
274 auto builder = HloComputation::Builder(TestName());
275 HloInstruction* while_init =
276 builder.AddInstruction(HloInstruction::CreateConstant(
277 LiteralUtil::CreateR2<float>({{1, 1, 1, 1}})));
278 // Creates 16 bytes, ignoring subcomputations
279 builder.AddInstruction(HloInstruction::CreateWhile(
280 r1f32, cond_computation, body_computation, while_init));
281
282 module->AddEntryComputation(builder.Build());
283
284 auto size_fn = [](const BufferValue& buffer) {
285 return ShapeUtil::ByteSizeOf(buffer.shape());
286 };
287 TF_ASSERT_OK_AND_ASSIGN(
288 HloSchedule schedule,
289 ScheduleModule(module.get(), size_fn, ListMemoryScheduler));
290 // Verify that all instructions are in the sequence.
291 auto entry_computation = module->entry_computation();
292 EXPECT_EQ(module->entry_computation()->instruction_count(),
293 schedule.sequence(module->entry_computation()).size());
294
295 absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
296 memory_by_computation[cond_computation] = 17;
297 memory_by_computation[body_computation] = 16;
298 std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
299 TuplePointsToAnalysis::Run(module.get()).ValueOrDie();
300
301 // HeapSimulator doesn't account for subcomputations
302 EXPECT_EQ(16, HeapSimulator::MinimumMemoryForComputation(
303 *entry_computation, schedule.sequence(entry_computation),
304 *points_to_analysis, size_fn)
305 .ValueOrDie());
306 // HeapSimulator accounts for subcomputations. Cond is the largest one.
307 // The output buffer of the while is aliased.
308 EXPECT_EQ(17, HeapSimulator::MinimumMemoryForComputation(
309 *entry_computation, schedule.sequence(entry_computation),
310 *points_to_analysis, size_fn, &memory_by_computation)
311 .ValueOrDie());
312 }
313
TEST_F(HloSchedulingTest,TrivialScheduler)314 TEST_F(HloSchedulingTest, TrivialScheduler) {
315 const char* const hlo_string = R"(
316 HloModule ModuleWithWhile
317
318 body {
319 param.b = (s32[], s32[]) parameter(0)
320 gte.0 = s32[] get-tuple-element(param.b), index=0
321 gte.1 = s32[] get-tuple-element(param.b), index=1
322 add = s32[] add(gte.0, gte.1)
323 ROOT tuple = (s32[], s32[]) tuple(gte.0, add)
324 }
325
326 cond {
327 param.c = (s32[], s32[]) parameter(0)
328 ROOT constant = pred[] constant(true)
329 }
330
331 ENTRY main {
332 init = (s32[], s32[]) parameter(0)
333 ROOT while = (s32[], s32[]) while(init), condition=cond, body=body
334 }
335 )";
336 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
337 ParseHloString(hlo_string));
338 EXPECT_FALSE(module->has_schedule());
339 TF_ASSERT_OK(HloTrivialScheduler().Run(module.get()).status());
340 ASSERT_TRUE(module->has_schedule());
341 TF_ASSERT_OK(module->schedule().Verify());
342
343 // Verify that a clone of the module also has a schedule.
344 std::unique_ptr<HloModule> clone = module->Clone();
345 ASSERT_TRUE(clone->has_schedule());
346 TF_ASSERT_OK(clone->schedule().Verify());
347 }
348
349 } // namespace
350 } // namespace xla
351