1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21
22 namespace xla {
23 namespace {
24
25 namespace op = xla::testing::opcode_matchers;
26
27 constexpr int64 kPointerSize = 8;
28 constexpr float kAsyncCopyBandwidth = 100;
29 constexpr float kAlternateMemBandwidth = 1000;
30 constexpr float kBytesPerSecond = 100;
31 constexpr float kFlopsPerSecond = 1000;
32 constexpr float kTranscendentalsPerSecond = 10;
33
ShapeSize(const Shape & shape)34 int64 ShapeSize(const Shape& shape) {
35 return ShapeUtil::ByteSizeOf(shape, kPointerSize);
36 }
37
38 class MemorySpaceAssignmentTest : public HloTestBase,
39 public ::testing::WithParamInterface<bool> {
40 protected:
41 // We use the following two memory space values to describe the default (slow
42 // and large) and alternate (fast and small) memory spaces.
43 const int64 kDefaultMemorySpace = 0;
44 const int64 kAlternateMemorySpace = 1;
45
AssignMemorySpaceUsingCostAnalysis(HloModule * module)46 std::unique_ptr<PresetAssignments> AssignMemorySpaceUsingCostAnalysis(
47 HloModule* module) {
48 HloCostAnalysis hlo_cost_analysis(ShapeSize);
49 hlo_cost_analysis.set_flops_per_second(kFlopsPerSecond);
50 hlo_cost_analysis.set_bytes_per_second(kBytesPerSecond);
51 hlo_cost_analysis.set_transcendentals_per_second(kTranscendentalsPerSecond);
52 for (HloComputation* computation : module->MakeNonfusionComputations()) {
53 TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
54 }
55 auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
56 auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create(
57 hlo_cost_analysis, kAsyncCopyBandwidth,
58 kAlternateMemBandwidth, *module)
59 .ValueOrDie();
60 CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
61 CostAnalysisPrefetchIntervalPicker(
62 *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
63 /*max_async_copy_to_overlap_ratio=*/10.0,
64 /*preferred_async_copy_to_overlap_ratio=*/1.5));
65 return AssignMemorySpace(
66 module, /*max_outstanding_async_copies=*/-1,
67 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
68 *cost_analysis, &cache_),
69 &prefetch_interval_picker);
70 }
71
AssignMemorySpace(HloModule * module,int64 max_outstanding_async_copies=-1,int64 max_prefetch_interval=10,int64 min_prefetch_interval=2,absl::optional<MemorySpaceAssignment::Options> options=absl::nullopt)72 std::unique_ptr<PresetAssignments> AssignMemorySpace(
73 HloModule* module, int64 max_outstanding_async_copies = -1,
74 int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2,
75 absl::optional<MemorySpaceAssignment::Options> options = absl::nullopt) {
76 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(
77 min_prefetch_interval, max_prefetch_interval);
78 return AssignMemorySpace(module, max_outstanding_async_copies,
79 /*buffer_interval_compare=*/{},
80 &prefetch_interval_picker, options);
81 }
82
AssignMemorySpace(HloModule * module,int64 max_outstanding_async_copies,absl::optional<MemorySpaceAssignment::BufferIntervalCompare> buffer_interval_compare,PrefetchIntervalPicker * prefetch_interval_picker,absl::optional<MemorySpaceAssignment::Options> memory_space_assignment_options=absl::nullopt)83 std::unique_ptr<PresetAssignments> AssignMemorySpace(
84 HloModule* module, int64 max_outstanding_async_copies,
85 absl::optional<MemorySpaceAssignment::BufferIntervalCompare>
86 buffer_interval_compare,
87 PrefetchIntervalPicker* prefetch_interval_picker,
88 absl::optional<MemorySpaceAssignment::Options>
89 memory_space_assignment_options = absl::nullopt) {
90 auto size_fn = [](const BufferValue& buffer) {
91 return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
92 };
93
94 auto is_allowed_in_alternate_mem = [](const HloValue& value) {
95 // Check if the value belongs to the entry computation.
96 HloInstruction* instruction = value.instruction();
97 HloComputation* computation = instruction->parent();
98 bool in_entry_computation =
99 (computation == computation->parent()->entry_computation());
100 if (in_entry_computation &&
101 instruction->opcode() == HloOpcode::kParameter) {
102 return false;
103 }
104 return true;
105 };
106
107 // Only check parameters in default memory if the original module didn't
108 // have the parameters in alternate memory.
109 bool check_parameters_in_default_memory = true;
110 for (const HloInstruction* parameter :
111 module->entry_computation()->parameter_instructions()) {
112 ShapeUtil::ForEachSubshape(
113 parameter->shape(),
114 [&](const Shape& subshape, const ShapeIndex& /*index*/) {
115 if (subshape.has_layout() &&
116 subshape.layout().memory_space() == kAlternateMemorySpace) {
117 check_parameters_in_default_memory = false;
118 }
119 });
120 }
121
122 MemorySpaceAssignment::Options options;
123 if (memory_space_assignment_options) {
124 options = *memory_space_assignment_options;
125 } else {
126 options.max_size_in_bytes = 128;
127 options.alignment_in_bytes = 8;
128 options.verify = true;
129 }
130
131 options.alternate_memory_space = kAlternateMemorySpace;
132 options.buffer_interval_compare = buffer_interval_compare;
133 options.prefetch_interval_picker = prefetch_interval_picker;
134 options.size_fn = size_fn;
135 options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
136 options.max_outstanding_prefetches = max_outstanding_async_copies;
137 options.max_outstanding_evictions = max_outstanding_async_copies;
138 options.allocate_across_sequential_calls = GetParam();
139
140 auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
141 std::unique_ptr<HloLiveRange> hlo_live_range =
142 HloLiveRange::Run(module->schedule(), *alias_analysis,
143 module->entry_computation())
144 .ValueOrDie();
145
146 std::unique_ptr<PresetAssignments> preset_assignments =
147 MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
148 options)
149 .ValueOrDie();
150 if (check_parameters_in_default_memory) {
151 CheckParametersInDefaultMemory(module);
152 }
153 CheckPresetAssignments(preset_assignments.get());
154 return preset_assignments;
155 }
156
CheckPresetAssignments(const PresetAssignments * preset_assignments)157 void CheckPresetAssignments(const PresetAssignments* preset_assignments) {
158 // Ensure that the exported preset assignments point to layouts in the
159 // alternate memory. Also ensure that the positions are unique. Note that
160 // we're using a std::set instead of absl::flat_hash_set because we can make
161 // use of HloPosition's comparator logic instead of providing a hasher.
162 std::set<HloPosition> positions_in_preset_assignments;
163 for (auto& position_and_chunk : preset_assignments->chunks()) {
164 HloPosition position = position_and_chunk.first;
165 EXPECT_EQ(positions_in_preset_assignments.find(position),
166 positions_in_preset_assignments.end());
167 positions_in_preset_assignments.insert(position);
168 const Shape& subshape =
169 ShapeUtil::GetSubshape(position.instruction->shape(), position.index);
170 EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace)
171 << "Exported position is not in alternate mem: "
172 << position.ToString();
173 }
174 }
175
CheckParametersInDefaultMemory(const HloModule * module)176 void CheckParametersInDefaultMemory(const HloModule* module) {
177 // Check that all the entry parameter subshapes are placed in default
178 // memory.
179 const HloComputation* entry_computation = module->entry_computation();
180 for (const HloInstruction* parameter :
181 entry_computation->parameter_instructions()) {
182 ShapeUtil::ForEachSubshape(
183 parameter->shape(),
184 [&](const Shape& subshape, const ShapeIndex& /*index*/) {
185 if (subshape.has_layout()) {
186 EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace)
187 << "Parameter not in default memory: "
188 << parameter->ToString();
189 }
190 });
191 }
192 }
193
194 struct OutstandingAsyncCopies {
195 int64 max_copies;
196 int64 max_prefetches;
197 int64 max_evictions;
198 };
199
CountMaximumOutstandingAsyncCopies(const HloModule & module)200 /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies(
201 const HloModule& module) {
202 OutstandingAsyncCopies copies{0, 0, 0};
203 int64 current_copies = 0;
204 int64 current_prefetches = 0;
205 int64 current_evictions = 0;
206 for (HloInstruction* instruction : module.schedule()
207 .sequence(module.entry_computation())
208 .instructions()) {
209 if (instruction->opcode() == HloOpcode::kCopyStart) {
210 current_copies++;
211 if (ShapeUtil::GetSubshape(instruction->shape(), {0})
212 .layout()
213 .memory_space() == kAlternateMemorySpace) {
214 current_prefetches++;
215 } else {
216 current_evictions++;
217 }
218 } else if (instruction->opcode() == HloOpcode::kCopyDone) {
219 current_copies--;
220 if (instruction->shape().layout().memory_space() ==
221 kAlternateMemorySpace) {
222 current_prefetches--;
223 } else {
224 current_evictions--;
225 }
226 }
227 copies.max_copies = std::max(copies.max_copies, current_copies);
228 copies.max_prefetches =
229 std::max(copies.max_prefetches, current_prefetches);
230 copies.max_prefetches = std::max(copies.max_evictions, current_evictions);
231 }
232 return copies;
233 }
234
GetAlternateMemoryOffset(const PresetAssignments & preset_assignments,const HloInstruction * instruction,const ShapeIndex & index={}) const235 int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
236 const HloInstruction* instruction,
237 const ShapeIndex& index = {}) const {
238 // Returns the offset of the assignment, -1 if it's not in the alternate
239 // memory.
240 const HloModule* module = instruction->parent()->parent();
241 auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
242 HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
243 for (auto& pos_and_chunk : preset_assignments.chunks()) {
244 for (auto& value : buffer.values()) {
245 if (pos_and_chunk.first == value->defining_position()) {
246 return pos_and_chunk.second.offset;
247 }
248 }
249 }
250 return -1;
251 }
252
CreateEvictAndPrefetchModule()253 std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
254 HloComputation::Builder builder(TestName());
255 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
256 HloInstruction* p0 =
257 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
258 HloInstruction* p1 =
259 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
260 HloInstruction* tanh = builder.AddInstruction(
261 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
262 // tanh should be placed in the alternate memory since there isn't much
263 // contention in the beginning. However, tanh has another consumer at the
264 // end. So it should be kicked out to default memory and prefetched back in.
265 // The graph below is meant to increase the contention to force
266 // eviction/prefetch behavior.
267 HloInstruction* a = builder.AddInstruction(
268 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
269 HloInstruction* b = builder.AddInstruction(
270 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
271 HloInstruction* c = builder.AddInstruction(
272 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
273 HloInstruction* d = builder.AddInstruction(
274 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
275 HloInstruction* e = builder.AddInstruction(
276 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
277 HloInstruction* f = builder.AddInstruction(
278 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
279 HloInstruction* g = builder.AddInstruction(
280 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
281 HloInstruction* h = builder.AddInstruction(
282 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
283 HloInstruction* i = builder.AddInstruction(
284 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
285 HloInstruction* j = builder.AddInstruction(
286 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
287 HloInstruction* k = builder.AddInstruction(
288 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
289 HloInstruction* l = builder.AddInstruction(
290 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
291 HloInstruction* m = builder.AddInstruction(
292 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
293 HloInstruction* n = builder.AddInstruction(
294 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
295 HloInstruction* o = builder.AddInstruction(
296 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
297 // tanh is being used at the root instruction, and this should be
298 // prefetched.
299 HloInstruction* add = builder.AddInstruction(
300 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
301
302 auto module = CreateNewVerifiedModule();
303 HloComputation* computation = module->AddEntryComputation(builder.Build());
304
305 HloSchedule schedule(module.get());
306 schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
307 j, k, l, m, n, o, add});
308 TF_CHECK_OK(module->set_schedule(schedule));
309 return module;
310 }
311
312 MemorySpaceAssignmentCostAnalysis::Cache cache_;
313 };
314
315 // For testing purposes, we define a cost analysis where we can control the
316 // elapsed times of each HLO and asynchronous copy.
317 class FakeMemorySpaceAssignmentCostAnalysis
318 : public MemorySpaceAssignmentCostAnalysis {
319 public:
320 static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const HloModule & module)321 Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
322 TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
323 TF_ASSIGN_OR_RETURN(auto hlo_live_range,
324 HloLiveRange::Run(module.schedule(), *alias_analysis,
325 module.entry_computation()));
326 auto call_graph = CallGraph::Build(&module);
327 return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
328 cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
329 /*alternate_mem_bandwidth_bytes_per_second=*/1,
330 std::move(alias_analysis), std::move(hlo_live_range),
331 std::move(call_graph)));
332 }
333
GetInstructionElapsed(const HloInstruction & instruction) const334 float GetInstructionElapsed(
335 const HloInstruction& instruction) const override {
336 if (get_instruction_elapsed_override_) {
337 return get_instruction_elapsed_override_(instruction);
338 }
339 return 1.0;
340 }
341
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const342 float GetInstructionElapsedInAlternateMemory(
343 const HloInstruction& instruction,
344 absl::optional<int64> operand_in_alternate_mem,
345 bool output_in_alternate_mem) const override {
346 if (get_instruction_elapsed_in_alternate_memory_override_) {
347 return get_instruction_elapsed_in_alternate_memory_override_(
348 instruction, operand_in_alternate_mem, output_in_alternate_mem);
349 }
350 if (operand_in_alternate_mem) {
351 return 0.5;
352 } else {
353 return 1.0;
354 }
355 }
356
GetAsyncCopyElapsed(const Shape & shape) const357 float GetAsyncCopyElapsed(const Shape& shape) const override {
358 if (get_async_copy_elapsed_override_) {
359 return get_async_copy_elapsed_override_(shape);
360 }
361 return 3.0;
362 }
363
364 // The following methods can be used to override what the above API calls
365 // return.
SetOverrideForGetInstructionElapsed(std::function<float (const HloInstruction &)> function)366 void SetOverrideForGetInstructionElapsed(
367 std::function<float(const HloInstruction&)> function) {
368 get_instruction_elapsed_override_ = function;
369 }
SetOverrideForGetInstructionElapsedInAlternateMemory(std::function<float (const HloInstruction &,absl::optional<int64>,bool)> function)370 void SetOverrideForGetInstructionElapsedInAlternateMemory(
371 std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
372 function) {
373 get_instruction_elapsed_in_alternate_memory_override_ = function;
374 }
SetOverrideForGetAsyncCopyElapsed(std::function<float (const Shape &)> function)375 void SetOverrideForGetAsyncCopyElapsed(
376 std::function<float(const Shape&)> function) {
377 get_async_copy_elapsed_override_ = function;
378 }
379
380 protected:
FakeMemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)381 FakeMemorySpaceAssignmentCostAnalysis(
382 const HloCostAnalysis& cost_analysis,
383 float async_copy_bandwidth_bytes_per_second,
384 float alternate_mem_bandwidth_bytes_per_second,
385 std::unique_ptr<HloAliasAnalysis> alias_analysis,
386 std::unique_ptr<HloLiveRange> hlo_live_range,
387 std::unique_ptr<CallGraph> call_graph)
388 : MemorySpaceAssignmentCostAnalysis(
389 cost_analysis, async_copy_bandwidth_bytes_per_second,
390 alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
391 std::move(hlo_live_range), std::move(call_graph)) {}
392
393 private:
394 std::function<float(const HloInstruction&)>
395 get_instruction_elapsed_override_ = nullptr;
396 std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
397 get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
398 std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
399 };
400
TEST_P(MemorySpaceAssignmentTest,ParameterOnly)401 TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
402 // A module consisting of a single parameter. Inputs/outputs are currently
403 // excluded from memory space assignment.
404 HloComputation::Builder builder(TestName());
405 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
406 HloInstruction* p0 =
407 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
408
409 auto module = CreateNewVerifiedModule();
410 HloComputation* computation = module->AddEntryComputation(builder.Build());
411
412 HloSchedule schedule(module.get());
413 schedule.set_sequence(computation, {p0});
414 TF_CHECK_OK(module->set_schedule(schedule));
415
416 AssignMemorySpace(module.get());
417
418 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
419 }
420
TEST_P(MemorySpaceAssignmentTest,Simple)421 TEST_P(MemorySpaceAssignmentTest, Simple) {
422 // A simple module with a few simple instructions. Expect this to be
423 // transformed with CopyStart and CopyDone instructions inserted after inputs
424 // and before outputs.
425 HloComputation::Builder builder(TestName());
426 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
427 HloInstruction* p0 =
428 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
429 HloInstruction* p1 =
430 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
431 HloInstruction* add = builder.AddInstruction(
432 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1));
433 HloInstruction* sub = builder.AddInstruction(
434 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
435 HloInstruction* mul = builder.AddInstruction(
436 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub));
437
438 auto module = CreateNewVerifiedModule();
439 HloComputation* computation = module->AddEntryComputation(builder.Build());
440
441 HloSchedule schedule(module.get());
442 schedule.set_sequence(computation, {p0, p1, add, sub, mul});
443 TF_CHECK_OK(module->set_schedule(schedule));
444
445 auto preset_assignments = AssignMemorySpace(module.get());
446
447 // Inputs and outputs are currently placed in the default memory. Everything
448 // else should be in the alternate memory.
449 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
450 F32, {2, 3},
451 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
452 kAlternateMemorySpace);
453 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
454 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
455 EXPECT_THAT(mul, op::ShapeWithLayout(shape));
456 EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
457 EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
458
459 // Make sure the preset assignments is sane.
460 EXPECT_EQ(preset_assignments->chunks().size(), 3);
461 EXPECT_EQ(preset_assignments->assignment_informations().size(), 1);
462 // Ensure the offset assigned to add and sub are different.
463 EXPECT_NE(preset_assignments->chunks()[0].second.offset,
464 preset_assignments->chunks()[1].second.offset);
465 }
466
TEST_P(MemorySpaceAssignmentTest,NegateChain)467 TEST_P(MemorySpaceAssignmentTest, NegateChain) {
468 // The negate chain is long enough for asynchronous copy to be inserted
469 // between p1 and add.
470 HloComputation::Builder builder(TestName());
471 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
472 HloInstruction* p0 =
473 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
474 HloInstruction* p1 =
475 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
476 HloInstruction* negate0 = builder.AddInstruction(
477 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
478 HloInstruction* negate1 = builder.AddInstruction(
479 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
480 HloInstruction* negate2 = builder.AddInstruction(
481 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
482 HloInstruction* negate3 = builder.AddInstruction(
483 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
484 HloInstruction* negate4 = builder.AddInstruction(
485 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
486 HloInstruction* negate5 = builder.AddInstruction(
487 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
488 HloInstruction* negate6 = builder.AddInstruction(
489 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
490 HloInstruction* add = builder.AddInstruction(
491 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
492
493 auto module = CreateNewVerifiedModule();
494 HloComputation* computation = module->AddEntryComputation(builder.Build());
495
496 HloSchedule schedule(module.get());
497 schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
498 negate3, negate4, negate5, negate6, add});
499 TF_CHECK_OK(module->set_schedule(schedule));
500
501 AssignMemorySpace(module.get());
502
503 EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
504 kDefaultMemorySpace,
505 op::Parameter(1))));
506 // Parameters are in the default memory space.
507 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
508 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
509 // Negate instructions are in the alternate memory space (1).
510 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
511 F32, {2, 3},
512 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
513 kAlternateMemorySpace);
514 EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
515 EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
516 EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
517 EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
518 EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
519 EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
520 EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
521 // Ensure the CopyStart/CopyDone schedules.
522 const HloInstructionSequence& sequence =
523 module->schedule().sequence(computation);
524 EXPECT_THAT(sequence.instructions()[0], op::Parameter(0));
525 EXPECT_THAT(sequence.instructions()[1], op::Parameter(1));
526 EXPECT_THAT(sequence.instructions()[2], op::CopyStart());
527 EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
528 }
529
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetch)530 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) {
531 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
532
533 AssignMemorySpace(module.get());
534
535 EXPECT_THAT(
536 module->entry_computation()->root_instruction(),
537 op::Add(op::Add(),
538 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
539 op::AsyncCopy(kDefaultMemorySpace,
540 kAlternateMemorySpace, op::Tanh()))));
541 }
542
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies0)543 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
544 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
545
546 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
547
548 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0);
549 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0);
550 }
551
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies1)552 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
553 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
554
555 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
556
557 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1);
558 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1);
559 }
560
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies2)561 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) {
562 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
563
564 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2);
565
566 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2);
567 EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2);
568 }
569
570 // TODO(berkin): This test is broken with some prefetch timing improvements.
TEST_P(MemorySpaceAssignmentTest,DISABLED_DontEvictWhenThereIsDefaultMemAllocation)571 TEST_P(MemorySpaceAssignmentTest,
572 DISABLED_DontEvictWhenThereIsDefaultMemAllocation) {
573 // This test is the same as EvictAndPrefetchLimitAsyncCopies1, except we check
574 // that there is no eviction if not necessary (due to an existing allocation
575 // in default memory).
576 HloComputation::Builder builder(TestName());
577 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
578 HloInstruction* p0 =
579 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
580 HloInstruction* p1 =
581 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
582 HloInstruction* tanh = builder.AddInstruction(
583 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
584 // tanh should be placed in the alternate memory since there isn't much
585 // contention in the beginning. However, tanh has another consumer at the end.
586 // So it should be kicked out to default memory and prefetched back in. The
587 // graph below is meant to increase the contention to force eviction/prefetch
588 // behavior.
589 HloInstruction* a = builder.AddInstruction(
590 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
591 HloInstruction* b = builder.AddInstruction(
592 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
593 HloInstruction* c = builder.AddInstruction(
594 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
595 HloInstruction* d = builder.AddInstruction(
596 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
597 HloInstruction* e = builder.AddInstruction(
598 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
599 HloInstruction* f = builder.AddInstruction(
600 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
601 HloInstruction* g = builder.AddInstruction(
602 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
603 HloInstruction* h = builder.AddInstruction(
604 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
605 HloInstruction* i = builder.AddInstruction(
606 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
607 HloInstruction* j = builder.AddInstruction(
608 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
609 HloInstruction* k = builder.AddInstruction(
610 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
611 HloInstruction* l = builder.AddInstruction(
612 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
613 HloInstruction* m = builder.AddInstruction(
614 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
615 HloInstruction* n = builder.AddInstruction(
616 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
617 HloInstruction* o = builder.AddInstruction(
618 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
619 // tanh is being used at the root instruction, and this should be
620 // prefetched.
621 HloInstruction* add = builder.AddInstruction(
622 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
623
624 auto module = CreateNewVerifiedModule();
625 HloComputation* computation = module->AddEntryComputation(builder.Build());
626
627 HloSchedule schedule(module.get());
628 schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
629 j, k, l, m, n, o, add});
630 TF_CHECK_OK(module->set_schedule(schedule));
631
632 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
633
634 // We expect the second argument to multiply is prefetched c.
635 EXPECT_THAT(f, op::Multiply(op::Add(), op::CopyDone()));
636 // We make sure that the second argument to this multiply is not evicted
637 // CopyDone but is the original c.
638 EXPECT_THAT(h, op::Multiply(op::Subtract(), op::Multiply()));
639 }
640
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchAndPrefetch)641 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchAndPrefetch) {
642 // Test for a memory corruption bug involving evict/prefetch/prefetch pattern,
643 // where the last prefetch copied from the original buffer in alternate buffer
644 // instead of evicted buffer.
645 HloComputation::Builder builder(TestName());
646 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
647 HloInstruction* p0 =
648 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
649 HloInstruction* p1 =
650 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
651 HloInstruction* tanh = builder.AddInstruction(
652 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
653 HloInstruction* a = builder.AddInstruction(
654 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
655 HloInstruction* b = builder.AddInstruction(
656 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
657 HloInstruction* c = builder.AddInstruction(
658 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
659 HloInstruction* d = builder.AddInstruction(
660 HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
661 HloInstruction* e = builder.AddInstruction(
662 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
663 HloInstruction* f = builder.AddInstruction(
664 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
665 HloInstruction* g = builder.AddInstruction(
666 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
667 HloInstruction* h = builder.AddInstruction(
668 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
669 HloInstruction* i = builder.AddInstruction(
670 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
671 HloInstruction* j = builder.AddInstruction(
672 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
673 HloInstruction* k = builder.AddInstruction(
674 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
675 HloInstruction* l = builder.AddInstruction(
676 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
677 HloInstruction* m = builder.AddInstruction(
678 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
679 HloInstruction* n = builder.AddInstruction(
680 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
681 HloInstruction* o = builder.AddInstruction(
682 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
683 HloInstruction* add0 = builder.AddInstruction(
684 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
685 HloInstruction* negate0 = builder.AddInstruction(
686 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add0));
687 HloInstruction* negate1 = builder.AddInstruction(
688 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
689 HloInstruction* negate2 = builder.AddInstruction(
690 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
691 HloInstruction* negate3 = builder.AddInstruction(
692 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
693 HloInstruction* negate4 = builder.AddInstruction(
694 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
695 HloInstruction* negate5 = builder.AddInstruction(
696 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
697 HloInstruction* negate6 = builder.AddInstruction(
698 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
699 HloInstruction* negate7 = builder.AddInstruction(
700 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
701 HloInstruction* negate8 = builder.AddInstruction(
702 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
703 HloInstruction* negate9 = builder.AddInstruction(
704 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
705 HloInstruction* add1 = builder.AddInstruction(
706 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate9, tanh));
707
708 auto module = CreateNewVerifiedModule();
709 HloComputation* computation = module->AddEntryComputation(builder.Build());
710
711 HloSchedule schedule(module.get());
712 schedule.set_sequence(
713 computation,
714 {p0, p1, tanh, a, b, c, d, e,
715 f, g, h, i, j, k, l, m,
716 n, o, add0, negate0, negate1, negate2, negate3, negate4,
717 negate5, negate6, negate7, negate8, negate9, add1});
718 TF_CHECK_OK(module->set_schedule(schedule));
719
720 AssignMemorySpace(module.get());
721
722 // Check that both prefetches (add0 and add1) prefetch from the eviction
723 // instead of tanh, which will be placed in the alternate memory directly.
724 EXPECT_THAT(
725 add0,
726 op::Add(op::Add(),
727 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
728 op::AsyncCopy(kDefaultMemorySpace,
729 kAlternateMemorySpace, op::Tanh()))));
730 EXPECT_THAT(
731 add1,
732 op::Add(op::Negate(),
733 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
734 op::AsyncCopy(kDefaultMemorySpace,
735 kAlternateMemorySpace, op::Tanh()))));
736 }
737
TEST_P(MemorySpaceAssignmentTest,While)738 TEST_P(MemorySpaceAssignmentTest, While) {
739 auto module = CreateNewVerifiedModule();
740 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
741 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
742 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
743
744 auto cond_builder = HloComputation::Builder("WhileCond");
745 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
746 HloInstruction* cond_param = cond_builder.AddInstruction(
747 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
748 HloInstruction* cond_iter = cond_builder.AddInstruction(
749 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
750 HloInstruction* cond_limit = cond_builder.AddInstruction(
751 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
752 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
753 HloInstruction* cond_lt = cond_builder.AddInstruction(
754 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
755 cond_limit, ComparisonDirection::kLt));
756 HloComputation* cond_computation =
757 module->AddEmbeddedComputation(cond_builder.Build());
758
759 auto body_builder = HloComputation::Builder("WhileBody");
760 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
761 HloInstruction* body_param = body_builder.AddInstruction(
762 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
763 HloInstruction* body_iter = body_builder.AddInstruction(
764 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
765 HloInstruction* body_data = body_builder.AddInstruction(
766 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
767 HloInstruction* body_iter_increment = body_builder.AddInstruction(
768 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
769 HloInstruction* body_iter_next =
770 body_builder.AddInstruction(HloInstruction::CreateBinary(
771 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
772 HloInstruction* body_data_increment =
773 body_builder.AddInstruction(HloInstruction::CreateConstant(
774 LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
775 HloInstruction* body_data_mul =
776 body_builder.AddInstruction(HloInstruction::CreateBinary(
777 shape, HloOpcode::kMultiply, body_data, body_data));
778 HloInstruction* body_data_add =
779 body_builder.AddInstruction(HloInstruction::CreateBinary(
780 shape, HloOpcode::kAdd, body_data, body_data_increment));
781 HloInstruction* body_data_next =
782 body_builder.AddInstruction(HloInstruction::CreateBinary(
783 shape, HloOpcode::kAdd, body_data_add, body_data_mul));
784 HloInstruction* body_out = body_builder.AddInstruction(
785 HloInstruction::CreateTuple({body_data_next, body_iter_next}));
786 HloComputation* body_computation =
787 module->AddEmbeddedComputation(body_builder.Build());
788
789 auto builder = HloComputation::Builder(TestName());
790 HloInstruction* data = builder.AddInstruction(
791 HloInstruction::CreateParameter(0, shape, "param_iter"));
792 HloInstruction* iter = builder.AddInstruction(
793 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
794 HloInstruction* tuple =
795 builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
796 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
797 tuple_shape, cond_computation, body_computation, tuple));
798 HloComputation* entry_computation =
799 module->AddEntryComputation(builder.Build());
800
801 HloSchedule schedule(module.get());
802 schedule.set_sequence(cond_computation,
803 {cond_param, cond_iter, cond_limit, cond_lt});
804 schedule.set_sequence(body_computation,
805 {body_param, body_iter, body_data, body_iter_increment,
806 body_iter_next, body_data_increment, body_data_mul,
807 body_data_add, body_data_next, body_out});
808 schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
809 TF_CHECK_OK(module->set_schedule(schedule));
810
811 AssignMemorySpace(module.get());
812
813 // Ensure the tuple value and buffers used in the while instruction are
814 // exempted from using the alternate memory when allocating across sequential
815 // calls is disabled. However, body_data_mul is independent and can be safely
816 // be placed in the alternate memory.
817 const bool allocate_across_sequential_calls = GetParam();
818 if (!allocate_across_sequential_calls) {
819 EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
820 EXPECT_THAT(data, op::ShapeWithLayout(shape));
821 EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
822 EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
823 EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
824 EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
825 }
826 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
827 F32, {2, 3},
828 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
829 kAlternateMemorySpace);
830 EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
831 }
832
TEST_P(MemorySpaceAssignmentTest,Tuple)833 TEST_P(MemorySpaceAssignmentTest, Tuple) {
834 HloComputation::Builder builder(TestName());
835 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
836 Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
837 Shape tuple_shape =
838 ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape});
839 HloInstruction* p = builder.AddInstruction(
840 HloInstruction::CreateParameter(0, tuple_shape, "p"));
841 HloInstruction* p0 = builder.AddInstruction(
842 HloInstruction::CreateGetTupleElement(shape, p, 0));
843 HloInstruction* negate0 = builder.AddInstruction(
844 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
845 HloInstruction* negate1 = builder.AddInstruction(
846 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
847 HloInstruction* negate2 = builder.AddInstruction(
848 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
849 HloInstruction* negate3 = builder.AddInstruction(
850 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
851 HloInstruction* negate4 = builder.AddInstruction(
852 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
853 HloInstruction* negate5 = builder.AddInstruction(
854 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
855 HloInstruction* negate6 = builder.AddInstruction(
856 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
857 HloInstruction* p1 = builder.AddInstruction(
858 HloInstruction::CreateGetTupleElement(shape, p, 1));
859 HloInstruction* add = builder.AddInstruction(
860 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
861 HloInstruction* p2 = builder.AddInstruction(
862 HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2));
863 HloInstruction* p2_0 = builder.AddInstruction(
864 HloInstruction::CreateGetTupleElement(shape, p2, 0));
865 HloInstruction* mul = builder.AddInstruction(
866 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0));
867
868 auto module = CreateNewVerifiedModule();
869 HloComputation* computation = module->AddEntryComputation(builder.Build());
870
871 HloSchedule schedule(module.get());
872 schedule.set_sequence(
873 computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
874 negate6, p1, add, p2, p2_0, mul});
875 TF_CHECK_OK(module->set_schedule(schedule));
876
877 AssignMemorySpace(module.get());
878
879 EXPECT_THAT(
880 mul,
881 op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
882 kDefaultMemorySpace,
883 op::GetTupleElement())),
884 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
885 op::GetTupleElement(op::GetTupleElement()))));
886 }
887
TEST_P(MemorySpaceAssignmentTest,Bitcast)888 TEST_P(MemorySpaceAssignmentTest, Bitcast) {
889 // Bitcasts can cause the position in the alternate memory to appear multiple
890 // times in the preset assignments. This test ensure the preset assignments
891 // refer to unique positions.
892 HloComputation::Builder builder(TestName());
893 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
894 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
895 HloInstruction* p0 =
896 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
897 HloInstruction* p1 = builder.AddInstruction(
898 HloInstruction::CreateParameter(1, param_shape, "p1"));
899 HloInstruction* negate = builder.AddInstruction(
900 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
901 HloInstruction* bitcast = builder.AddInstruction(
902 HloInstruction::CreateBitcast(param_shape, negate));
903 HloInstruction* add = builder.AddInstruction(
904 HloInstruction::CreateBinary(param_shape, HloOpcode::kAdd, bitcast, p1));
905
906 auto module = CreateNewVerifiedModule();
907 HloComputation* computation = module->AddEntryComputation(builder.Build());
908
909 HloSchedule schedule(module.get());
910 schedule.set_sequence(computation, {p0, p1, negate, bitcast, add});
911 TF_CHECK_OK(module->set_schedule(schedule));
912
913 AssignMemorySpace(module.get());
914
915 bitcast = add->mutable_operand(0);
916 EXPECT_EQ(bitcast->opcode(), HloOpcode::kBitcast);
917 EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
918 }
919
TEST_P(MemorySpaceAssignmentTest,Bitcast2)920 TEST_P(MemorySpaceAssignmentTest, Bitcast2) {
921 HloComputation::Builder builder(TestName());
922 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
923 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
924 HloInstruction* p0 =
925 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
926 HloInstruction* p1 = builder.AddInstruction(
927 HloInstruction::CreateParameter(1, param_shape, "p1"));
928 HloInstruction* negate0 = builder.AddInstruction(
929 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
930 HloInstruction* negate1 = builder.AddInstruction(
931 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
932 HloInstruction* negate2 = builder.AddInstruction(
933 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
934 HloInstruction* negate3 = builder.AddInstruction(
935 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
936 HloInstruction* negate4 = builder.AddInstruction(
937 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
938 HloInstruction* bitcast =
939 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
940 HloInstruction* add = builder.AddInstruction(
941 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
942
943 auto module = CreateNewVerifiedModule();
944 HloComputation* computation = module->AddEntryComputation(builder.Build());
945
946 HloSchedule schedule(module.get());
947 schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
948 negate3, negate4, bitcast, add});
949 TF_CHECK_OK(module->set_schedule(schedule));
950
951 AssignMemorySpace(module.get());
952
953 EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
954 kAlternateMemorySpace);
955 }
956
TEST_P(MemorySpaceAssignmentTest,Bitcast3)957 TEST_P(MemorySpaceAssignmentTest, Bitcast3) {
958 HloComputation::Builder builder(TestName());
959 Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
960 Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
961 Shape shape3 = ShapeUtil::MakeShape(F32, {1, 6});
962 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
963 HloInstruction* p0 =
964 builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
965 HloInstruction* p1 = builder.AddInstruction(
966 HloInstruction::CreateParameter(1, param_shape, "p1"));
967 HloInstruction* negate0 = builder.AddInstruction(
968 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p0));
969 HloInstruction* negate1 = builder.AddInstruction(
970 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0));
971 HloInstruction* negate2 = builder.AddInstruction(
972 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1));
973 HloInstruction* negate3 = builder.AddInstruction(
974 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2));
975 HloInstruction* negate4 = builder.AddInstruction(
976 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3));
977 HloInstruction* bitcast1 =
978 builder.AddInstruction(HloInstruction::CreateBitcast(shape1, p1));
979 HloInstruction* add = builder.AddInstruction(
980 HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, bitcast1, negate4));
981 HloInstruction* bitcast2 =
982 builder.AddInstruction(HloInstruction::CreateBitcast(shape3, p1));
983 HloInstruction* bitcast3 =
984 builder.AddInstruction(HloInstruction::CreateBitcast(shape2, bitcast2));
985 HloInstruction* bitcast4 =
986 builder.AddInstruction(HloInstruction::CreateBitcast(shape2, add));
987 HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
988 shape2, HloOpcode::kMultiply, bitcast3, bitcast4));
989
990 auto module = CreateNewVerifiedModule();
991 HloComputation* computation = module->AddEntryComputation(builder.Build());
992
993 HloSchedule schedule(module.get());
994 schedule.set_sequence(computation,
995 {p0, p1, negate0, negate1, negate2, negate3, negate4,
996 bitcast1, add, bitcast2, bitcast3, bitcast4, mul});
997 TF_CHECK_OK(module->set_schedule(schedule));
998
999 AssignMemorySpace(module.get());
1000
1001 // We expect one bitcast on the LHS of multiply since bitcast(bitcast(foo)) is
1002 // converted to bitcast(foo).
1003 EXPECT_THAT(
1004 mul,
1005 op::Multiply(
1006 op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
1007 op::Parameter(1))),
1008 op::Bitcast(op::Add(
1009 op::Bitcast(op::AsyncCopy(kAlternateMemorySpace,
1010 kDefaultMemorySpace, op::Parameter(1))),
1011 op::Negate()))));
1012 EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1013 kAlternateMemorySpace);
1014 EXPECT_EQ(add->shape().layout().memory_space(), kAlternateMemorySpace);
1015 // bitcast2 will no longer have a consumer and should get DCE'd, so we don't
1016 // care about its memory space.
1017 EXPECT_EQ(mul->operand(0)->shape().layout().memory_space(),
1018 kAlternateMemorySpace);
1019 EXPECT_EQ(mul->operand(1)->shape().layout().memory_space(),
1020 kAlternateMemorySpace);
1021 }
1022
TEST_P(MemorySpaceAssignmentTest,BitcastTuple)1023 TEST_P(MemorySpaceAssignmentTest, BitcastTuple) {
1024 HloComputation::Builder builder(TestName());
1025 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1026 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1027 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1028
1029 auto module = CreateNewVerifiedModule();
1030 HloComputation::Builder fusion_builder("fusion");
1031 HloInstruction* fusion_param = fusion_builder.AddInstruction(
1032 HloInstruction::CreateParameter(0, tuple_shape, "p"));
1033 HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1034 HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1035 HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1036 HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1037 fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1038 shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1039 HloComputation* fusion_computation =
1040 module->AddEmbeddedComputation(fusion_builder.Build());
1041
1042 HloInstruction* p0 =
1043 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1044 HloInstruction* p1 = builder.AddInstruction(
1045 HloInstruction::CreateParameter(1, param_shape, "p1"));
1046 HloInstruction* negate0 = builder.AddInstruction(
1047 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1048 HloInstruction* negate1 = builder.AddInstruction(
1049 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1050 HloInstruction* negate2 = builder.AddInstruction(
1051 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1052 HloInstruction* negate3 = builder.AddInstruction(
1053 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1054 HloInstruction* negate4 = builder.AddInstruction(
1055 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1056 HloInstruction* bitcast =
1057 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1058 HloInstruction* tuple =
1059 builder.AddInstruction(HloInstruction::CreateTuple({bitcast, p0}));
1060 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1061 shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1062
1063 HloComputation* computation = module->AddEntryComputation(builder.Build());
1064
1065 HloSchedule schedule(module.get());
1066 schedule.set_sequence(computation,
1067 {p0, p1, negate0, negate1, negate2, negate3, negate4,
1068 bitcast, tuple, fusion});
1069 TF_CHECK_OK(module->set_schedule(schedule));
1070
1071 AssignMemorySpace(module.get());
1072 }
1073
TEST_P(MemorySpaceAssignmentTest,BitcastGetTupleElementTuple)1074 TEST_P(MemorySpaceAssignmentTest, BitcastGetTupleElementTuple) {
1075 // This test pattern was encountered in
1076 // //third_party/tensorflow/compiler/xla/tests:slice_test and was causing a
1077 // breakage when there is a GetTupleElement(Tuple(Bitcast())) pattern. Also
1078 // added a GetTupleElement(GetTupleElement(Tuple(Tuple(Bitcast())))) pattern.
1079 absl::string_view hlo_string = R"(
1080 HloModule DoIt_S64_10_0_5_1.3, is_scheduled=true
1081
1082 ENTRY %DoIt_S64_10_0_5_1.3 (p0.1: (u32[10], u32[10])) -> (u32[5], u32[5]) {
1083 %p0.1 = (u32[10]{0:T(128)}, u32[10]{0:T(128)}) parameter(0)
1084 %get-tuple-element.1 = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=1
1085 %bitcast.1 = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element.1)
1086 %get-tuple-element = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=0
1087 %bitcast = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element)
1088 %tuple.1 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %bitcast, u32[5]{0:T(128)} %bitcast.1)
1089 %tuple.3 = ((u32[5]{0:T(128)}, u32[5]{0:T(128)}), (u32[5]{0:T(128)}, u32[5]{0:T(128)})) tuple(%tuple.1, %tuple.1)
1090 %get-tuple-element.4 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %tuple.1), index=0
1091 %get-tuple-element.5 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) get-tuple-element(%tuple.3), index=0
1092 %get-tuple-element.6 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %get-tuple-element.5), index=1
1093 %copy.2 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.4)
1094 %copy.3 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.6)
1095 ROOT %tuple.2 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %copy.2, u32[5]{0:T(128)} %copy.3)
1096 }
1097 )";
1098
1099 TF_ASSERT_OK_AND_ASSIGN(auto module,
1100 ParseAndReturnVerifiedModule(hlo_string));
1101 AssignMemorySpace(module.get());
1102 }
1103
TEST_P(MemorySpaceAssignmentTest,GetSimplifiedOperandBug)1104 TEST_P(MemorySpaceAssignmentTest, GetSimplifiedOperandBug) {
1105 // Test case for a bug finding Bitcasts in GTE(Tuple(...)) pattern.
1106 absl::string_view hlo_string = R"(
1107 HloModule sort.16, is_scheduled=true
1108
1109 ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1110 %param.3.4 = s32[1]{0:T(128)} parameter(3)
1111 %param.2.3 = u32[1]{0:T(128)} parameter(2)
1112 %param.1.2 = f32[1]{0:T(128)} parameter(1)
1113 %param.0.1 = s32[1]{0:T(128)} parameter(0)
1114 %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4)
1115 %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1116 %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1117 %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1118 %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
1119 %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
1120 %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
1121 %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
1122 %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
1123 ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
1124 }
1125 )";
1126
1127 TF_ASSERT_OK_AND_ASSIGN(auto module,
1128 ParseAndReturnVerifiedModule(hlo_string));
1129 AssignMemorySpace(module.get());
1130 }
1131
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUse)1132 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUse) {
1133 // When there is a pattern where a bitcast has multiple uses (negate0 and add)
1134 // and one is in the default memory and the other is in alternate memory, they
1135 // both need their own bitcast.
1136 HloComputation::Builder builder(TestName());
1137 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1138 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1139 HloInstruction* p0 = builder.AddInstruction(
1140 HloInstruction::CreateParameter(0, param_shape, "p1"));
1141 HloInstruction* bitcast =
1142 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1143 HloInstruction* negate0 = builder.AddInstruction(
1144 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1145 HloInstruction* negate1 = builder.AddInstruction(
1146 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1147 HloInstruction* negate2 = builder.AddInstruction(
1148 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1149 HloInstruction* negate3 = builder.AddInstruction(
1150 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1151 HloInstruction* negate4 = builder.AddInstruction(
1152 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1153 HloInstruction* add = builder.AddInstruction(
1154 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
1155
1156 auto module = CreateNewVerifiedModule();
1157 HloComputation* computation = module->AddEntryComputation(builder.Build());
1158
1159 HloSchedule schedule(module.get());
1160 schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1161 negate3, negate4, add});
1162 TF_CHECK_OK(module->set_schedule(schedule));
1163
1164 AssignMemorySpace(module.get());
1165 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1166 F32, {2, 3},
1167 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
1168 kAlternateMemorySpace);
1169 EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1170 EXPECT_THAT(add->operand(0), op::ShapeWithLayout(shape_in_alternate_mem));
1171 }
1172
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUseTuple)1173 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUseTuple) {
1174 // Same as BitcastMultUse but the second use is a tuple.
1175 HloComputation::Builder builder(TestName());
1176 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1177 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1178 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1179
1180 auto module = CreateNewVerifiedModule();
1181 HloComputation::Builder fusion_builder("fusion");
1182 HloInstruction* fusion_param = fusion_builder.AddInstruction(
1183 HloInstruction::CreateParameter(0, tuple_shape, "p"));
1184 HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1185 HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1186 HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1187 HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1188 fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1189 shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1190 HloComputation* fusion_computation =
1191 module->AddEmbeddedComputation(fusion_builder.Build());
1192
1193 HloInstruction* p0 = builder.AddInstruction(
1194 HloInstruction::CreateParameter(0, param_shape, "p1"));
1195 HloInstruction* bitcast =
1196 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1197 HloInstruction* negate0 = builder.AddInstruction(
1198 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1199 HloInstruction* negate1 = builder.AddInstruction(
1200 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1201 HloInstruction* negate2 = builder.AddInstruction(
1202 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1203 HloInstruction* negate3 = builder.AddInstruction(
1204 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1205 HloInstruction* negate4 = builder.AddInstruction(
1206 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1207 HloInstruction* tuple =
1208 builder.AddInstruction(HloInstruction::CreateTuple({bitcast, negate4}));
1209 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1210 shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1211
1212 HloComputation* computation = module->AddEntryComputation(builder.Build());
1213
1214 HloSchedule schedule(module.get());
1215 schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1216 negate3, negate4, tuple, fusion});
1217 TF_CHECK_OK(module->set_schedule(schedule));
1218
1219 AssignMemorySpace(module.get());
1220 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1221 F32, {2, 3},
1222 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
1223 kAlternateMemorySpace);
1224 EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1225 EXPECT_THAT(fusion->operand(0)->operand(0),
1226 op::ShapeWithLayout(shape_in_alternate_mem));
1227 }
1228
TEST_P(MemorySpaceAssignmentTest,BitcastScheduleBug)1229 TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) {
1230 // Bitcasts can force asynchronous copies to be scheduled too early, possibly
1231 // leading to memory corruption.
1232 // Bug:
1233 // p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add
1234 // /
1235 // p1->cs->cd->bitcast-----------------------------------------+
1236 //
1237 // Expected:
1238 // p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add
1239 // /
1240 // p1--------------------->cs----------------->cd->bitcast-+
1241 HloComputation::Builder builder(TestName());
1242 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1243 Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1244 HloInstruction* p0 =
1245 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1246 HloInstruction* p1 = builder.AddInstruction(
1247 HloInstruction::CreateParameter(1, param_shape, "p1"));
1248 HloInstruction* bitcast =
1249 builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1250 HloInstruction* negate0 = builder.AddInstruction(
1251 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1252 HloInstruction* negate1 = builder.AddInstruction(
1253 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1254 HloInstruction* negate2 = builder.AddInstruction(
1255 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1256 HloInstruction* negate3 = builder.AddInstruction(
1257 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1258 HloInstruction* negate4 = builder.AddInstruction(
1259 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1260 HloInstruction* negate5 = builder.AddInstruction(
1261 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
1262 HloInstruction* negate6 = builder.AddInstruction(
1263 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
1264 HloInstruction* negate7 = builder.AddInstruction(
1265 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
1266 HloInstruction* negate8 = builder.AddInstruction(
1267 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
1268 HloInstruction* negate9 = builder.AddInstruction(
1269 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
1270 HloInstruction* add = builder.AddInstruction(
1271 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9));
1272
1273 auto module = CreateNewVerifiedModule();
1274 HloComputation* computation = module->AddEntryComputation(builder.Build());
1275
1276 HloSchedule schedule(module.get());
1277 schedule.set_sequence(
1278 computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3,
1279 negate4, negate5, negate6, negate7, negate8, negate9, add});
1280 TF_CHECK_OK(module->set_schedule(schedule));
1281
1282 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1283 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4);
1284
1285 EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1286 kAlternateMemorySpace);
1287 const auto& instructions =
1288 module->schedule().sequence(module->entry_computation()).instructions();
1289 for (int i = 0; i < instructions.size(); ++i) {
1290 // Expect that there is a negate before and after the CopyStart and there is
1291 // a negate before CopyDone.
1292 if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) {
1293 EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1294 EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate);
1295 } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) {
1296 EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1297 }
1298 }
1299 }
1300
TEST_P(MemorySpaceAssignmentTest,TupleSelect)1301 TEST_P(MemorySpaceAssignmentTest, TupleSelect) {
1302 // Make sure tuple-select is not optimized away.
1303 absl::string_view hlo_string = R"(
1304 HloModule tuple, is_scheduled=true
1305
1306 ENTRY %main (a: f32[2], b: f32[2], c: f32[2], d: f32[2], cond: pred[]) -> f32[2] {
1307 %cond = pred[]{:T(128)E(32)} parameter(4)
1308 %token0 = token[] after-all()
1309 %d = f32[2]{0:T(128)} parameter(3)
1310 %c = f32[2]{0:T(128)} parameter(2)
1311 %b = f32[2]{0:T(128)} parameter(1)
1312 %a = f32[2]{0:T(128)} parameter(0)
1313 %tup0 = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple(f32[2]{0:T(128)} %a, f32[2]{0:T(128)} %b)
1314 %tup1 = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple(f32[2]{0:T(128)} %c, f32[2]{0:T(128)} %d)
1315 %s = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple-select(pred[]{:T(128)E(32)} %cond, (f32[2]{0:T(128)}, f32[2]{0:T(128)}) %tup0, (f32[2]{0:T(128)}, f32[2]{0:T(128)}) %tup1)
1316 %gte = f32[2]{0:T(128)} get-tuple-element((f32[2]{0:T(128)}, f32[2]{0:T(128)}) %s), index=0
1317 ROOT %negate = f32[2]{0:T(128)} negate(f32[2]{0:T(128)} %gte)
1318 }
1319 )";
1320
1321 TF_ASSERT_OK_AND_ASSIGN(auto module,
1322 ParseAndReturnVerifiedModule(hlo_string));
1323 AssignMemorySpace(module.get());
1324
1325 EXPECT_THAT(module->entry_computation()->root_instruction(),
1326 op::Negate(op::GetTupleElement(op::TupleSelect())));
1327 }
1328
TEST_P(MemorySpaceAssignmentTest,AddDependency)1329 TEST_P(MemorySpaceAssignmentTest, AddDependency) {
1330 // Make sure add-dependency is not optimized away.
1331 absl::string_view hlo_string = R"(
1332 HloModule AddDependency, is_scheduled=true
1333
1334 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1335 %p = f32[3]{0} parameter(0)
1336 %neg0 = f32[3]{0} negate(f32[3]{0} %p)
1337 %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
1338 %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
1339 %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
1340 %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
1341 %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
1342 %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
1343 %token0 = token[] after-all()
1344 %add_dep = f32[3]{0} add-dependency(f32[3]{0} %p, token[] %token0)
1345 ROOT %add = f32[3]{0} add(f32[3]{0} %add_dep, f32[3]{0} %neg6)
1346 }
1347 )";
1348
1349 TF_ASSERT_OK_AND_ASSIGN(auto module,
1350 ParseAndReturnVerifiedModule(hlo_string));
1351 AssignMemorySpace(module.get());
1352
1353 EXPECT_THAT(module->entry_computation()->root_instruction(),
1354 op::Add(op::AddDependency(), op::Negate()));
1355 }
1356
TEST_P(MemorySpaceAssignmentTest,WhileAllocationBug)1357 TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) {
1358 // This test is carefully crafted to include two multiply ops sized [4,3] in a
1359 // while body. For testing purposes, we have provided a BufferIntervalCompare
1360 // such that first multiply, then tanh, then other HloValues will be
1361 // allocated. The memory is sized just enough to fit two [4,3] buffers.
1362 // Because the multiplies in the while body are going to be allocated in the
1363 // alternate memory first, the tanh that is fed inside the while loop should
1364 // not be placed in the alternate memory. Otherwise, we will corrupt memory.
1365 absl::string_view hlo_string = R"(
1366 HloModule WhileAllocationBug, is_scheduled=true
1367
1368 %WhileBody (body_param: (f32[4,3], f32[])) -> (f32[4,3], f32[]) {
1369 %body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1370 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=1
1371 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=0
1372 %constant.1 = f32[] constant(1)
1373 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1374 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1375 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.2)
1376 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1377 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1378 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1379 ROOT %tuple = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[] %add)
1380 }
1381
1382 %WhileCond (cond_param: (f32[4,3], f32[])) -> pred[] {
1383 %cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1384 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %cond_param), index=1
1385 %constant = f32[] constant(50)
1386 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1387 }
1388
1389 ENTRY %Entry (param_iter: f32[4,3], param_data: f32[], p2: f32[4,3]) -> f32[4,3] {
1390 %param_data = f32[] parameter(1)
1391 %param_iter = f32[4,3]{1,0} parameter(0)
1392 %p2 = f32[4,3]{1,0} parameter(2)
1393 %tanh = f32[4,3]{1,0} tanh(f32[4,3]{1,0} %param_iter)
1394 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1395 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1396 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1397 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1398 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1399 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1400 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1401 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %tanh)
1402 %tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %tanh, f32[] %param_data)
1403 %while = (f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1404 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %while), index=0
1405 ROOT %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.3, f32[4,3]{1,0} %add.4)
1406 }
1407 )";
1408
1409 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
1410 [](const MemorySpaceAssignment::BufferInterval& a,
1411 const MemorySpaceAssignment::BufferInterval& b) {
1412 bool a_is_mul =
1413 a.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1414 bool b_is_mul =
1415 b.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1416 if (a_is_mul && !b_is_mul) {
1417 return true;
1418 }
1419 if (!a_is_mul && b_is_mul) {
1420 return false;
1421 }
1422 bool a_is_tanh =
1423 a.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1424 bool b_is_tanh =
1425 b.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1426 if (a_is_tanh && !b_is_tanh) {
1427 return true;
1428 }
1429 if (!a_is_tanh && b_is_tanh) {
1430 return false;
1431 }
1432 return a.buffer->id() < b.buffer->id();
1433 };
1434 TF_ASSERT_OK_AND_ASSIGN(auto module,
1435 ParseAndReturnVerifiedModule(hlo_string));
1436
1437 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
1438 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1439 buffer_interval_compare, &prefetch_interval_picker);
1440
1441 for (const HloInstruction* instruction :
1442 module->entry_computation()->instructions()) {
1443 if (instruction->opcode() == HloOpcode::kWhile) {
1444 const Shape& while_subshape =
1445 ShapeUtil::GetSubshape(instruction->shape(), {0});
1446 // We expect shape {0} to either be in default memory for the entire while
1447 // loop or there has to be an eviction within the while loop.
1448 if (while_subshape.layout().memory_space() == kAlternateMemorySpace) {
1449 const HloInstruction* body_param =
1450 instruction->while_body()->parameter_instruction(0);
1451 const HloInstruction* gte = nullptr;
1452 for (const HloInstruction* user : body_param->users()) {
1453 if (user->opcode() == HloOpcode::kGetTupleElement &&
1454 user->tuple_index() == 0) {
1455 gte = user;
1456 break;
1457 }
1458 }
1459 EXPECT_NE(gte, nullptr);
1460 const HloInstruction* copy_start = nullptr;
1461 for (const HloInstruction* user : gte->users()) {
1462 if (user->opcode() == HloOpcode::kCopyStart) {
1463 copy_start = user;
1464 break;
1465 }
1466 }
1467 EXPECT_NE(copy_start, nullptr);
1468 const Shape& copy_start_subshape =
1469 ShapeUtil::GetSubshape(copy_start->shape(), {0});
1470
1471 EXPECT_NE(copy_start_subshape.layout().memory_space(),
1472 kAlternateMemorySpace);
1473 }
1474 }
1475 }
1476 }
1477
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoops)1478 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) {
1479 absl::string_view hlo_string = R"(
1480 HloModule WhileAllocationBug, is_scheduled=true
1481
1482 %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1483 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1484 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1485 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1486 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1487 %constant.1 = f32[] constant(1)
1488 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1489 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1490 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1491 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1492 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1493 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1494 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1495 }
1496
1497 %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1498 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1499 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1500 %constant = f32[] constant(50)
1501 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1502 }
1503
1504 %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1505 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1506 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1507 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1508 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1509 %constant.1 = f32[] constant(1)
1510 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1511 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1512 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1513 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1514 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1515 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1516 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1517 }
1518
1519 %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1520 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1521 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1522 %constant = f32[] constant(50)
1523 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1524 }
1525
1526 ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1527 %param_iter = f32[] parameter(1)
1528 %param_data = f32[4,3]{1,0} parameter(0)
1529 %p2 = f32[4,3]{1,0} parameter(2)
1530 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1531 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1532 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1533 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1534 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1535 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1536 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1537 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1538 %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1539 %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1540 %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1541 %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1542 %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1543 %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter)
1544 %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1545 %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1546 ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3)
1547 }
1548 )";
1549
1550 TF_ASSERT_OK_AND_ASSIGN(auto module,
1551 ParseAndReturnVerifiedModule(hlo_string));
1552 AssignMemorySpace(module.get());
1553 }
1554
TEST_P(MemorySpaceAssignmentTest,WhileLiveRangeBug)1555 TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) {
1556 // Tests against while live ranges being incorrect and the verifier
1557 // complaining about a conflict.
1558 absl::string_view hlo_string = R"(
1559 HloModule WhileAllocationBug, is_scheduled=true
1560
1561 %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1562 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1563 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1564 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1565 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1566 %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1567 %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1568 %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1569 %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1570 %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1571 %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1572 %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1573 %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1574 %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1575 %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1576 %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1577 %constant.1 = f32[] constant(1)
1578 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1579 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1580 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1581 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1582 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1583 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1584 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1585 }
1586
1587 %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1588 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1589 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1590 %constant = f32[] constant(50)
1591 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1592 }
1593
1594 ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1595 %param_iter = f32[] parameter(1)
1596 %param_data = f32[4,3]{1,0} parameter(0)
1597 %p2 = f32[4,3]{1,0} parameter(2)
1598 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1599 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1600 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1601 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1602 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1603 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1604 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1605 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1606 %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1607 %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1608 %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1609 %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1610 %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1611 ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3)
1612 }
1613 )";
1614
1615 TF_ASSERT_OK_AND_ASSIGN(auto module,
1616 ParseAndReturnVerifiedModule(hlo_string));
1617 AssignMemorySpace(module.get());
1618 }
1619
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoopsOneBuffer)1620 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
1621 // Tests against a bug when there are consecutive while loops with one buffer
1622 // (the value doesn't change in the buffer), the parameter can be colored in
1623 // the alternate memory space.
1624 absl::string_view hlo_string = R"(
1625 HloModule WhileAllocationBug, is_scheduled=true
1626
1627 %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1628 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1629 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1630 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1631 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1632 %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1633 %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1634 %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1635 %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1636 %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1637 %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1638 %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1639 %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1640 %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1641 %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1642 %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1643 %constant.1 = f32[] constant(1)
1644 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1645 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1646 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1647 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1648 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1649 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1650 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1651 }
1652
1653 %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1654 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1655 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1656 %constant = f32[] constant(50)
1657 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1658 }
1659
1660 %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1661 %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1662 %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1663 %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1664 %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1665 %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1666 %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1667 %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1668 %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1669 %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1670 %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1671 %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1672 %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1673 %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1674 %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1675 %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1676 %constant.1 = f32[] constant(1)
1677 %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1678 %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1679 %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1680 %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1681 %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1682 %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1683 ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1684 }
1685
1686 %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1687 %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1688 %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1689 %constant = f32[] constant(50)
1690 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1691 }
1692
1693 ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1694 %param_iter = f32[] parameter(1)
1695 %param_data = f32[4,3]{1,0} parameter(0)
1696 %p2 = f32[4,3]{1,0} parameter(2)
1697 %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1698 %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1699 %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1700 %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1701 %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1702 %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1703 %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1704 %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1705 %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1706 %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1707 %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1708 %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1709 %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter)
1710 %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1711 %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1712 %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1
1713 ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6)
1714 }
1715 )";
1716
1717 TF_ASSERT_OK_AND_ASSIGN(auto module,
1718 ParseAndReturnVerifiedModule(hlo_string));
1719 AssignMemorySpace(module.get());
1720 }
1721
TEST_P(MemorySpaceAssignmentTest,WhileCondAliasBug)1722 TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
1723 // While loop is the root of the entry computation. We should ensure the
1724 // output of the entry computation remains to be in default memory space.
1725 // Test from //third_party/tensorflow/compiler/xla/tests:while_test
1726 // WhileTest.WhileWithPrngScalarResult.
1727 absl::string_view hlo_string = R"(
1728 HloModule WhileWithPrngScalarResult.18, is_scheduled=true
1729
1730 %fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] {
1731 %param_1.3 = s32[1]{0:T(128)} parameter(1)
1732 %constant.2 = s32[]{:T(128)} constant(-2147483648)
1733 %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
1734 %param_2.3 = s32[5]{0:T(128)} parameter(2)
1735 %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
1736 %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
1737 %param_0.1 = s32[6]{0:T(128)} parameter(0)
1738 ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
1739 }
1740
1741 %body.3 (prev.4: s32[6]) -> s32[6] {
1742 %constant.7 = s32[]{:T(128)} constant(100)
1743 %constant.6 = s32[]{:T(128)} constant(0)
1744 %constant.5 = s32[1]{0:T(128)} constant({1})
1745 %prev.4 = s32[6]{0:T(128)} parameter(0)
1746 %rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
1747 %neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
1748 ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
1749 }
1750
1751 %WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
1752 %constant.15 = s32[]{:T(128)} constant(1)
1753 %prev.12 = s32[6]{0:T(128)} parameter(0)
1754 %bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12)
1755 %bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1)
1756 ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT
1757 }
1758
1759 ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] {
1760 %constant.1 = s32[]{:T(128)} constant(0)
1761 %broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={}
1762 ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3
1763 }
1764 )";
1765
1766 TF_ASSERT_OK_AND_ASSIGN(auto module,
1767 ParseAndReturnVerifiedModule(hlo_string));
1768 AssignMemorySpace(module.get());
1769 // Expect the output to have default memory space.
1770 EXPECT_EQ(module->entry_computation()
1771 ->root_instruction()
1772 ->shape()
1773 .layout()
1774 .memory_space(),
1775 kDefaultMemorySpace);
1776 }
1777
TEST_P(MemorySpaceAssignmentTest,WhileInPlaceBuffer)1778 TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
1779 // Ensure that a dynamic update slice within a while loop is able to get an
1780 // alternate memory allocation.
1781 absl::string_view hlo_string = R"(
1782 HloModule Module, is_scheduled=true
1783
1784 fused_computation {
1785 param0 = f32[2,3] parameter(0)
1786 constant.1 = f32[] constant(0)
1787 broadcast = f32[2,1] broadcast(constant.1), dimensions={}
1788 constant.3 = s32[] constant(0)
1789 ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
1790 }
1791
1792 %WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
1793 %body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1794 %get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
1795 %get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
1796 %get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
1797 %fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
1798 %multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
1799 ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
1800 }
1801
1802 %WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
1803 %cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1804 %get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
1805 %constant = f32[] constant(50)
1806 ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1807 }
1808
1809 ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
1810 %param_iter = f32[] parameter(1)
1811 %param_data = f32[2,3]{1,0} parameter(0)
1812 %p2 = f32[2,3]{1,0} parameter(2)
1813 %copy1 = f32[2,3]{1,0} copy(param_data)
1814 %copy2 = f32[2,3]{1,0} copy(p2)
1815 %tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
1816 %while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1817 %get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
1818 ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
1819 }
1820 )";
1821
1822 TF_ASSERT_OK_AND_ASSIGN(auto module,
1823 ParseAndReturnVerifiedModule(hlo_string));
1824 AssignMemorySpace(module.get());
1825 const HloInstruction* while_op =
1826 module->entry_computation()->GetInstructionWithName("while");
1827 if (GetParam()) {
1828 EXPECT_EQ(
1829 ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
1830 kAlternateMemorySpace);
1831 }
1832 }
1833
TEST_P(MemorySpaceAssignmentTest,WhileSharedBufferVerificationBug)1834 TEST_P(MemorySpaceAssignmentTest, WhileSharedBufferVerificationBug) {
1835 // Tests a spurious verification failure when a while has the same value
1836 // passed in twice (copy0) and that value is evicted within the while loop.
1837 absl::string_view hlo_string = R"(
1838 HloModule module, is_scheduled=true
1839
1840 while_cond {
1841 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1842 ROOT gte = pred[] get-tuple-element(p0), index=3
1843 }
1844
1845 while_body {
1846 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1847 gte0 = f32[3]{0} get-tuple-element(p0), index=0
1848 gte1 = f32[3]{0} get-tuple-element(p0), index=1
1849 gte2 = f32[3]{0} get-tuple-element(p0), index=2
1850 gte3 = pred[] get-tuple-element(p0), index=3
1851 add = f32[3]{0} add(gte0, gte0)
1852 negate0 = f32[3]{0} negate(add)
1853 negate1 = f32[3]{0} negate(negate0)
1854 negate2 = f32[3]{0} negate(negate1)
1855 negate3 = f32[3]{0} negate(negate2)
1856 negate4 = f32[3]{0} negate(negate3)
1857 negate5 = f32[3]{0} negate(negate4)
1858 negate6 = f32[3]{0} negate(negate5)
1859 negate7 = f32[3]{0} negate(negate6)
1860 negate8 = f32[3]{0} negate(negate7)
1861 negate9 = f32[3]{0} negate(negate8)
1862 negate10 = f32[3]{0} negate(negate9)
1863 negate11 = f32[3]{0} negate(negate10)
1864 negate12 = f32[3]{0} negate(negate11)
1865 negate13 = f32[3]{0} negate(negate12)
1866 negate14 = f32[3]{0} negate(negate13)
1867 negate15 = f32[3]{0} negate(negate14)
1868 negate16 = f32[3]{0} negate(negate15)
1869 ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte0, negate16, gte3)
1870 }
1871
1872 ENTRY entry {
1873 p0 = f32[3]{0} parameter(0)
1874 p1 = pred[] parameter(1)
1875 copy0 = f32[3]{0} copy(p0)
1876 copy1 = f32[3]{0} copy(p0)
1877 tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
1878 while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1879 ROOT gte = f32[3]{0} get-tuple-element(while), index=2
1880 }
1881 )";
1882 TF_ASSERT_OK_AND_ASSIGN(auto module,
1883 ParseAndReturnVerifiedModule(hlo_string));
1884 AssignMemorySpace(module.get());
1885 }
1886
TEST_P(MemorySpaceAssignmentTest,b172243149)1887 TEST_P(MemorySpaceAssignmentTest, b172243149) {
1888 // Tests for the failure in b/172243149, where if we skip processing
1889 // non-copy allocations that are in default memory can actually cause
1890 // failures. In this case, the problem tensor is copy0, where it is fed to
1891 // both negate, while, and add0. The copy0->negate dependency can be allocated
1892 // in the alternate memory. Then the algorithm attempts to place the
1893 // copy0->while edge in the alternate memory, but since this value isn't used
1894 // in the while loop, it won't get an alternate memory allocation. Finally for
1895 // the copy0->add0 edge, the algorithm will actually replace it with
1896 // while{0}->add0, since this is equivalent and while is defined later than
1897 // copy0. However, if we actually skip processing this while{0}->add0
1898 // allocation, we won't replace this edge, and will end up with the
1899 // copy0->add0 edge, which illegally extends the lifetime of the alternate
1900 // memory buffer in copy0.
1901 absl::string_view hlo_string = R"(
1902 HloModule module, is_scheduled=true
1903
1904 while_cond {
1905 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1906 ROOT gte = pred[] get-tuple-element(p0), index=3
1907 }
1908
1909 while_body {
1910 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1911 gte0 = f32[3]{0} get-tuple-element(p0), index=0
1912 gte1 = f32[3]{0} get-tuple-element(p0), index=1
1913 gte2 = f32[3]{0} get-tuple-element(p0), index=2
1914 gte3 = pred[] get-tuple-element(p0), index=3
1915 add = f32[3]{0} add(gte1, gte2)
1916 negate0 = f32[3]{0} negate(add)
1917 negate1 = f32[3]{0} negate(negate0)
1918 negate2 = f32[3]{0} negate(negate1)
1919 negate3 = f32[3]{0} negate(negate2)
1920 negate4 = f32[3]{0} negate(negate3)
1921 negate5 = f32[3]{0} negate(negate4)
1922 negate6 = f32[3]{0} negate(negate5)
1923 negate7 = f32[3]{0} negate(negate6)
1924 negate8 = f32[3]{0} negate(negate7)
1925 negate9 = f32[3]{0} negate(negate8)
1926 negate10 = f32[3]{0} negate(negate9)
1927 negate11 = f32[3]{0} negate(negate10)
1928 negate12 = f32[3]{0} negate(negate11)
1929 negate13 = f32[3]{0} negate(negate12)
1930 negate14 = f32[3]{0} negate(negate13)
1931 negate15 = f32[3]{0} negate(negate14)
1932 negate16 = f32[3]{0} negate(negate15)
1933 ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, negate16, gte3)
1934 }
1935
1936 ENTRY entry {
1937 p0 = f32[3]{0} parameter(0)
1938 p1 = pred[] parameter(1)
1939 copy0 = f32[3]{0} copy(p0)
1940 copy1 = f32[3]{0} copy(p0)
1941 copy2 = f32[3]{0} copy(p0)
1942 negate = f32[3]{0} negate(copy0)
1943 tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, copy2, p1)
1944 while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1945 gte = f32[3]{0} get-tuple-element(while), index=2
1946 add0 = f32[3]{0} add(negate, copy0)
1947 ROOT add1 = f32[3]{0} add(add0, gte)
1948 }
1949 )";
1950 TF_ASSERT_OK_AND_ASSIGN(auto module,
1951 ParseAndReturnVerifiedModule(hlo_string));
1952 AssignMemorySpace(module.get());
1953 }
1954
TEST_P(MemorySpaceAssignmentTest,ControlPredecessorsBug)1955 TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
1956 // Having control_predecessors on an HLO was preventing us from DCEing an op
1957 // that doesn't have any users (tuple.1). The scheduler assumes the graph is
1958 // fully DCEed, which causes some instructions not to be scheduled.
1959 absl::string_view hlo_string = R"(
1960 HloModule sort.16, is_scheduled=true
1961
1962 ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1963 %param.3.4 = s32[1]{0:T(128)} parameter(3)
1964 %param.2.3 = u32[1]{0:T(128)} parameter(2)
1965 %param.1.2 = f32[1]{0:T(128)} parameter(1)
1966 %param.0.1 = s32[1]{0:T(128)} parameter(0)
1967 %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4), control-predecessors={%param.0.1}
1968 %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1969 %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1970 %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1971 %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
1972 %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
1973 %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
1974 %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
1975 %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
1976 ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
1977 }
1978 )";
1979
1980 TF_ASSERT_OK_AND_ASSIGN(auto module,
1981 ParseAndReturnVerifiedModule(hlo_string));
1982 AssignMemorySpace(module.get());
1983 }
1984
TEST_P(MemorySpaceAssignmentTest,ConditionalShouldBeAllocatedInAlternateMem)1985 TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) {
1986 // Checks if simple conditionals get alternate memory allocations.
1987 absl::string_view hlo_string = R"(
1988 HloModule CondAllocation, is_scheduled=true
1989
1990 true_computation {
1991 p0 = (f32[3]{0}) parameter(0)
1992 gte = f32[3]{0} get-tuple-element(p0), index=0
1993 ROOT neg1 = f32[3]{0} negate(gte)
1994 }
1995
1996 false_computation {
1997 p0 = (f32[3]{0}) parameter(0)
1998 gte = f32[3]{0} get-tuple-element(p0), index=0
1999 ROOT neg2 = f32[3]{0} negate(gte)
2000 }
2001
2002 ENTRY entry {
2003 p0 = f32[3]{0} parameter(0)
2004 p1 = pred[] parameter(1)
2005 copy = f32[3]{0} copy(p0)
2006 tuple = (f32[3]{0}) tuple(copy)
2007 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
2008 }
2009 )";
2010 TF_ASSERT_OK_AND_ASSIGN(auto module,
2011 ParseAndReturnVerifiedModule(hlo_string));
2012 AssignMemorySpace(module.get());
2013
2014 if (GetParam()) {
2015 // Check that copy and gtes got alternate memory allocations.
2016 auto copy =
2017 module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2018 EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2019 auto neg1 = module->GetComputationWithName("true_computation")
2020 ->GetInstructionWithName("neg1");
2021 auto neg1_operand = neg1->operand(0);
2022 EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2023 kAlternateMemorySpace);
2024 auto neg2 = module->GetComputationWithName("false_computation")
2025 ->GetInstructionWithName("neg2");
2026 auto neg2_operand = neg2->operand(0);
2027 EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2028 kAlternateMemorySpace);
2029 }
2030 }
2031
TEST_P(MemorySpaceAssignmentTest,ConditionalAvoidsUnnecessaryPrefetch)2032 TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) {
2033 // Checks if we avoid unnecessary allocation in alternate memory if the input
2034 // won't be used in the computation for a long time.
2035 absl::string_view hlo_string = R"(
2036 HloModule CondAllocation, is_scheduled=true
2037
2038 true_computation {
2039 p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2040 gte0 = f32[3]{0} get-tuple-element(p0), index=0
2041 neg0 = f32[3]{0} negate(gte0)
2042 neg1 = f32[3]{0} negate(neg0)
2043 neg2 = f32[3]{0} negate(neg1)
2044 neg3 = f32[3]{0} negate(neg2)
2045 neg4 = f32[3]{0} negate(neg3)
2046 neg5 = f32[3]{0} negate(neg4)
2047 neg6 = f32[3]{0} negate(neg5)
2048 neg7 = f32[3]{0} negate(neg6)
2049 neg8 = f32[3]{0} negate(neg7)
2050 neg9 = f32[3]{0} negate(neg8)
2051 gte1 = f32[3]{0} get-tuple-element(p0), index=1
2052 ROOT add = f32[3]{0} add(neg9, gte1)
2053 }
2054
2055 false_computation {
2056 p0 = (f32[3]{0}) parameter(0)
2057 gte = f32[3]{0} get-tuple-element(p0), index=0
2058 ROOT neg = f32[3]{0} negate(gte)
2059 }
2060
2061 ENTRY entry {
2062 p0 = f32[3]{0} parameter(0)
2063 p1 = pred[] parameter(1)
2064 copy0 = f32[3]{0} copy(p0)
2065 copy1 = f32[3]{0} copy(p0)
2066 tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2067 tuple1 = (f32[3]{0}) tuple(copy0)
2068 ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2069 }
2070 )";
2071 TF_ASSERT_OK_AND_ASSIGN(auto module,
2072 ParseAndReturnVerifiedModule(hlo_string));
2073 AssignMemorySpace(module.get());
2074
2075 if (GetParam()) {
2076 // Check that copy1 doesn't get unnecessarily allocated in alternate mem
2077 // (due to long negate chain in true_computation) but is prefetched before
2078 // add.
2079 auto copy0 =
2080 module->GetComputationWithName("entry")->GetInstructionWithName(
2081 "copy0");
2082 EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2083 auto copy1 =
2084 module->GetComputationWithName("entry")->GetInstructionWithName(
2085 "copy1");
2086 EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace);
2087 auto add = module->GetComputationWithName("true_computation")
2088 ->GetInstructionWithName("add");
2089 auto add_operand = add->operand(1);
2090 EXPECT_EQ(add_operand->shape().layout().memory_space(),
2091 kAlternateMemorySpace);
2092 }
2093 }
2094
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUse)2095 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) {
2096 // Make sure there is an evict when there is a conditional use followed by
2097 // another use.
2098 absl::string_view hlo_string = R"(
2099 HloModule CondAllocation, is_scheduled=true
2100
2101 true_computation {
2102 p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2103 gte0 = f32[3]{0} get-tuple-element(p0), index=0
2104 gte1 = f32[3]{0} get-tuple-element(p0), index=1
2105 add0 = f32[3]{0} add(gte0, gte1)
2106 neg0 = f32[3]{0} negate(add0)
2107 neg1 = f32[3]{0} negate(neg0)
2108 neg2 = f32[3]{0} negate(neg1)
2109 neg3 = f32[3]{0} negate(neg2)
2110 neg4 = f32[3]{0} negate(neg3)
2111 neg5 = f32[3]{0} negate(neg4)
2112 neg6 = f32[3]{0} negate(neg5)
2113 neg7 = f32[3]{0} negate(neg6)
2114 neg8 = f32[3]{0} negate(neg7)
2115 ROOT neg9 = f32[3]{0} negate(neg8)
2116 }
2117
2118 false_computation {
2119 p0 = (f32[3]{0}) parameter(0)
2120 gte = f32[3]{0} get-tuple-element(p0), index=0
2121 ROOT neg = f32[3]{0} negate(gte)
2122 }
2123
2124 ENTRY entry {
2125 p0 = f32[3]{0} parameter(0)
2126 p1 = pred[] parameter(1)
2127 copy0 = f32[3]{0} copy(p0)
2128 copy1 = f32[3]{0} copy(p0)
2129 tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2130 tuple1 = (f32[3]{0}) tuple(copy0)
2131 conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2132 ROOT add1 = f32[3]{0} add(copy1, conditional)
2133 }
2134 )";
2135 TF_ASSERT_OK_AND_ASSIGN(auto module,
2136 ParseAndReturnVerifiedModule(hlo_string));
2137 AssignMemorySpace(module.get());
2138
2139 if (GetParam()) {
2140 // Make sure the copy1->add edge is in alternate memory. Before conditional,
2141 // this should be evicted to default memory and neg uses the input from
2142 // default memory.
2143 auto copy1 =
2144 module->GetComputationWithName("entry")->GetInstructionWithName(
2145 "copy1");
2146 EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace);
2147 auto add0 = module->GetComputationWithName("true_computation")
2148 ->GetInstructionWithName("add0");
2149 auto add0_operand = add0->operand(1);
2150 EXPECT_EQ(add0_operand->shape().layout().memory_space(),
2151 kAlternateMemorySpace);
2152 auto add1 =
2153 module->GetComputationWithName("entry")->GetInstructionWithName("add1");
2154 auto add1_operand = add1->operand(0);
2155 EXPECT_EQ(add1_operand->shape().layout().memory_space(),
2156 kDefaultMemorySpace);
2157 EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone);
2158 }
2159 }
2160
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUseInWhile)2161 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) {
2162 absl::string_view hlo_string = R"(
2163 HloModule CondAllocation, is_scheduled=true
2164
2165 true_computation {
2166 p0 = (f32[3]{0}) parameter(0)
2167 gte = f32[3]{0} get-tuple-element(p0), index=0
2168 ROOT neg1 = f32[3]{0} negate(gte)
2169 }
2170
2171 false_computation {
2172 p0 = (f32[3]{0}) parameter(0)
2173 gte = f32[3]{0} get-tuple-element(p0), index=0
2174 ROOT neg2 = f32[3]{0} negate(gte)
2175 }
2176
2177 while_cond {
2178 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2179 ROOT gte = pred[] get-tuple-element(p0), index=2
2180 }
2181
2182 while_body {
2183 p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2184 gte0 = f32[3]{0} get-tuple-element(p0), index=0
2185 gte1 = f32[3]{0} get-tuple-element(p0), index=1
2186 gte2 = pred[] get-tuple-element(p0), index=2
2187 cond_tuple = (f32[3]{0}) tuple(gte0)
2188 conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation
2189 add = f32[3]{0} add(conditional, gte1)
2190 neg0 = f32[3]{0} negate(add)
2191 neg1 = f32[3]{0} negate(neg0)
2192 ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2)
2193 }
2194
2195 ENTRY entry {
2196 p0 = f32[3]{0} parameter(0)
2197 p1 = pred[] parameter(1)
2198 copy0 = f32[3]{0} copy(p0)
2199 copy1 = f32[3]{0} copy(p0)
2200 tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1)
2201 while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
2202 ROOT gte = f32[3]{0} get-tuple-element(while), index=1
2203 }
2204 )";
2205 TF_ASSERT_OK_AND_ASSIGN(auto module,
2206 ParseAndReturnVerifiedModule(hlo_string));
2207 AssignMemorySpace(module.get());
2208
2209 if (GetParam()) {
2210 // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation.
2211 // This will force an eviction and a prefetch for while body root.
2212 auto copy0 =
2213 module->GetComputationWithName("entry")->GetInstructionWithName(
2214 "copy0");
2215 EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2216 auto conditional = module->GetComputationWithName("while_body")
2217 ->GetInstructionWithName("conditional");
2218 auto conditional_operand = conditional->operand(1);
2219 EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0})
2220 .layout()
2221 .memory_space(),
2222 kAlternateMemorySpace);
2223 auto while_root =
2224 module->GetComputationWithName("while_body")->root_instruction();
2225 auto while_root_operand = while_root->operand(0);
2226 EXPECT_THAT(
2227 while_root_operand,
2228 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2229 op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
2230 op::GetTupleElement(op::Parameter(0)))));
2231 }
2232 }
2233
TEST_P(MemorySpaceAssignmentTest,NestedConditional)2234 TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
2235 absl::string_view hlo_string = R"(
2236 HloModule CondAllocation, is_scheduled=true
2237
2238 true_computation2 {
2239 p0 = (f32[3]{0}) parameter(0)
2240 gte = f32[3]{0} get-tuple-element(p0), index=0
2241 ROOT neg1 = f32[3]{0} negate(gte)
2242 }
2243
2244 false_computation2 {
2245 p0 = (f32[3]{0}) parameter(0)
2246 gte = f32[3]{0} get-tuple-element(p0), index=0
2247 ROOT neg2 = f32[3]{0} negate(gte)
2248 }
2249
2250 true_computation1 {
2251 p0 = (f32[3]{0}) parameter(0)
2252 gte = f32[3]{0} get-tuple-element(p0), index=0
2253 slice = f32[1]{0} slice(gte), slice={[0:1]}
2254 bitcast = f32[] bitcast(slice)
2255 constant = f32[] constant(0.0)
2256 compare = pred[] compare(bitcast, constant), direction=GT
2257 ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2
2258 }
2259
2260 false_computation1 {
2261 p0 = (f32[3]{0}) parameter(0)
2262 gte = f32[3]{0} get-tuple-element(p0), index=0
2263 ROOT neg3 = f32[3]{0} negate(gte)
2264 }
2265
2266
2267 ENTRY entry {
2268 p0 = f32[3]{0} parameter(0)
2269 p1 = pred[] parameter(1)
2270 copy = f32[3]{0} copy(p0)
2271 tuple = (f32[3]{0}) tuple(copy)
2272 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2273 }
2274 )";
2275 TF_ASSERT_OK_AND_ASSIGN(auto module,
2276 ParseAndReturnVerifiedModule(hlo_string));
2277 AssignMemorySpace(module.get());
2278
2279 if (GetParam()) {
2280 // Make sure alternate memory allocation gets propagated into both levels of
2281 // conditional.
2282 auto copy =
2283 module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2284 EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2285 auto neg1_operand = module->GetComputationWithName("true_computation2")
2286 ->GetInstructionWithName("neg1")
2287 ->operand(0);
2288 auto neg2_operand = module->GetComputationWithName("false_computation2")
2289 ->GetInstructionWithName("neg2")
2290 ->operand(0);
2291 auto neg3_operand = module->GetComputationWithName("false_computation1")
2292 ->GetInstructionWithName("neg3")
2293 ->operand(0);
2294 EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2295 kAlternateMemorySpace);
2296 EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2297 kAlternateMemorySpace);
2298 EXPECT_EQ(neg3_operand->shape().layout().memory_space(),
2299 kAlternateMemorySpace);
2300 }
2301 }
2302
TEST_P(MemorySpaceAssignmentTest,NestedConditionalBufferReuseVerificationBug)2303 TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) {
2304 // Tests a spurious verification failure when there are nested conditionals
2305 // and the innermost conditional computation reuses the buffer. Here, both the
2306 // parameter of true_computation2 and neg2 will get the same buffer. Make sure
2307 // that verification doesn't claim a failure in this case.
2308 absl::string_view hlo_string = R"(
2309 HloModule CondAllocation, is_scheduled=true
2310
2311 true_computation2 {
2312 p0 = (f32[3]{0}) parameter(0)
2313 gte = f32[3]{0} get-tuple-element(p0), index=0
2314 neg1 = f32[3]{0} negate(gte)
2315 neg2 = f32[3]{0} negate(neg1)
2316 ROOT neg3 = f32[3]{0} negate(neg2)
2317 }
2318
2319 false_computation2 {
2320 p0 = (f32[3]{0}) parameter(0)
2321 gte = f32[3]{0} get-tuple-element(p0), index=0
2322 ROOT neg4 = f32[3]{0} negate(gte)
2323 }
2324
2325 true_computation1 {
2326 p0 = (f32[3]{0}) parameter(0)
2327 gte = f32[3]{0} get-tuple-element(p0), index=0
2328 slice = f32[1]{0} slice(gte), slice={[0:1]}
2329 bitcast = f32[] bitcast(slice)
2330 constant = f32[] constant(0.0)
2331 compare = pred[] compare(bitcast, constant), direction=GT
2332 tuple = (f32[3]{0}) tuple(gte)
2333 ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2
2334 }
2335
2336 false_computation1 {
2337 p0 = (f32[3]{0}) parameter(0)
2338 gte = f32[3]{0} get-tuple-element(p0), index=0
2339 ROOT neg5 = f32[3]{0} negate(gte)
2340 }
2341
2342 ENTRY entry {
2343 p0 = f32[3]{0} parameter(0)
2344 p1 = pred[] parameter(1)
2345 copy = f32[3]{0} copy(p0)
2346 tuple = (f32[3]{0}) tuple(copy)
2347 ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2348 }
2349 )";
2350 TF_ASSERT_OK_AND_ASSIGN(auto module,
2351 ParseAndReturnVerifiedModule(hlo_string));
2352 AssignMemorySpace(module.get());
2353 }
2354
TEST_P(MemorySpaceAssignmentTest,RequestIdentifierShouldNotBeAllocatedInAlternateMem)2355 TEST_P(MemorySpaceAssignmentTest,
2356 RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
2357 // Ensure that request identifier returned by Send/Recv HLOs are not allocated
2358 // in the alternate memory.
2359 absl::string_view hlo_string = R"(
2360 HloModule SendRecv, is_scheduled=true
2361
2362 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2363 %p = f32[3]{0} parameter(0)
2364 %after-all = token[] after-all()
2365 %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
2366 %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
2367 %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
2368 %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
2369 %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
2370 %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2371 ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
2372 }
2373 )";
2374
2375 TF_ASSERT_OK_AND_ASSIGN(auto module,
2376 ParseAndReturnVerifiedModule(hlo_string));
2377 AssignMemorySpace(module.get());
2378
2379 for (const HloInstruction* instruction :
2380 module->entry_computation()->instructions()) {
2381 if (instruction->opcode() == HloOpcode::kSend ||
2382 instruction->opcode() == HloOpcode::kRecv) {
2383 const Shape& request_identifier_shape =
2384 ShapeUtil::GetSubshape(instruction->shape(), {1});
2385 EXPECT_NE(request_identifier_shape.layout().memory_space(),
2386 kAlternateMemorySpace);
2387 }
2388 }
2389 }
2390
TEST_P(MemorySpaceAssignmentTest,SendDoneShouldHaveSendOperand)2391 TEST_P(MemorySpaceAssignmentTest, SendDoneShouldHaveSendOperand) {
2392 // Ensure that SendDone has only a Send operand.
2393 absl::string_view hlo_string = R"(
2394 HloModule SendRecv, is_scheduled=true
2395
2396 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2397 %p0 = f32[3]{0} parameter(0)
2398 %p1 = f32[3]{0} parameter(1)
2399 %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2400 %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2401 %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2402 %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2403 %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2404 %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2405 %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2406 %after-all = token[] after-all()
2407 %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2408 %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2409 ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2410 }
2411 )";
2412
2413 TF_ASSERT_OK_AND_ASSIGN(auto module,
2414 ParseAndReturnVerifiedModule(hlo_string));
2415 AssignMemorySpace(module.get());
2416 }
2417
TEST_P(MemorySpaceAssignmentTest,SendAndSendDoneShouldGetSameAllocation)2418 TEST_P(MemorySpaceAssignmentTest, SendAndSendDoneShouldGetSameAllocation) {
2419 // Ensure that Send and SendDone have the same allocation.
2420 absl::string_view hlo_string = R"(
2421 HloModule SendRecv, is_scheduled=true
2422
2423 ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2424 %p0 = f32[3]{0} parameter(0)
2425 %p1 = f32[3]{0} parameter(1)
2426 %after-all = token[] after-all()
2427 %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2428 %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2429 %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2430 %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2431 %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2432 %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2433 %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2434 %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2435 %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2436 ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2437 }
2438 )";
2439
2440 TF_ASSERT_OK_AND_ASSIGN(auto module,
2441 ParseAndReturnVerifiedModule(hlo_string));
2442 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
2443 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/4);
2444 }
2445
TEST_P(MemorySpaceAssignmentTest,LastUseOpt)2446 TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
2447 // Test that checks the last use optimization. It uses two buffers that should
2448 // be placed in alternate memory.
2449 //
2450 // +-------+
2451 // / \
2452 // add1--->sub1 +-------->mul2
2453 // mul1===>add2
2454 //
2455 // Without the last use optimization, the mul1 buffer will be assigned first
2456 // (because it is larger) to offset 0. Then, add1 will be scheduled for the
2457 // add1 to sub1 segment. Because offset 0 is available, it will get that
2458 // offset. But because offset 0 is not available in the sub1 to mul2 offset,
2459 // it will end up in unnecessary copies. With the last use optimization, these
2460 // copies can be optimized away.
2461 HloComputation::Builder builder(TestName());
2462 Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
2463 Shape shape2 = ShapeUtil::MakeShape(F32, {2, 4});
2464 PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}});
2465 HloInstruction* p0 =
2466 builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
2467 HloInstruction* p1 =
2468 builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2469 HloInstruction* add1 = builder.AddInstruction(
2470 HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, p0, p0));
2471 HloInstruction* sub1 = builder.AddInstruction(
2472 HloInstruction::CreateBinary(shape1, HloOpcode::kSubtract, p0, add1));
2473 HloInstruction* mul1 = builder.AddInstruction(
2474 HloInstruction::CreateBinary(shape2, HloOpcode::kMultiply, p1, p1));
2475 HloInstruction* add2 = builder.AddInstruction(
2476 HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, mul1, p1));
2477 HloInstruction* mul2 = builder.AddInstruction(
2478 HloInstruction::CreateBinary(shape1, HloOpcode::kMultiply, add1, sub1));
2479 HloInstruction* padding_value = builder.AddInstruction(
2480 HloInstruction::CreateConstant(LiteralUtil::Zero(F32)));
2481 HloInstruction* padded_mul2 = builder.AddInstruction(
2482 HloInstruction::CreatePad(shape2, mul2, padding_value, padding_config));
2483 HloInstruction* add3 = builder.AddInstruction(
2484 HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, add2, padded_mul2));
2485
2486 auto module = CreateNewVerifiedModule();
2487 HloComputation* computation = module->AddEntryComputation(builder.Build());
2488
2489 HloSchedule schedule(module.get());
2490 schedule.set_sequence(computation, {p0, p1, add1, sub1, mul1, add2, mul2,
2491 padding_value, padded_mul2, add3});
2492 TF_CHECK_OK(module->set_schedule(schedule));
2493
2494 AssignMemorySpace(module.get());
2495
2496 EXPECT_THAT(
2497 mul2,
2498 op::Multiply(
2499 op::Add(op::Parameter(0), op::Parameter(0)),
2500 op::Subtract(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2501 op::Parameter(0)),
2502 op::Add(op::Parameter(0), op::Parameter(0)))));
2503 }
2504
TEST_P(MemorySpaceAssignmentTest,CopyOrdering)2505 TEST_P(MemorySpaceAssignmentTest, CopyOrdering) {
2506 // Test to make sure the CopyStarts follow the same CopyDone order. The shapes
2507 // are picked in increasing order to exploit the fact that heap simulator
2508 // processes larger tensors first. This checks the ability of the compiler to
2509 // reschedule:
2510 //
2511 // CS1 CD1
2512 // +--------------+
2513 // +-----------+
2514 // CS2 CD2
2515 //
2516 // into:
2517 //
2518 // CS1 CD1
2519 // +------------+
2520 // +-----------+
2521 // CS2 CD2
2522 HloComputation::Builder builder(TestName());
2523 Shape shape1 = ShapeUtil::MakeShape(F32, {2, 1});
2524 Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
2525 Shape shape3 = ShapeUtil::MakeShape(F32, {2, 3});
2526 Shape shape4 = ShapeUtil::MakeShape(F32, {2, 4});
2527 PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}});
2528 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape3, shape4});
2529 HloInstruction* p0 = builder.AddInstruction(
2530 HloInstruction::CreateParameter(0, tuple_shape, "p"));
2531 HloInstruction* p4 = builder.AddInstruction(
2532 HloInstruction::CreateGetTupleElement(shape4, p0, 1));
2533 HloInstruction* p3 = builder.AddInstruction(
2534 HloInstruction::CreateGetTupleElement(shape3, p0, 0));
2535 HloInstruction* p2 =
2536 builder.AddInstruction(HloInstruction::CreateParameter(2, shape2, "p2"));
2537 HloInstruction* p1 =
2538 builder.AddInstruction(HloInstruction::CreateParameter(1, shape1, "p1"));
2539 HloInstruction* negate0 = builder.AddInstruction(
2540 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p1));
2541 HloInstruction* negate1 = builder.AddInstruction(
2542 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0));
2543 HloInstruction* negate2 = builder.AddInstruction(
2544 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1));
2545 HloInstruction* negate3 = builder.AddInstruction(
2546 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2));
2547 HloInstruction* negate4 = builder.AddInstruction(
2548 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3));
2549 HloInstruction* negate5 = builder.AddInstruction(
2550 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate4));
2551 HloInstruction* negate6 = builder.AddInstruction(
2552 HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate5));
2553 HloInstruction* padding_value = builder.AddInstruction(
2554 HloInstruction::CreateConstant(LiteralUtil::Zero(F32)));
2555 HloInstruction* add1 = builder.AddInstruction(
2556 HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, negate6, p1));
2557 HloInstruction* padded_add1 = builder.AddInstruction(
2558 HloInstruction::CreatePad(shape2, add1, padding_value, padding_config));
2559 HloInstruction* add2 = builder.AddInstruction(
2560 HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, padded_add1, p2));
2561 HloInstruction* padded_add2 = builder.AddInstruction(
2562 HloInstruction::CreatePad(shape3, add2, padding_value, padding_config));
2563 HloInstruction* negate7 = builder.AddInstruction(
2564 HloInstruction::CreateUnary(shape4, HloOpcode::kNegate, p4));
2565 HloInstruction* add3 = builder.AddInstruction(
2566 HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, padded_add2, p3));
2567 HloInstruction* padded_add3 = builder.AddInstruction(
2568 HloInstruction::CreatePad(shape4, add3, padding_value, padding_config));
2569 HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
2570 shape4, HloOpcode::kAdd, padded_add3, negate7));
2571
2572 auto module = CreateNewVerifiedModule();
2573 HloComputation* computation = module->AddEntryComputation(builder.Build());
2574
2575 HloSchedule schedule(module.get());
2576 schedule.set_sequence(computation, {p0,
2577 p4,
2578 p3,
2579 p2,
2580 p1,
2581 negate0,
2582 negate1,
2583 negate2,
2584 negate3,
2585 negate4,
2586 negate5,
2587 negate6,
2588 padding_value,
2589 add1,
2590 padded_add1,
2591 add2,
2592 padded_add2,
2593 negate7,
2594 add3,
2595 padded_add3,
2596 add4});
2597 TF_CHECK_OK(module->set_schedule(schedule));
2598
2599 // Use a large max prefetch interval to force CopyStart/CopyDone right after
2600 // the parameters.
2601 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
2602 /*max_prefetch_interval=*/50);
2603
2604 // Iterate over the schedule to make sure CopyStart order and the
2605 // corresponding CopyDone order match.
2606 std::list<const HloInstruction*> copy_starts;
2607 for (HloInstruction* instruction : module->schedule()
2608 .sequence(module->entry_computation())
2609 .instructions()) {
2610 if (instruction->opcode() == HloOpcode::kCopyStart) {
2611 copy_starts.push_back(instruction);
2612 }
2613 if (instruction->opcode() == HloOpcode::kCopyDone) {
2614 EXPECT_EQ(copy_starts.front(), instruction->operand(0));
2615 copy_starts.pop_front();
2616 }
2617 }
2618 }
2619
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule1)2620 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
2621 // Test to ensure CopyStart/CopyDone is placed only in the entry computation.
2622 auto module = CreateNewVerifiedModule();
2623 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2624 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2625 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
2626
2627 auto cond_builder = HloComputation::Builder("WhileCond");
2628 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2629 HloInstruction* cond_param = cond_builder.AddInstruction(
2630 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2631 HloInstruction* cond_iter = cond_builder.AddInstruction(
2632 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2633 HloInstruction* cond_limit = cond_builder.AddInstruction(
2634 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2635 // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
2636 HloInstruction* cond_lt = cond_builder.AddInstruction(
2637 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2638 cond_limit, ComparisonDirection::kLt));
2639 HloComputation* cond_computation =
2640 module->AddEmbeddedComputation(cond_builder.Build());
2641
2642 auto body_builder = HloComputation::Builder("WhileBody");
2643 // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2644 HloInstruction* body_param = body_builder.AddInstruction(
2645 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2646 HloInstruction* body_iter = body_builder.AddInstruction(
2647 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2648 HloInstruction* body_data = body_builder.AddInstruction(
2649 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2650 HloInstruction* body_iter_increment = body_builder.AddInstruction(
2651 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2652 HloInstruction* body_iter_next =
2653 body_builder.AddInstruction(HloInstruction::CreateBinary(
2654 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2655 HloInstruction* body_data_increment =
2656 body_builder.AddInstruction(HloInstruction::CreateConstant(
2657 LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
2658 HloInstruction* body_data_mul =
2659 body_builder.AddInstruction(HloInstruction::CreateBinary(
2660 shape, HloOpcode::kMultiply, body_data, body_data));
2661 HloInstruction* body_data_add =
2662 body_builder.AddInstruction(HloInstruction::CreateBinary(
2663 shape, HloOpcode::kAdd, body_data, body_data_increment));
2664 HloInstruction* body_data_next =
2665 body_builder.AddInstruction(HloInstruction::CreateBinary(
2666 shape, HloOpcode::kAdd, body_data_add, body_data_mul));
2667 HloInstruction* body_out = body_builder.AddInstruction(
2668 HloInstruction::CreateTuple({body_data_next, body_iter_next}));
2669 HloComputation* body_computation =
2670 module->AddEmbeddedComputation(body_builder.Build());
2671
2672 auto builder = HloComputation::Builder(TestName());
2673 HloInstruction* data = builder.AddInstruction(
2674 HloInstruction::CreateParameter(0, shape, "param_iter"));
2675 HloInstruction* iter = builder.AddInstruction(
2676 HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
2677 HloInstruction* p2 =
2678 builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2"));
2679 HloInstruction* tuple =
2680 builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
2681 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
2682 tuple_shape, cond_computation, body_computation, tuple));
2683 HloInstruction* while_data = builder.AddInstruction(
2684 HloInstruction::CreateGetTupleElement(shape, while_op, 0));
2685 HloInstruction* add = builder.AddInstruction(
2686 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, while_data, p2));
2687 HloComputation* entry_computation =
2688 module->AddEntryComputation(builder.Build());
2689
2690 HloSchedule schedule(module.get());
2691 schedule.set_sequence(cond_computation,
2692 {cond_param, cond_iter, cond_limit, cond_lt});
2693 schedule.set_sequence(body_computation,
2694 {body_param, body_iter, body_data, body_iter_increment,
2695 body_iter_next, body_data_increment, body_data_mul,
2696 body_data_add, body_data_next, body_out});
2697 schedule.set_sequence(entry_computation,
2698 {iter, data, p2, tuple, while_op, while_data, add});
2699 TF_CHECK_OK(module->set_schedule(schedule));
2700
2701 AssignMemorySpace(module.get(), -1, 50);
2702 }
2703
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule2)2704 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
2705 auto module = CreateNewVerifiedModule();
2706 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2707 Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2708
2709 auto call_builder = HloComputation::Builder("Call");
2710 HloInstruction* call_param = call_builder.AddInstruction(
2711 HloInstruction::CreateParameter(0, shape, "call_param"));
2712 HloInstruction* call_param2 = call_builder.AddInstruction(
2713 HloInstruction::CreateParameter(1, shape2, "call_param2"));
2714 HloInstruction* slice = call_builder.AddInstruction(
2715 HloInstruction::CreateSlice(shape, call_param2, {0, 0}, {2, 3}, {1, 1}));
2716 HloInstruction* mul =
2717 call_builder.AddInstruction(HloInstruction::CreateBinary(
2718 shape, HloOpcode::kMultiply, call_param, slice));
2719 HloInstruction* negate0 = call_builder.AddInstruction(
2720 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2721 HloInstruction* negate1 = call_builder.AddInstruction(
2722 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2723 HloInstruction* negate2 = call_builder.AddInstruction(
2724 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2725 HloInstruction* negate3 = call_builder.AddInstruction(
2726 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2727 HloInstruction* negate4 = call_builder.AddInstruction(
2728 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2729 HloInstruction* negate5 = call_builder.AddInstruction(
2730 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2731 HloInstruction* negate6 = call_builder.AddInstruction(
2732 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2733 HloInstruction* negate7 = call_builder.AddInstruction(
2734 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2735 HloInstruction* add0 =
2736 call_builder.AddInstruction(HloInstruction::CreateBinary(
2737 shape, HloOpcode::kAdd, call_param, negate7));
2738 HloComputation* call_computation =
2739 module->AddEmbeddedComputation(call_builder.Build());
2740
2741 auto builder = HloComputation::Builder(TestName());
2742 HloInstruction* p0 =
2743 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2744 HloInstruction* p1 =
2745 builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2746 HloInstruction* add1 = builder.AddInstruction(
2747 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2748 HloInstruction* add2 = builder.AddInstruction(
2749 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2750 HloInstruction* negate8 = builder.AddInstruction(
2751 HloInstruction::CreateUnary(shape2, HloOpcode::kNegate, p1));
2752 HloInstruction* call = builder.AddInstruction(
2753 HloInstruction::CreateCall(shape, {add1, negate8}, call_computation));
2754 HloInstruction* add3 = builder.AddInstruction(
2755 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, add1));
2756 HloInstruction* add4 = builder.AddInstruction(
2757 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add3));
2758 HloInstruction* add5 = builder.AddInstruction(
2759 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add2, add4));
2760 HloComputation* entry_computation =
2761 module->AddEntryComputation(builder.Build());
2762
2763 HloSchedule schedule(module.get());
2764 schedule.set_sequence(
2765 call_computation,
2766 {call_param, call_param2, slice, mul, negate0, negate1, negate2, negate3,
2767 negate4, negate5, negate6, negate7, add0});
2768 schedule.set_sequence(entry_computation,
2769 {p0, p1, add1, add2, negate8, call, add3, add4, add5});
2770 TF_CHECK_OK(module->set_schedule(schedule));
2771
2772 AssignMemorySpace(module.get(), -1, 5);
2773 }
2774
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule3)2775 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
2776 auto module = CreateNewVerifiedModule();
2777 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2778 Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2779
2780 auto call_builder = HloComputation::Builder("Call");
2781 HloInstruction* call_param = call_builder.AddInstruction(
2782 HloInstruction::CreateParameter(0, shape, "call_param"));
2783 // Use shape2 here which is larger (scheduled earlier) to occupy alternate
2784 // memory at the beginning. This should cause a situation where the prefetch
2785 // of add1 later in the function body gets the wrong offset which cannot be
2786 // communicated to the outside the function.
2787 HloInstruction* iota =
2788 call_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2789 HloInstruction* slice = call_builder.AddInstruction(
2790 HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2791 HloInstruction* mul =
2792 call_builder.AddInstruction(HloInstruction::CreateBinary(
2793 shape, HloOpcode::kMultiply, call_param, slice));
2794 HloInstruction* negate0 = call_builder.AddInstruction(
2795 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2796 HloInstruction* negate1 = call_builder.AddInstruction(
2797 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2798 HloInstruction* negate2 = call_builder.AddInstruction(
2799 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2800 HloInstruction* negate3 = call_builder.AddInstruction(
2801 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2802 HloInstruction* negate4 = call_builder.AddInstruction(
2803 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2804 HloInstruction* negate5 = call_builder.AddInstruction(
2805 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2806 HloInstruction* negate6 = call_builder.AddInstruction(
2807 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2808 HloInstruction* negate7 = call_builder.AddInstruction(
2809 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2810 HloInstruction* add0 =
2811 call_builder.AddInstruction(HloInstruction::CreateBinary(
2812 shape, HloOpcode::kAdd, call_param, negate7));
2813 HloComputation* call_computation =
2814 module->AddEmbeddedComputation(call_builder.Build());
2815
2816 auto builder = HloComputation::Builder(TestName());
2817 HloInstruction* p0 =
2818 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2819 HloInstruction* add1 = builder.AddInstruction(
2820 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2821 HloInstruction* add2 = builder.AddInstruction(
2822 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2823 HloInstruction* call = builder.AddInstruction(
2824 HloInstruction::CreateCall(shape, {add1}, call_computation));
2825 HloInstruction* add3 = builder.AddInstruction(
2826 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add1));
2827 HloComputation* entry_computation =
2828 module->AddEntryComputation(builder.Build());
2829
2830 HloSchedule schedule(module.get());
2831 schedule.set_sequence(
2832 call_computation,
2833 {call_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2834 negate4, negate5, negate6, negate7, add0});
2835 schedule.set_sequence(entry_computation, {p0, add1, add2, call, add3});
2836 TF_CHECK_OK(module->set_schedule(schedule));
2837
2838 AssignMemorySpace(module.get(), -1, 5);
2839 }
2840
2841 // TODO(berkin): This might be an incorrect input graph, investigate.
TEST_P(MemorySpaceAssignmentTest,DISABLED_NonEntryComputationSchedule4)2842 TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) {
2843 auto module = CreateNewVerifiedModule();
2844 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2845 Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2846
2847 auto true_builder = HloComputation::Builder("True");
2848 HloInstruction* true_param = true_builder.AddInstruction(
2849 HloInstruction::CreateParameter(0, shape, "true_param"));
2850 HloInstruction* iota =
2851 true_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2852 HloInstruction* slice = true_builder.AddInstruction(
2853 HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2854 HloInstruction* mul =
2855 true_builder.AddInstruction(HloInstruction::CreateBinary(
2856 shape, HloOpcode::kMultiply, true_param, slice));
2857 HloInstruction* negate0 = true_builder.AddInstruction(
2858 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2859 HloInstruction* negate1 = true_builder.AddInstruction(
2860 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2861 HloInstruction* negate2 = true_builder.AddInstruction(
2862 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2863 HloInstruction* negate3 = true_builder.AddInstruction(
2864 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2865 HloInstruction* negate4 = true_builder.AddInstruction(
2866 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2867 HloInstruction* negate5 = true_builder.AddInstruction(
2868 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2869 HloInstruction* negate6 = true_builder.AddInstruction(
2870 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2871 HloInstruction* negate7 = true_builder.AddInstruction(
2872 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2873 HloInstruction* add0 =
2874 true_builder.AddInstruction(HloInstruction::CreateBinary(
2875 shape, HloOpcode::kAdd, true_param, negate7));
2876 HloComputation* true_computation =
2877 module->AddEmbeddedComputation(true_builder.Build());
2878
2879 auto false_builder = HloComputation::Builder("False");
2880 HloInstruction* false_param = false_builder.AddInstruction(
2881 HloInstruction::CreateParameter(0, shape, "false_param"));
2882 HloComputation* false_computation =
2883 module->AddEmbeddedComputation(false_builder.Build());
2884
2885 auto builder = HloComputation::Builder(TestName());
2886 HloInstruction* p0 =
2887 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2888 HloInstruction* add1 = builder.AddInstruction(
2889 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2890 HloInstruction* add2 = builder.AddInstruction(
2891 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2892 HloInstruction* pred = builder.AddInstruction(
2893 HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2894 HloInstruction* conditional =
2895 builder.AddInstruction(HloInstruction::CreateConditional(
2896 shape, pred, add1, true_computation, add2, false_computation));
2897 HloInstruction* add3 = builder.AddInstruction(
2898 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, conditional, add1));
2899 HloComputation* entry_computation =
2900 module->AddEntryComputation(builder.Build());
2901
2902 HloSchedule schedule(module.get());
2903 schedule.set_sequence(
2904 true_computation,
2905 {true_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2906 negate4, negate5, negate6, negate7, add0});
2907 schedule.set_sequence(false_computation, {false_param});
2908 schedule.set_sequence(entry_computation,
2909 {p0, add1, add2, pred, conditional, add3});
2910 TF_CHECK_OK(module->set_schedule(schedule));
2911
2912 AssignMemorySpace(module.get(), -1, 5);
2913 }
2914
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule5)2915 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
2916 // This test reproduces the failure in b/143288178. Given a graph like the
2917 // following:
2918 //
2919 // ... = foo(a)
2920 // tuple = tuple((..., a)
2921 // ... = while(tuple) {
2922 // p = param(0)
2923 // a1 = get-tuple-element(p), index=n-1
2924 // ...
2925 // ROOT tuple((..., a1))
2926 // }
2927 //
2928 // If a copy to alternate memory is inserted before foo, and if the size of
2929 // the while body is less than max prefetch interval so that the copy-done is
2930 // kept in the alternate memory, then we end up referring to the copy-done in
2931 // the root instruction of the while loop body. I.e.,
2932 //
2933 // cs = copy-start(a)
2934 // ...
2935 // cd = copy-done(cs)
2936 // ... = foo(cd)
2937 // tuple = tuple((..., cd)
2938 // ... = while(tuple) {
2939 // p = param(0)
2940 // a1 = get-tuple-element(p), index=n-1
2941 // ...
2942 // ROOT tuple((..., cd)) <-- Error: cd belongs to outside computation.
2943 // }
2944 //
2945 auto module = CreateNewVerifiedModule();
2946 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2947 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2948 Shape tuple_shape =
2949 ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape});
2950
2951 auto cond_builder = HloComputation::Builder("WhileCond");
2952 HloInstruction* cond_param = cond_builder.AddInstruction(
2953 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2954 HloInstruction* cond_iter = cond_builder.AddInstruction(
2955 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2956 HloInstruction* cond_limit = cond_builder.AddInstruction(
2957 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2958 HloInstruction* cond_lt = cond_builder.AddInstruction(
2959 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2960 cond_limit, ComparisonDirection::kLt));
2961 HloComputation* cond_computation =
2962 module->AddEmbeddedComputation(cond_builder.Build());
2963
2964 auto body_builder = HloComputation::Builder("WhileBody");
2965 HloInstruction* body_param = body_builder.AddInstruction(
2966 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2967 HloInstruction* body_iter = body_builder.AddInstruction(
2968 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2969 HloInstruction* body_data = body_builder.AddInstruction(
2970 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2971 HloInstruction* body_iter_increment = body_builder.AddInstruction(
2972 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2973 HloInstruction* body_iter_next =
2974 body_builder.AddInstruction(HloInstruction::CreateBinary(
2975 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2976 HloInstruction* body_data2 = body_builder.AddInstruction(
2977 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2));
2978 HloInstruction* body_out = body_builder.AddInstruction(
2979 HloInstruction::CreateTuple({body_data, body_iter_next, body_data2}));
2980 HloComputation* body_computation =
2981 module->AddEmbeddedComputation(body_builder.Build());
2982
2983 auto builder = HloComputation::Builder(TestName());
2984 HloInstruction* data = builder.AddInstruction(
2985 HloInstruction::CreateParameter(0, shape, "param_data"));
2986 HloInstruction* iter = builder.AddInstruction(
2987 HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
2988 HloInstruction* data2 = builder.AddInstruction(
2989 HloInstruction::CreateParameter(2, scalar_shape, "param_data2"));
2990 HloInstruction* negate0 = builder.AddInstruction(
2991 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
2992 HloInstruction* negate1 = builder.AddInstruction(
2993 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2994 HloInstruction* negate2 = builder.AddInstruction(
2995 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2996 HloInstruction* negate3 = builder.AddInstruction(
2997 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2998 HloInstruction* negate4 = builder.AddInstruction(
2999 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3000 HloInstruction* negate5 = builder.AddInstruction(
3001 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3002 HloInstruction* negate6 = builder.AddInstruction(
3003 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3004 HloInstruction* negate7 = builder.AddInstruction(
3005 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
3006 HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
3007 scalar_shape, HloOpcode::kSubtract, iter, data2));
3008 HloInstruction* tuple = builder.AddInstruction(
3009 HloInstruction::CreateTuple({negate7, iter, data2}));
3010 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
3011 tuple_shape, cond_computation, body_computation, tuple));
3012 HloInstruction* while_data = builder.AddInstruction(
3013 HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1));
3014 HloInstruction* root =
3015 builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
3016 HloComputation* entry_computation =
3017 module->AddEntryComputation(builder.Build());
3018
3019 HloSchedule schedule(module.get());
3020 schedule.set_sequence(cond_computation,
3021 {cond_param, cond_iter, cond_limit, cond_lt});
3022 schedule.set_sequence(body_computation,
3023 {body_param, body_iter, body_data, body_iter_increment,
3024 body_iter_next, body_data2, body_out});
3025 schedule.set_sequence(
3026 entry_computation,
3027 {iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5,
3028 negate6, negate7, sub, tuple, while_op, while_data, root});
3029 TF_CHECK_OK(module->set_schedule(schedule));
3030
3031 // Set a large max prefetch interval so that the buffer can be kept in
3032 // alternate memory.
3033 AssignMemorySpace(module.get(), -1, 20);
3034 }
3035
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule6)3036 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
3037 auto module = CreateNewVerifiedModule();
3038 Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
3039 Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
3040 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape});
3041
3042 auto cond_builder = HloComputation::Builder("WhileCond");
3043 HloInstruction* cond_param = cond_builder.AddInstruction(
3044 HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
3045 HloInstruction* cond_iter = cond_builder.AddInstruction(
3046 HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
3047 HloInstruction* cond_limit = cond_builder.AddInstruction(
3048 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
3049 HloInstruction* cond_lt = cond_builder.AddInstruction(
3050 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
3051 cond_limit, ComparisonDirection::kLt));
3052 HloComputation* cond_computation =
3053 module->AddEmbeddedComputation(cond_builder.Build());
3054
3055 auto body_builder = HloComputation::Builder("WhileBody");
3056 HloInstruction* body_param = body_builder.AddInstruction(
3057 HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
3058 HloInstruction* body_iter = body_builder.AddInstruction(
3059 HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
3060 HloInstruction* body_data = body_builder.AddInstruction(
3061 HloInstruction::CreateGetTupleElement(shape, body_param, 0));
3062 HloInstruction* body_negate0 = body_builder.AddInstruction(
3063 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data));
3064 HloInstruction* body_negate1 = body_builder.AddInstruction(
3065 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0));
3066 HloInstruction* body_negate2 = body_builder.AddInstruction(
3067 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1));
3068 HloInstruction* body_negate3 = body_builder.AddInstruction(
3069 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2));
3070 HloInstruction* body_negate4 = body_builder.AddInstruction(
3071 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3));
3072 HloInstruction* body_negate5 = body_builder.AddInstruction(
3073 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4));
3074 HloInstruction* body_negate6 = body_builder.AddInstruction(
3075 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5));
3076 HloInstruction* body_negate7 = body_builder.AddInstruction(
3077 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6));
3078 HloInstruction* body_iter_increment = body_builder.AddInstruction(
3079 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
3080 HloInstruction* body_iter_next =
3081 body_builder.AddInstruction(HloInstruction::CreateBinary(
3082 scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
3083 HloInstruction* body_out = body_builder.AddInstruction(
3084 HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7}));
3085 HloComputation* body_computation =
3086 module->AddEmbeddedComputation(body_builder.Build());
3087
3088 auto builder = HloComputation::Builder(TestName());
3089 HloInstruction* data = builder.AddInstruction(
3090 HloInstruction::CreateParameter(0, shape, "param_data"));
3091 HloInstruction* iter = builder.AddInstruction(
3092 HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
3093 HloInstruction* negate0 = builder.AddInstruction(
3094 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
3095 HloInstruction* negate1 = builder.AddInstruction(
3096 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3097 HloInstruction* negate2 = builder.AddInstruction(
3098 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3099 HloInstruction* negate3 = builder.AddInstruction(
3100 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3101 HloInstruction* negate4 = builder.AddInstruction(
3102 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3103 HloInstruction* negate5 = builder.AddInstruction(
3104 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3105 HloInstruction* negate6 = builder.AddInstruction(
3106 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3107 HloInstruction* negate7 = builder.AddInstruction(
3108 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
3109 HloInstruction* tuple = builder.AddInstruction(
3110 HloInstruction::CreateTuple({data, iter, negate7}));
3111 HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
3112 tuple_shape, cond_computation, body_computation, tuple));
3113 HloInstruction* while_data = builder.AddInstruction(
3114 HloInstruction::CreateGetTupleElement(shape, while_op, 0));
3115 HloInstruction* while_data2 = builder.AddInstruction(
3116 HloInstruction::CreateGetTupleElement(shape, while_op, 2));
3117 HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
3118 shape, HloOpcode::kAdd, while_data, while_data2));
3119 HloComputation* entry_computation =
3120 module->AddEntryComputation(builder.Build());
3121
3122 HloSchedule schedule(module.get());
3123 schedule.set_sequence(cond_computation,
3124 {cond_param, cond_iter, cond_limit, cond_lt});
3125 schedule.set_sequence(
3126 body_computation,
3127 {body_param, body_iter, body_data, body_negate0, body_negate1,
3128 body_negate2, body_negate3, body_negate4, body_negate5, body_negate6,
3129 body_negate7, body_iter_increment, body_iter_next, body_out});
3130 schedule.set_sequence(
3131 entry_computation,
3132 {iter, data, negate0, negate1, negate2, negate3, negate4, negate5,
3133 negate6, negate7, tuple, while_op, while_data, while_data2, root});
3134 TF_CHECK_OK(module->set_schedule(schedule));
3135
3136 // Pick a large max prefetch interval to ensure all the while inputs are
3137 // allocated in the alternate memory.
3138 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3139 /*max_prefetch_interval=*/25);
3140
3141 // Index {0} of the while loop argument is not written inside the while loop,
3142 // so it can be trivially placed in the alternate memory space.
3143 *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
3144 LayoutUtil::MakeLayout(
3145 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3146 kAlternateMemorySpace);
3147 // Index {1} is a scalar, so it is always placed in the default memory.
3148 *ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() =
3149 LayoutUtil::MakeLayout(
3150 /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3151 kDefaultMemorySpace);
3152 // Index {2} of the while loop is placed in the default memory.
3153 *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
3154 LayoutUtil::MakeLayout(
3155 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3156 kDefaultMemorySpace);
3157
3158 // Expect the layout for the while loop and its aliased buffers.
3159 EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
3160 EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape));
3161 EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape));
3162 EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape));
3163 EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape));
3164 }
3165
TEST_P(MemorySpaceAssignmentTest,DanglingCopy)3166 TEST_P(MemorySpaceAssignmentTest, DanglingCopy) {
3167 // This situation was encountered in vss, where there is a mismatch in the
3168 // memory space in preset assignments and the output graph.
3169 HloComputation::Builder builder(TestName());
3170 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3171 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3172
3173 HloInstruction* p = builder.AddInstruction(
3174 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3175 HloInstruction* p0 = builder.AddInstruction(
3176 HloInstruction::CreateGetTupleElement(shape, p, 0));
3177 HloInstruction* p1a = builder.AddInstruction(
3178 HloInstruction::CreateGetTupleElement(shape, p, 1));
3179 HloInstruction* copy = builder.AddInstruction(
3180 HloInstruction::CreateUnary(shape, HloOpcode::kCopy, p1a));
3181 HloInstruction* negate0 = builder.AddInstruction(
3182 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3183 HloInstruction* negate1 = builder.AddInstruction(
3184 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3185 HloInstruction* negate2 = builder.AddInstruction(
3186 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3187 HloInstruction* negate3 = builder.AddInstruction(
3188 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3189 HloInstruction* negate4 = builder.AddInstruction(
3190 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3191 HloInstruction* negate5 = builder.AddInstruction(
3192 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3193 HloInstruction* negate6 = builder.AddInstruction(
3194 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3195 HloInstruction* p1b = builder.AddInstruction(
3196 HloInstruction::CreateGetTupleElement(shape, p, 1));
3197 HloInstruction* add = builder.AddInstruction(
3198 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1b));
3199
3200 auto module = CreateNewVerifiedModule();
3201 HloComputation* computation = module->AddEntryComputation(builder.Build());
3202
3203 HloSchedule schedule(module.get());
3204 schedule.set_sequence(
3205 computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3206 negate6, p1a, copy, p1b, add});
3207 TF_CHECK_OK(module->set_schedule(schedule));
3208
3209 AssignMemorySpace(module.get());
3210 }
3211
TEST_P(MemorySpaceAssignmentTest,MultiOutputFusion)3212 TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) {
3213 HloComputation::Builder builder(TestName());
3214 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3215 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3216 auto module = CreateNewVerifiedModule();
3217
3218 HloComputation::Builder fusion_builder("fusion");
3219 HloInstruction* fusion_param0 = fusion_builder.AddInstruction(
3220 HloInstruction::CreateParameter(0, shape, "p0"));
3221 HloInstruction* fusion_param1 = fusion_builder.AddInstruction(
3222 HloInstruction::CreateParameter(1, shape, "p1"));
3223 fusion_builder.AddInstruction(
3224 HloInstruction::CreateTuple({fusion_param0, fusion_param1}));
3225 HloComputation* fusion_computation =
3226 module->AddEmbeddedComputation(fusion_builder.Build());
3227
3228 HloInstruction* p0 =
3229 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3230 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3231 tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3232 fusion_computation));
3233 HloInstruction* element0 = builder.AddInstruction(
3234 HloInstruction::CreateGetTupleElement(shape, fusion, 0));
3235 HloInstruction* element1 = builder.AddInstruction(
3236 HloInstruction::CreateGetTupleElement(shape, fusion, 1));
3237 HloInstruction* add = builder.AddInstruction(
3238 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3239
3240 HloComputation* computation = module->AddEntryComputation(builder.Build());
3241
3242 HloSchedule schedule(module.get());
3243 schedule.set_sequence(computation, {p0, fusion, element0, element1, add});
3244 TF_CHECK_OK(module->set_schedule(schedule));
3245
3246 AssignMemorySpace(module.get());
3247 }
3248
TEST_P(MemorySpaceAssignmentTest,TupleInput)3249 TEST_P(MemorySpaceAssignmentTest, TupleInput) {
3250 HloComputation::Builder builder(TestName());
3251 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3252 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3253 auto module = CreateNewVerifiedModule();
3254
3255 HloComputation::Builder fusion_builder("fusion");
3256 HloInstruction* fusion_param = fusion_builder.AddInstruction(
3257 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3258 HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
3259 HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
3260 HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
3261 HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
3262 fusion_builder.AddInstruction(HloInstruction::CreateBinary(
3263 shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
3264 HloComputation* fusion_computation =
3265 module->AddEmbeddedComputation(fusion_builder.Build());
3266
3267 HloInstruction* p0 =
3268 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3269 HloInstruction* p1 =
3270 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3271 HloInstruction* negate0 = builder.AddInstruction(
3272 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3273 HloInstruction* negate1 = builder.AddInstruction(
3274 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3275 HloInstruction* tuple =
3276 builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
3277 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3278 shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
3279
3280 HloComputation* computation = module->AddEntryComputation(builder.Build());
3281
3282 HloSchedule schedule(module.get());
3283 schedule.set_sequence(computation, {p0, p1, negate0, negate1, tuple, fusion});
3284 TF_CHECK_OK(module->set_schedule(schedule));
3285
3286 AssignMemorySpace(module.get());
3287 }
3288
TEST_P(MemorySpaceAssignmentTest,TupleToTuple1)3289 TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) {
3290 HloComputation::Builder builder(TestName());
3291 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3292 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3293 auto module = CreateNewVerifiedModule();
3294
3295 HloComputation::Builder fusion0_builder("fusion0");
3296 HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3297 HloInstruction::CreateParameter(0, shape, "p0"));
3298 HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3299 HloInstruction::CreateParameter(1, shape, "p1"));
3300 fusion0_builder.AddInstruction(
3301 HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3302 HloComputation* fusion0_computation =
3303 module->AddEmbeddedComputation(fusion0_builder.Build());
3304
3305 HloComputation::Builder fusion1_builder("fusion1");
3306 HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3307 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3308 HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3309 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3310 HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3311 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3312 fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3313 shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3314 HloComputation* fusion1_computation =
3315 module->AddEmbeddedComputation(fusion1_builder.Build());
3316
3317 HloInstruction* p0 =
3318 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3319 HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3320 tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3321 fusion0_computation));
3322 HloInstruction* element0 = builder.AddInstruction(
3323 HloInstruction::CreateGetTupleElement(shape, fusion0, 0));
3324 HloInstruction* element1 = builder.AddInstruction(
3325 HloInstruction::CreateGetTupleElement(shape, fusion0, 1));
3326 HloInstruction* negate0 = builder.AddInstruction(
3327 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3328 HloInstruction* negate1 = builder.AddInstruction(
3329 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3330 HloInstruction* negate2 = builder.AddInstruction(
3331 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3332 HloInstruction* negate3 = builder.AddInstruction(
3333 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3334 HloInstruction* negate4 = builder.AddInstruction(
3335 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3336 HloInstruction* negate5 = builder.AddInstruction(
3337 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3338 HloInstruction* negate6 = builder.AddInstruction(
3339 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3340 HloInstruction* add0 = builder.AddInstruction(
3341 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3342 HloInstruction* add1 = builder.AddInstruction(
3343 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, negate6));
3344 HloInstruction* fusion1 = builder.AddInstruction(
3345 HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3346 {fusion0}, fusion1_computation));
3347 HloInstruction* mul = builder.AddInstruction(
3348 HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add1, fusion1));
3349
3350 HloComputation* computation = module->AddEntryComputation(builder.Build());
3351
3352 HloSchedule schedule(module.get());
3353 schedule.set_sequence(
3354 computation,
3355 {p0, fusion0, element0, element1, negate0, negate1, negate2, negate3,
3356 negate4, negate5, negate6, add0, add1, fusion1, mul});
3357 TF_CHECK_OK(module->set_schedule(schedule));
3358
3359 AssignMemorySpace(module.get(), -1, 5);
3360 EXPECT_THAT(fusion1,
3361 op::Fusion(op::Tuple(
3362 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3363 op::GetTupleElement(op::Fusion(), 0)),
3364 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3365 op::GetTupleElement(op::Fusion(), 1)))));
3366 }
3367
TEST_P(MemorySpaceAssignmentTest,TupleToTuple2)3368 TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) {
3369 HloComputation::Builder builder(TestName());
3370 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3371 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3372 Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({shape, tuple_shape});
3373 auto module = CreateNewVerifiedModule();
3374
3375 HloComputation::Builder fusion0_builder("fusion0");
3376 HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3377 HloInstruction::CreateParameter(0, shape, "p0"));
3378 HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3379 HloInstruction::CreateParameter(1, shape, "p1"));
3380 HloInstruction* fusion0_tuple = fusion0_builder.AddInstruction(
3381 HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3382 fusion0_builder.AddInstruction(
3383 HloInstruction::CreateTuple({fusion0_param0, fusion0_tuple}));
3384 HloComputation* fusion0_computation =
3385 module->AddEmbeddedComputation(fusion0_builder.Build());
3386
3387 HloComputation::Builder fusion1_builder("fusion1");
3388 HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3389 HloInstruction::CreateParameter(0, nested_tuple_shape, "p"));
3390 HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3391 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3392 HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3393 HloInstruction::CreateGetTupleElement(tuple_shape, fusion1_param, 1));
3394 HloInstruction* fusion1_element2 = fusion1_builder.AddInstruction(
3395 HloInstruction::CreateGetTupleElement(shape, fusion1_element1, 1));
3396 fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3397 shape, HloOpcode::kAdd, fusion1_element0, fusion1_element2));
3398 HloComputation* fusion1_computation =
3399 module->AddEmbeddedComputation(fusion1_builder.Build());
3400
3401 HloInstruction* p0 =
3402 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3403 HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3404 nested_tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3405 fusion0_computation));
3406 HloInstruction* negate0 = builder.AddInstruction(
3407 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3408 HloInstruction* negate1 = builder.AddInstruction(
3409 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3410 HloInstruction* negate2 = builder.AddInstruction(
3411 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3412 HloInstruction* negate3 = builder.AddInstruction(
3413 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3414 HloInstruction* negate4 = builder.AddInstruction(
3415 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3416 HloInstruction* negate5 = builder.AddInstruction(
3417 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3418 HloInstruction* negate6 = builder.AddInstruction(
3419 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3420 HloInstruction* fusion1 = builder.AddInstruction(
3421 HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3422 {fusion0}, fusion1_computation));
3423
3424 HloComputation* computation = module->AddEntryComputation(builder.Build());
3425
3426 HloSchedule schedule(module.get());
3427 schedule.set_sequence(
3428 computation, {p0, fusion0, negate0, negate1, negate2, negate3, negate4,
3429 negate5, negate6, fusion1});
3430 TF_CHECK_OK(module->set_schedule(schedule));
3431
3432 AssignMemorySpace(module.get(), -1, 5);
3433
3434 EXPECT_THAT(
3435 fusion1,
3436 op::Fusion(op::Tuple(
3437 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3438 op::GetTupleElement(op::Fusion(), 0)),
3439 op::Tuple(
3440 op::AsyncCopy(
3441 kAlternateMemorySpace, kDefaultMemorySpace,
3442 op::GetTupleElement(op::GetTupleElement(op::Fusion(), 1), 0)),
3443 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3444 op::GetTupleElement(
3445 op::GetTupleElement(op::Fusion(), 1), 1))))));
3446 }
3447
TEST_P(MemorySpaceAssignmentTest,TupleToTuple3)3448 TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) {
3449 HloComputation::Builder builder(TestName());
3450 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3451 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3452 auto module = CreateNewVerifiedModule();
3453
3454 HloComputation::Builder fusion0_builder("fusion0");
3455 HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3456 HloInstruction::CreateParameter(0, shape, "p0"));
3457 HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3458 HloInstruction::CreateParameter(1, shape, "p1"));
3459 fusion0_builder.AddInstruction(
3460 HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3461 HloComputation* fusion0_computation =
3462 module->AddEmbeddedComputation(fusion0_builder.Build());
3463
3464 HloComputation::Builder fusion1_builder("fusion1");
3465 HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3466 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3467 HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3468 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3469 HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3470 HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3471 fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3472 shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3473 HloComputation* fusion1_computation =
3474 module->AddEmbeddedComputation(fusion1_builder.Build());
3475
3476 HloInstruction* p0 =
3477 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3478 HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3479 tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3480 fusion0_computation));
3481 HloInstruction* fusion1 = builder.AddInstruction(
3482 HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3483 {fusion0}, fusion1_computation));
3484
3485 HloComputation* computation = module->AddEntryComputation(builder.Build());
3486
3487 HloSchedule schedule(module.get());
3488 schedule.set_sequence(computation, {p0, fusion0, fusion1});
3489 TF_CHECK_OK(module->set_schedule(schedule));
3490
3491 AssignMemorySpace(module.get());
3492 EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
3493 }
3494
TEST_P(MemorySpaceAssignmentTest,InputOutputAlias)3495 TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
3496 HloComputation::Builder builder(TestName());
3497 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3498 Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3499 HloInstruction* p = builder.AddInstruction(
3500 HloInstruction::CreateParameter(0, tuple_shape, "p"));
3501 HloInstruction* p0 = builder.AddInstruction(
3502 HloInstruction::CreateGetTupleElement(shape, p, 0));
3503 HloInstruction* negate0 = builder.AddInstruction(
3504 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3505 HloInstruction* negate1 = builder.AddInstruction(
3506 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3507 HloInstruction* negate2 = builder.AddInstruction(
3508 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3509 HloInstruction* negate3 = builder.AddInstruction(
3510 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3511 HloInstruction* negate4 = builder.AddInstruction(
3512 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3513 HloInstruction* negate5 = builder.AddInstruction(
3514 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3515 HloInstruction* negate6 = builder.AddInstruction(
3516 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3517 HloInstruction* p1 = builder.AddInstruction(
3518 HloInstruction::CreateGetTupleElement(shape, p, 1));
3519 HloInstruction* add = builder.AddInstruction(
3520 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3521 HloInstruction* negate7 = builder.AddInstruction(
3522 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add));
3523 HloInstruction* tuple =
3524 builder.AddInstruction(HloInstruction::CreateTuple({p0, add}));
3525
3526 auto module = CreateNewVerifiedModule();
3527 HloComputation* computation = module->AddEntryComputation(builder.Build());
3528
3529 HloSchedule schedule(module.get());
3530 schedule.set_sequence(
3531 computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3532 negate6, p1, add, negate7, tuple});
3533 TF_CHECK_OK(module->set_schedule(schedule));
3534
3535 // Make input {0} alias with output {0} and input {1} alias with output {1}.
3536 TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
3537 TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
3538
3539 AssignMemorySpace(module.get());
3540
3541 // Make sure the input is in the default memory space.
3542 EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(),
3543 kDefaultMemorySpace);
3544 EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(),
3545 kDefaultMemorySpace);
3546 }
3547
TEST_P(MemorySpaceAssignmentTest,CostAnalysis)3548 TEST_P(MemorySpaceAssignmentTest, CostAnalysis) {
3549 // This is mostly a smoke test since it's difficult and brittle to work out
3550 // the cost of the HLO instructions.
3551 HloComputation::Builder builder(TestName());
3552 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3553 HloInstruction* p0 =
3554 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3555 HloInstruction* p1 =
3556 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3557 HloInstruction* negate0 = builder.AddInstruction(
3558 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3559 HloInstruction* negate1 = builder.AddInstruction(
3560 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3561 HloInstruction* negate2 = builder.AddInstruction(
3562 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3563 HloInstruction* negate3 = builder.AddInstruction(
3564 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3565 HloInstruction* negate4 = builder.AddInstruction(
3566 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3567 HloInstruction* negate5 = builder.AddInstruction(
3568 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3569 HloInstruction* negate6 = builder.AddInstruction(
3570 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3571 HloInstruction* add = builder.AddInstruction(
3572 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3573
3574 auto module = CreateNewVerifiedModule();
3575 HloComputation* computation = module->AddEntryComputation(builder.Build());
3576
3577 HloSchedule schedule(module.get());
3578 schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
3579 negate3, negate4, negate5, negate6, add});
3580 TF_CHECK_OK(module->set_schedule(schedule));
3581
3582 AssignMemorySpaceUsingCostAnalysis(module.get());
3583 // Parameters are in the default memory space.
3584 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3585 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3586 // Negate instructions are in the alternate memory space (1).
3587 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3588 F32, {2, 3},
3589 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3590 kAlternateMemorySpace);
3591 EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
3592 EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
3593 EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
3594 EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
3595 EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
3596 EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
3597 EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
3598 }
3599
TEST_P(MemorySpaceAssignmentTest,MemoryBoundednessBufferIntervalCompare)3600 TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
3601 // This test is carefully crafted to force only negates to be allocated to the
3602 // alternate memory. The graph consists of interleaving negate and tanh
3603 // operations:
3604 //
3605 // +------+ +-------+ +-----
3606 // / \ / \ /
3607 // negate tanh negate tanh negate tanh
3608 // \ / \ /
3609 // +--------+ +---------+
3610 //
3611 // The alternate memory is sized to fit only two f32[4,3] tensors at a time.
3612 // Also, transcendentals are made to be lower bandwidth than FLOPs. So, the
3613 // MemoryBoundednessBufferIntervalCompare should prioritize the negates, which
3614 // are more memory bound.
3615 HloComputation::Builder builder(TestName());
3616 Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3617 HloInstruction* p0 =
3618 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3619 HloInstruction* p1 =
3620 builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3621 HloInstruction* tanh0 = builder.AddInstruction(
3622 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3623 HloInstruction* negate0 = builder.AddInstruction(
3624 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3625 HloInstruction* tanh1 = builder.AddInstruction(
3626 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh0));
3627 HloInstruction* negate1 = builder.AddInstruction(
3628 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3629 HloInstruction* tanh2 = builder.AddInstruction(
3630 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3631 HloInstruction* negate2 = builder.AddInstruction(
3632 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3633 HloInstruction* tanh3 = builder.AddInstruction(
3634 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3635 HloInstruction* negate3 = builder.AddInstruction(
3636 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3637 HloInstruction* tanh4 = builder.AddInstruction(
3638 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh3));
3639 HloInstruction* negate4 = builder.AddInstruction(
3640 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3641 HloInstruction* tuple =
3642 builder.AddInstruction(HloInstruction::CreateTuple({tanh4, negate4}));
3643
3644 auto module = CreateNewVerifiedModule();
3645 HloComputation* computation = module->AddEntryComputation(builder.Build());
3646
3647 HloSchedule schedule(module.get());
3648 schedule.set_sequence(computation,
3649 {p0, p1, tanh0, negate0, tanh1, negate1, tanh2, negate2,
3650 tanh3, negate3, tanh4, negate4, tuple});
3651 TF_CHECK_OK(module->set_schedule(schedule));
3652
3653 AssignMemorySpaceUsingCostAnalysis(module.get());
3654 // Parameters are in the default memory space.
3655 EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3656 EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3657 Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3658 F32, {4, 3},
3659 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3660 kDefaultMemorySpace);
3661 // Expect only negates to be in alternate memory space. Not all might fit but
3662 // make sure at least one does.
3663 std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2,
3664 negate3, negate4};
3665 int64 num_negates_in_alternate_mem = absl::c_count_if(
3666 negate_instructions, [&](const HloInstruction* instruction) {
3667 return instruction->shape().layout().memory_space() ==
3668 kAlternateMemorySpace;
3669 });
3670 EXPECT_GE(num_negates_in_alternate_mem, 1);
3671 EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
3672 EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
3673 EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
3674 EXPECT_THAT(tanh3, op::ShapeWithLayout(shape_in_default_mem));
3675 EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
3676 }
3677
TEST_P(MemorySpaceAssignmentTest,SimpleWhileTupleTest)3678 TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
3679 Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
3680 Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
3681 Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1});
3682 auto module = CreateNewVerifiedModule("SimpleWhile");
3683 HloSchedule schedule(module.get());
3684
3685 // A simple compare-to-limit (x < 4) computation for a While.
3686 //
3687 // condition:
3688 // const4[s32] -----------------------------------\
3689 // \
3690 // param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
3691 //
3692 HloComputation* cond_computation;
3693 {
3694 auto builder = HloComputation::Builder("WhileCond");
3695 auto const4 = builder.AddInstruction(
3696 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
3697 auto param = builder.AddInstruction(
3698 HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3699 auto index = builder.AddInstruction(
3700 HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
3701 auto compare = builder.AddInstruction(
3702 HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
3703 const4, ComparisonDirection::kLt));
3704 cond_computation = module->AddEmbeddedComputation(builder.Build());
3705 schedule.set_sequence(cond_computation, {const4, param, index, compare});
3706 }
3707
3708 // Builds a simple body computation for a While.
3709 //
3710 // body:
3711 // constv[f32[1]] --------------------------------------\
3712 // \
3713 // /--- get-tuple-elementv[1] --- addv ---\
3714 // param[(s32,f32[1])] ---| tuple
3715 // \--- get-tuple-elementc[0] --- addc ---/
3716 // /
3717 // const1[s32] -----------------------------------------/
3718 //
3719 HloComputation* body_computation;
3720 {
3721 auto builder = HloComputation::Builder("WhileBody");
3722 auto const1 = builder.AddInstruction(
3723 HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
3724 auto constv = builder.AddInstruction(
3725 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f})));
3726 auto param = builder.AddInstruction(
3727 HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3728 auto indexc = builder.AddInstruction(
3729 HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
3730 auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
3731 indexc->shape(), HloOpcode::kAdd, indexc, const1));
3732 auto indexv = builder.AddInstruction(
3733 HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
3734 auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
3735 constv->shape(), HloOpcode::kAdd, indexv, constv));
3736 auto tuple =
3737 builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
3738 body_computation = module->AddEmbeddedComputation(builder.Build());
3739 schedule.set_sequence(body_computation, {const1, constv, param, indexc,
3740 addc, indexv, addv, tuple});
3741 }
3742
3743 // This tests a simple while loop where the parameters are aliased with the
3744 // output buffers.
3745 auto builder = HloComputation::Builder("SimpleWhile");
3746 auto param = builder.AddInstruction(
3747 HloInstruction::CreateParameter(0, t_s32_f32v1, "param"));
3748 auto gte0 = builder.AddInstruction(
3749 HloInstruction::CreateGetTupleElement(s32, param, 0));
3750 auto gte1 = builder.AddInstruction(
3751 HloInstruction::CreateGetTupleElement(f32v1, param, 1));
3752 auto tuple =
3753 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
3754 auto while0 = builder.AddInstruction(HloInstruction::CreateWhile(
3755 t_s32_f32v1, cond_computation, body_computation, tuple));
3756
3757 HloComputation* computation = module->AddEntryComputation(builder.Build());
3758 schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0});
3759 TF_CHECK_OK(module->set_schedule(schedule));
3760
3761 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3762 /*max_prefetch_interval=*/50);
3763
3764 // Ensure all parameters and while are placed in default memory.
3765 Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3766 F32, {4, 6},
3767 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3768 kDefaultMemorySpace);
3769 Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3770 xla::S32, {},
3771 /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3772 kDefaultMemorySpace);
3773 Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3774 F32, {1},
3775 /*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3776 kDefaultMemorySpace);
3777 Shape t_s32_f32v1_in_default_mem =
3778 ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem});
3779 EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3780 EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3781 }
3782
TEST_P(MemorySpaceAssignmentTest,EvictionsShouldntBeDelayed)3783 TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
3784 // This test reproduces an eviction scheduling bug where evictions to default
3785 // memory can happen later than intended, causing memory corruption. This test
3786 // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3]
3787 // tensors instead, so at most two tensors should fit in the alternate memory
3788 // space at a given time. We have a number of redundant operations
3789 // (tanh_redundant ops) that do not have users. The bug was due to
3790 // SimplifyGraph removing dead instructions, and removing them from the
3791 // schedule. However, the CopyStart/CopyDone insertion relies on the schedule
3792 // indexes, so they could be inserted too late.
3793 HloComputation::Builder builder(TestName());
3794 Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3795 HloInstruction* p0 =
3796 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3797 HloInstruction* tanh0 = builder.AddInstruction(
3798 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3799 HloInstruction* tanh_redundant0 = builder.AddInstruction(
3800 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3801 HloInstruction* tanh_redundant1 = builder.AddInstruction(
3802 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3803 HloInstruction* tanh_redundant2 = builder.AddInstruction(
3804 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3805 HloInstruction* tanh_redundant3 = builder.AddInstruction(
3806 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3807 HloInstruction* tanh_redundant4 = builder.AddInstruction(
3808 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3809 HloInstruction* tanh_redundant5 = builder.AddInstruction(
3810 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3811 HloInstruction* tanh_redundant6 = builder.AddInstruction(
3812 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3813 HloInstruction* negate0 = builder.AddInstruction(
3814 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0));
3815 HloInstruction* tanh1 = builder.AddInstruction(
3816 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0));
3817 HloInstruction* negate1 = builder.AddInstruction(
3818 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3819 HloInstruction* tanh2 = builder.AddInstruction(
3820 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3821 HloInstruction* negate2 = builder.AddInstruction(
3822 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3823 HloInstruction* tanh3 = builder.AddInstruction(
3824 HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3825 HloInstruction* negate3 = builder.AddInstruction(
3826 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3827 HloInstruction* tuple = builder.AddInstruction(
3828 HloInstruction::CreateTuple({tanh3, negate3, tanh0}));
3829
3830 auto module = CreateNewVerifiedModule();
3831 HloComputation* computation = module->AddEntryComputation(builder.Build());
3832
3833 HloSchedule schedule(module.get());
3834 schedule.set_sequence(
3835 computation,
3836 {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2,
3837 tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6,
3838 negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple});
3839 TF_CHECK_OK(module->set_schedule(schedule));
3840
3841 AssignMemorySpaceUsingCostAnalysis(module.get());
3842
3843 TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
3844 HloAliasAnalysis::Run(module.get()));
3845 TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range,
3846 HloLiveRange::Run(module->schedule(), *alias_analysis,
3847 module->entry_computation()));
3848
3849 std::vector<int> num_live_buffers_in_alternate_mem(
3850 hlo_live_range->flattened_instruction_sequence().size() + 1, 0);
3851
3852 // Go through each value and for those that are allocated in the alternate
3853 // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for
3854 // every time step that they are live.
3855 for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
3856 const Shape& shape = value->shape();
3857 if (!shape.has_layout() ||
3858 shape.layout().memory_space() == kDefaultMemorySpace) {
3859 continue;
3860 }
3861
3862 HloLiveRange::TimeBound time_bound =
3863 hlo_live_range->buffer_live_ranges().at(value);
3864 for (int i = time_bound.start; i <= time_bound.end; ++i) {
3865 ++num_live_buffers_in_alternate_mem[i];
3866 }
3867 }
3868
3869 // The test memory can at most hold two f32[4,3] buffers at a time. If there
3870 // is more than that, it means we have memory corruption.
3871 for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) {
3872 EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2);
3873 }
3874 }
3875
TEST_P(MemorySpaceAssignmentTest,InputOutputsInAlternateMemShouldntBeAssigned)3876 TEST_P(MemorySpaceAssignmentTest,
3877 InputOutputsInAlternateMemShouldntBeAssigned) {
3878 // When input/outputs are marked to be in the alternate memory (e.g.
3879 // go/tpu-fast-mem-inference), do not allocate those and assume they will live
3880 // in the alternate memory for the entire computation. The BufferAssignment
3881 // pass, which is run after this, will allocate those buffers.
3882 HloComputation::Builder builder(TestName());
3883 Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3884 Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3885 F32, {2, 3},
3886 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3887 kAlternateMemorySpace);
3888 // p0 is in the default memory space.
3889 HloInstruction* p0 =
3890 builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3891 // p1 is in the alternate memory space.
3892 HloInstruction* p1 = builder.AddInstruction(
3893 HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1"));
3894 HloInstruction* negate0 = builder.AddInstruction(
3895 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3896 HloInstruction* negate1 = builder.AddInstruction(
3897 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3898 HloInstruction* negate2 = builder.AddInstruction(
3899 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3900 HloInstruction* negate3 = builder.AddInstruction(
3901 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3902 HloInstruction* negate4 = builder.AddInstruction(
3903 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3904 HloInstruction* negate5 = builder.AddInstruction(
3905 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3906 HloInstruction* negate6 = builder.AddInstruction(
3907 HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3908 HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3909 shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1));
3910 // Index {0} of the root instruction is in the alternate memory space, index
3911 // {1} is in the default memory space.
3912 HloInstruction* tuple =
3913 builder.AddInstruction(HloInstruction::CreateTuple({add, negate5}));
3914
3915 auto module = CreateNewVerifiedModule();
3916 HloComputation* computation = module->AddEntryComputation(builder.Build());
3917
3918 HloSchedule schedule(module.get());
3919 schedule.set_sequence(computation,
3920 {p0, p1, negate0, negate1, negate2, negate3, negate4,
3921 negate5, negate6, add, tuple});
3922 TF_CHECK_OK(module->set_schedule(schedule));
3923
3924 std::unique_ptr<PresetAssignments> preset_assignments =
3925 AssignMemorySpace(module.get());
3926
3927 // Ensure that p1 is in the alternate memory and add, which has p1 as an
3928 // operand, has a direct dependency to p1 (no CopyStart/CopyDone).
3929 EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem));
3930 EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1)));
3931 // Make sure add is still in the alternate memory space.
3932 EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
3933
3934 // Check the preset assignments and ensure the inputs/outputs in the alternate
3935 // memory space aren't in the preset assignments. Inputs/outputs in the
3936 // alternate memory space are left to BufferAssignment to be allocated.
3937 for (const auto& position_and_chunk : preset_assignments->chunks()) {
3938 const HloPosition& position = position_and_chunk.first;
3939 EXPECT_NE(position.instruction, p1);
3940 EXPECT_NE(position.instruction, add);
3941 }
3942 }
3943
TEST_P(MemorySpaceAssignmentTest,PendingChunkMemoryCorruptionBug)3944 TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
3945 // Tests a memory corruption bug where the allocated chunk overlaps with a
3946 // pending chunk. To test this, we provide a new buffer interval compare where
3947 // we prioritize the allocation of sine, cosine, and tanh to create the
3948 // situation:
3949 //
3950 // Max memory
3951 // -------------------------------------------
3952 // +------------+
3953 // | b |
3954 // +------------+
3955 // +-------+
3956 // | |
3957 // | |
3958 // | a |
3959 // | | +------------+
3960 // | | | n |
3961 // +-------+ +------------+
3962 // -------------------------------------------
3963 // Min memory time ->
3964 //
3965 //
3966 // Then allocating for buffer d, we have these two prefetch buffers
3967 // overlapping:
3968 //
3969 // Max memory
3970 // -------------------------------------------
3971 // +------------+ +----------+
3972 // | b | | prefetch |
3973 // +------------+ | for o |
3974 // +-------+ +---------+ |
3975 // | | | | | |
3976 // | | | | | |
3977 // | a | | +----|-----+
3978 // | | | prefetch| +------------+
3979 // | | | for m | | n |
3980 // +-------+ +---------+ +------------+
3981 // -------------------------------------------
3982 // Min memory time ->
3983 //
3984 absl::string_view hlo_string = R"(
3985 HloModule bug, is_scheduled=true
3986
3987 ENTRY %Entry {
3988 %param0 = f32[8,3] parameter(0)
3989 %param1 = f32[2,4] parameter(1)
3990 %a = f32[8,3] sine(%param0)
3991 %b = f32[2,4] cosine(%param1)
3992 %d = f32[8,3] tanh(%a)
3993 %c = f32[8,3] negate(%a)
3994 %e = f32[2,4] negate(%b)
3995 %f = f32[2,4] negate(%e)
3996 %g = f32[2,4] negate(%f)
3997 %h = f32[2,4] negate(%g)
3998 %i = f32[2,4] negate(%h)
3999 %j = f32[2,4] negate(%i)
4000 %k = f32[2,4] negate(%j)
4001 %l = f32[2,4] negate(%k)
4002 %m = f32[8,3] negate(%d)
4003 %n = f32[2,4] sine(%l)
4004 %o = f32[8,3] negate(%d)
4005 %p = f32[2,4] negate(%n)
4006 %q = f32[8,3] negate(%m)
4007 ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o)
4008 }
4009 )";
4010
4011 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4012 [](const MemorySpaceAssignment::BufferInterval& a,
4013 const MemorySpaceAssignment::BufferInterval& b) {
4014 auto get_opcode_priority = [](const HloOpcode& opcode) {
4015 switch (opcode) {
4016 case HloOpcode::kSin:
4017 return 0;
4018 case HloOpcode::kCos:
4019 return 1;
4020 case HloOpcode::kTanh:
4021 return 2;
4022 default:
4023 return 3;
4024 }
4025 };
4026
4027 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4028 get_opcode_priority(b.buffer->defining_instruction()->opcode());
4029 };
4030 TF_ASSERT_OK_AND_ASSIGN(auto module,
4031 ParseAndReturnVerifiedModule(hlo_string));
4032
4033 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4034 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4035 buffer_interval_compare, &prefetch_interval_picker);
4036 }
4037
TEST_P(MemorySpaceAssignmentTest,MoveCopyDoneEarlier)4038 TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) {
4039 // This tests the case where an earlier placed smaller buffer may block a
4040 // larger buffer due to asynchronous copy ordering. The smaller buffer (the
4041 // operand of sin) will be placed first. The cos, whose operand is 3 times
4042 // larger than sin's, needs longer time for the asynhronous copy. The cos is
4043 // placed right after sin, leading to a copy ordering violation:
4044 //
4045 // param1------------------>CS----->CD->sin
4046 // param0------------->CS------------------->CD->cos
4047 //
4048 // To fix this, we need to move copy done for cos earlier and ensure both of
4049 // these buffers get alternate memory allocations:
4050 //
4051 // param1------------------>CS----->CD->sin
4052 // param0-->CS------------------->CD------------>cos
4053 absl::string_view hlo_string = R"(
4054 HloModule module, is_scheduled=true
4055
4056 ENTRY Entry {
4057 param0 = f32[8,3] parameter(0)
4058 param1 = f32[2,4] parameter(1)
4059 a = f32[2,4] negate(param1)
4060 b = f32[2,4] negate(a)
4061 c = f32[2,4] negate(b)
4062 d = f32[2,4] negate(c)
4063 e = f32[2,4] negate(d)
4064 f = f32[2,4] negate(e)
4065 g = f32[2,4] negate(f)
4066 h = f32[2,4] negate(g)
4067 i = f32[2,4] negate(h)
4068 j = f32[2,4] negate(i)
4069 k = f32[2,4] negate(j)
4070 l = f32[2,4] negate(k)
4071 m = f32[2,4] negate(l)
4072 n = f32[2,4] negate(m)
4073 sin = f32[2,4] sine(param1)
4074 o = f32[2,4] negate(n)
4075 cos = f32[8,3] cosine(param0)
4076 ROOT tuple = (f32[8,3], f32[2,4], f32[2,4]) tuple(cos, sin, o)
4077 }
4078 )";
4079
4080 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4081 [](const MemorySpaceAssignment::BufferInterval& a,
4082 const MemorySpaceAssignment::BufferInterval& b) {
4083 auto get_opcode_priority = [](const HloOpcode& opcode) {
4084 switch (opcode) {
4085 case HloOpcode::kSin:
4086 return 0;
4087 case HloOpcode::kCos:
4088 return 1;
4089 case HloOpcode::kTanh:
4090 return 2;
4091 default:
4092 return 3;
4093 }
4094 };
4095
4096 auto get_user_priority = [&](const HloValue& value) {
4097 int priority = INT_MAX;
4098 for (const auto& use : value.uses()) {
4099 priority = std::min(priority,
4100 get_opcode_priority(use.instruction->opcode()));
4101 }
4102 return priority;
4103 };
4104
4105 return get_user_priority(*a.buffer) < get_user_priority(*b.buffer);
4106 };
4107 TF_ASSERT_OK_AND_ASSIGN(auto module,
4108 ParseAndReturnVerifiedModule(hlo_string));
4109
4110 HloCostAnalysis hlo_cost_analysis(ShapeSize);
4111 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
4112 FakeMemorySpaceAssignmentCostAnalysis::Create(
4113 hlo_cost_analysis, *module));
4114 cost_analysis->SetOverrideForGetAsyncCopyElapsed([](const Shape& shape) {
4115 // This should return 2 for f32[2,4] and 6 for f32[8,3].
4116 return ShapeSize(shape) / 16;
4117 });
4118 CostAnalysisPrefetchIntervalPicker interval_picker(
4119 *cost_analysis,
4120 /*min_async_copy_to_overlap_ratio=*/1.0,
4121 /*max_async_copy_to_overlap_ratio=*/4.0,
4122 /*preferred_async_copy_to_overlap_ratio=*/1.5);
4123 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4124 buffer_interval_compare, &interval_picker);
4125
4126 // Check that both cos and sin could get their operands prefetched.
4127 const HloInstruction* cos =
4128 module->entry_computation()->GetInstructionWithName("cos");
4129 const HloInstruction* sin =
4130 module->entry_computation()->GetInstructionWithName("sin");
4131 EXPECT_THAT(sin->operand(0),
4132 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4133 op::Parameter(1)));
4134 EXPECT_THAT(cos->operand(0),
4135 op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4136 op::Parameter(0)));
4137
4138 // Sanity check that the cos' operand copy-done is scheduled earlier than
4139 // sin's operand.
4140 auto find_schedule_index = [&](const HloInstruction* instruction) {
4141 const auto& instructions =
4142 module->schedule().sequence(module->entry_computation()).instructions();
4143 for (int i = 0; i < instructions.size(); ++i) {
4144 if (instruction == instructions[i]) {
4145 return i;
4146 }
4147 }
4148 CHECK(false);
4149 return -1;
4150 };
4151 EXPECT_GT(find_schedule_index(sin->operand(0)),
4152 find_schedule_index(cos->operand(0)));
4153 }
4154
TEST_P(MemorySpaceAssignmentTest,WhileAliasedArgumentRequiredAssignmentBug)4155 TEST_P(MemorySpaceAssignmentTest, WhileAliasedArgumentRequiredAssignmentBug) {
4156 // Tests an overly pessimistic assertion when the same HloValue is passed
4157 // multiple times to a while HLO. We already handle this case that the two
4158 // arguments must alias and get the same allocation in AllocateSegment so the
4159 // assertion isn't necessary.
4160 absl::string_view hlo_string = R"(
4161 HloModule bug, is_scheduled=true
4162
4163 while_condition {
4164 param1 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4165 ROOT cond = pred[] constant(true)
4166 }
4167
4168 while_body {
4169 param2 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4170 gte2 = f32[2,4] get-tuple-element(param2), index=0
4171 gte3 = f32[2,4] get-tuple-element(param2), index=1
4172 gte4 = f32[2,4] get-tuple-element(param2), index=2
4173 add = f32[2,4] add(gte2, gte3)
4174 ROOT tuple2 = (f32[2,4], f32[2,4], f32[2,4]) tuple(add, gte3, gte4)
4175 }
4176
4177 ENTRY Entry {
4178 param0 = f32[2,4] parameter(0)
4179 a = f32[2,4] negate(param0)
4180 b = f32[2,4] negate(param0)
4181 tuple = (f32[2,4], f32[2,4], f32[2,4]) tuple(a, b, b)
4182 while = (f32[2,4], f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
4183 gte1 = f32[2,4] get-tuple-element(while), index=0
4184 gte2 = f32[2,4] get-tuple-element(while), index=1
4185 ROOT root = f32[2,4] add(gte1, gte2)
4186 }
4187 )";
4188 TF_ASSERT_OK_AND_ASSIGN(auto module,
4189 ParseAndReturnVerifiedModule(hlo_string));
4190 AssignMemorySpace(module.get());
4191 }
4192
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBug)4193 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
4194 // When we have a disallowed use (in this case tanh), we aren't allowed to
4195 // allocate this use in alternate memory. However, if we have another use
4196 // after this on the same buffer (o), this use may refer to "a" instead of the
4197 // evicted value, which is illegal because "a" will be allocated in the
4198 // alternate memory space.
4199 absl::string_view hlo_string = R"(
4200 HloModule bug, is_scheduled=true
4201
4202 ENTRY Entry {
4203 param0 = f32[8,3] parameter(0)
4204 param1 = f32[2,4] parameter(1)
4205 a = f32[8,3] cosine(param0)
4206 b = f32[2,4] negate(param1)
4207 d = f32[8,3] negate(a)
4208 c = f32[2,4] negate(b)
4209 e = f32[2,4] negate(c)
4210 f = f32[8,3] tanh(a)
4211 g = f32[2,4] negate(e)
4212 h = f32[2,4] negate(g)
4213 i = f32[2,4] negate(h)
4214 j = f32[2,4] negate(i)
4215 k = f32[2,4] negate(j)
4216 l = f32[2,4] negate(k)
4217 m = f32[2,4] negate(l)
4218 n = f32[2,4] sine(m)
4219 o = f32[8,3] negate(a)
4220 p = f32[2,4] negate(n)
4221 q = f32[8,3] add(o, f)
4222 r = f32[8,3] add(q, d)
4223 ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r)
4224 }
4225 )";
4226
4227 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4228 [](const MemorySpaceAssignment::BufferInterval& a,
4229 const MemorySpaceAssignment::BufferInterval& b) {
4230 auto get_opcode_priority = [](const HloOpcode& opcode) {
4231 switch (opcode) {
4232 case HloOpcode::kSin:
4233 return 0;
4234 case HloOpcode::kCos:
4235 return 1;
4236 case HloOpcode::kTanh:
4237 return 2;
4238 default:
4239 return 3;
4240 }
4241 };
4242
4243 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4244 get_opcode_priority(b.buffer->defining_instruction()->opcode());
4245 };
4246 TF_ASSERT_OK_AND_ASSIGN(auto module,
4247 ParseAndReturnVerifiedModule(hlo_string));
4248
4249 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4250 MemorySpaceAssignment::Options options;
4251 options.max_size_in_bytes = 128;
4252 options.alignment_in_bytes = 8;
4253 options.verify = true;
4254 options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4255 return use.instruction->opcode() != HloOpcode::kTanh;
4256 };
4257 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4258 buffer_interval_compare, &prefetch_interval_picker,
4259 options);
4260 }
4261
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBugInWhile)4262 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
4263 // Test for situations where we disallow a use (tanh in this case) in the
4264 // alternate memory space and there is a subsequent use that also requires the
4265 // buffer to be in the default memory space. In this case, the allocation in
4266 // the default memory space might not be the very last one, so we need to
4267 // search the allocation sequence and find the one in the default memory
4268 // space.
4269 absl::string_view hlo_string = R"(
4270 HloModule module, is_scheduled=true
4271
4272 while_cond {
4273 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4274 ROOT gte = pred[] get-tuple-element(p0), index=3
4275 }
4276
4277 while_body {
4278 p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4279 gte0 = f32[3]{0} get-tuple-element(p0), index=0
4280 gte1 = f32[3]{0} get-tuple-element(p0), index=1
4281 gte2 = f32[3]{0} get-tuple-element(p0), index=2
4282 gte3 = pred[] get-tuple-element(p0), index=3
4283 add = f32[3]{0} add(gte0, gte0)
4284 negate0 = f32[3]{0} negate(add)
4285 negate1 = f32[3]{0} negate(negate0)
4286 negate2 = f32[3]{0} negate(negate1)
4287 negate3 = f32[3]{0} negate(negate2)
4288 negate4 = f32[3]{0} negate(negate3)
4289 negate5 = f32[3]{0} negate(negate4)
4290 negate6 = f32[3]{0} negate(negate5)
4291 negate7 = f32[3]{0} negate(negate6)
4292 negate8 = f32[3]{0} negate(negate7)
4293 negate9 = f32[3]{0} negate(negate8)
4294 negate10 = f32[3]{0} negate(negate9)
4295 negate11 = f32[3]{0} negate(negate10)
4296 negate12 = f32[3]{0} negate(negate11)
4297 negate13 = f32[3]{0} negate(negate12)
4298 negate14 = f32[3]{0} negate(negate13)
4299 negate15 = f32[3]{0} negate(gte2)
4300 tanh = f32[3]{0} tanh(gte2)
4301 ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
4302 }
4303
4304 ENTRY entry {
4305 p0 = f32[3]{0} parameter(0)
4306 p1 = pred[] parameter(1)
4307 copy0 = f32[3]{0} copy(p0)
4308 copy1 = f32[3]{0} copy(p0)
4309 tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
4310 while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4311 ROOT gte = f32[3]{0} get-tuple-element(while), index=2
4312 }
4313 )";
4314
4315 TF_ASSERT_OK_AND_ASSIGN(auto module,
4316 ParseAndReturnVerifiedModule(hlo_string));
4317 MemorySpaceAssignment::Options options;
4318 options.max_size_in_bytes = 128;
4319 options.alignment_in_bytes = 8;
4320 options.verify = true;
4321 options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4322 return use.instruction->opcode() != HloOpcode::kTanh;
4323 };
4324 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4325 /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
4326 options);
4327 }
4328
TEST_P(MemorySpaceAssignmentTest,BitcastRoot)4329 TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
4330 // Tests against a bug where the root of entry computation is a bitcast
4331 // instruction and it ends up getting an allocation in the alternate memory.
4332 absl::string_view hlo_string = R"(
4333 HloModule primitive_computation_gather.4, is_scheduled=true
4334
4335 %while_body {
4336 %param.1 = (s32[], f32[3,3,3]) parameter(0)
4337 %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0
4338 %copy.6 = s32[] copy(s32[] %get-tuple-element.32)
4339 %constant.8 = s32[] constant(1)
4340 %add = s32[] add(s32[] %copy.6, s32[] %constant.8)
4341 %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1
4342 negate = f32[3,3,3] negate(get-tuple-element.35)
4343 ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate)
4344 }
4345
4346 %while_cond {
4347 %param.0 = (s32[], f32[3,3,3]) parameter(0)
4348 %get-tuple-element = s32[] get-tuple-element(%param.0), index=0
4349 %constant.3 = s32[] constant(3)
4350 ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT
4351 }
4352
4353 ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] {
4354 %constant.1 = s32[] constant(0)
4355 %copy.11 = s32[] copy(s32[] %constant.1)
4356 %constant = f32[] constant(0)
4357 %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={}
4358 %tuple.8 = (s32[], f32[3,10,5], s32[3,1], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast)
4359 %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body
4360 %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1
4361 ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7)
4362 }
4363 )";
4364
4365 TF_ASSERT_OK_AND_ASSIGN(auto module,
4366 ParseAndReturnVerifiedModule(hlo_string));
4367 AssignMemorySpace(module.get());
4368
4369 const HloInstruction* root = module->entry_computation()->root_instruction();
4370 EXPECT_TRUE(!root->shape().has_layout() ||
4371 root->shape().layout().memory_space() == kDefaultMemorySpace);
4372 }
4373
4374 // A mock MemorySpaceAssignmentRepacker class that accepst a map of
4375 // (start_time,offset) -> new_offset values. Using this map, the repacker
4376 // repacks the allocations to the new_offset.
4377 class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker {
4378 public:
FakeMemorySpaceAssignmentRepacker(absl::flat_hash_map<std::pair<int64,int64>,int64> & repack_map,std::function<void (absl::Span<AllocationBlock * >)> check_fun=nullptr,bool always_return_modified=false)4379 explicit FakeMemorySpaceAssignmentRepacker(
4380 absl::flat_hash_map<std::pair<int64, int64>, int64>& repack_map,
4381 std::function<void(absl::Span<AllocationBlock*>)> check_fun = nullptr,
4382 bool always_return_modified = false)
4383 : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8),
4384 repack_map_(repack_map),
4385 check_fun_(check_fun),
4386 always_return_modified_(always_return_modified) {}
4387
Repack(absl::Span<AllocationBlock * > allocations)4388 StatusOr<bool> Repack(absl::Span<AllocationBlock*> allocations) override {
4389 bool modified = false;
4390 for (AllocationBlock* block : allocations) {
4391 absl::flat_hash_set<int64> colocations;
4392 std::string colocations_str;
4393 for (const AllocationBlock* colocation : block->colocations) {
4394 absl::StrAppend(&colocations_str, colocation->id, ", ");
4395 colocations.insert(colocation->id);
4396 }
4397 VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time
4398 << ", " << block->end_time << "] size: " << block->size
4399 << " init offset: " << block->initial_offset << " colocations: {"
4400 << colocations_str << "}";
4401 auto it = repack_map_.find({block->start_time, block->initial_offset});
4402 if (it != repack_map_.end()) {
4403 modified = true;
4404 block->offset = it->second;
4405 } else {
4406 block->offset = block->initial_offset;
4407 }
4408 for (AllocationBlock* colocation : block->colocations) {
4409 if (it != repack_map_.end()) {
4410 colocation->offset = it->second;
4411 } else {
4412 colocation->offset = colocation->initial_offset;
4413 }
4414 }
4415 }
4416 if (check_fun_) {
4417 check_fun_(allocations);
4418 }
4419
4420 return always_return_modified_ || modified;
4421 }
4422
4423 private:
4424 // A map from (start_time, offset) to new_offset.
4425 absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map_;
4426 std::function<void(absl::Span<AllocationBlock*>)> check_fun_;
4427 bool always_return_modified_;
4428 };
4429
TEST_P(MemorySpaceAssignmentTest,Repack)4430 TEST_P(MemorySpaceAssignmentTest, Repack) {
4431 // We initially perform the following allocations at these offsets.
4432 //
4433 // Max memory
4434 // -------------------------------------------
4435 //
4436 //
4437 //
4438 //
4439 // +------------+
4440 // | b |
4441 // +------------+
4442 // +-------+ +------------+
4443 // | a | | n |
4444 // +-------+ +------------+
4445 // -------------------------------------------
4446 // Min memory time ->
4447 //
4448 // Next up, we try to allocate the prefetch for m. However due to
4449 // fragmentation, this won't be possible:
4450 //
4451 // Max memory
4452 // -------------------------------------------
4453 //
4454 //
4455 //
4456 // +---------+
4457 // +------------+ |
4458 // | b | | |
4459 // +------------+ |
4460 // +-------+ | | +------------+
4461 // | a | | d | | n |
4462 // +-------+ +---------+ +------------+
4463 // -------------------------------------------
4464 // Min memory time ->
4465 //
4466 // We then call repack to repack the existing allocations which allows us to
4467 // allocate the prefetch for m:
4468 //
4469 // Max memory
4470 // -------------------------------------------
4471 // +---------+
4472 // | |
4473 // | |
4474 // | |
4475 // +-------+ | |
4476 // | a | | d |
4477 // +-------+ +---------+
4478 // +------------+ +------------+
4479 // | b | | n |
4480 // +------------+ +------------+
4481 // -------------------------------------------
4482 // Min memory time ->
4483 absl::string_view hlo_string = R"(
4484 HloModule bug, is_scheduled=true
4485
4486 ENTRY Entry {
4487 param0 = f32[8,3] parameter(0)
4488 param1 = f32[2,4] parameter(1)
4489 a = f32[2,4] sine(param1)
4490 b = f32[2,4] cosine(param1)
4491 c = f32[8,3] negate(param0)
4492 j = f32[2,4] negate(a)
4493 d = f32[8,3] tanh(param0)
4494 k = f32[2,4] negate(j)
4495 l = f32[2,4] add(b, k)
4496 m = f32[8,3] negate(d)
4497 n = f32[2,4] sine(l)
4498 o = f32[8,3] negate(m)
4499 p = f32[2,4] negate(n)
4500 q = f32[8,3] negate(m)
4501 ROOT tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(p, q, o)
4502 }
4503 )";
4504
4505 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4506 [](const MemorySpaceAssignment::BufferInterval& a,
4507 const MemorySpaceAssignment::BufferInterval& b) {
4508 auto get_opcode_priority = [](const HloOpcode& opcode) {
4509 switch (opcode) {
4510 case HloOpcode::kSin:
4511 return 0;
4512 case HloOpcode::kCos:
4513 return 1;
4514 case HloOpcode::kTanh:
4515 return 2;
4516 default:
4517 return 3;
4518 }
4519 };
4520
4521 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4522 get_opcode_priority(b.buffer->defining_instruction()->opcode());
4523 };
4524 TF_ASSERT_OK_AND_ASSIGN(auto module,
4525 ParseAndReturnVerifiedModule(hlo_string));
4526
4527 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4528 absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map;
4529 // Move "a" from offset 0 to 32.
4530 repack_map[{2, 0}] = 32;
4531 // Move "b" from offset 32 to 0.
4532 repack_map[{3, 32}] = 0;
4533 FakeMemorySpaceAssignmentRepacker repacker =
4534 FakeMemorySpaceAssignmentRepacker(repack_map);
4535 MemorySpaceAssignment::Options options;
4536 options.max_size_in_bytes = 128;
4537 options.alignment_in_bytes = 8;
4538 options.verify = true;
4539 options.max_repacks = 1;
4540 options.repacker = &repacker;
4541 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4542 buffer_interval_compare, &prefetch_interval_picker,
4543 options);
4544
4545 // If repacking succeeds, we should find the buffer for d in alternate memory.
4546 const HloInstruction* d =
4547 module->entry_computation()->GetInstructionWithName("d");
4548 EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace);
4549 }
4550
TEST_P(MemorySpaceAssignmentTest,RepackExportsAliasedOffsets)4551 TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) {
4552 // This test is that we are correctly exporting aliased offsets for repacking.
4553 // In this example, the buffer produced at HLO "a" will be allocated first,
4554 // and will consist of four allocations:
4555 // 1) a produced in the alternate memory (and then evicted to the default
4556 // memory). 2) a prefetched to the alternate memory to be used by q and
4557 // while HLOs. 3) a used within the while loop body. 4) the output of while
4558 // HLO, used by u.
4559 //
4560 // Since a will be allocated first (the test is crafted to prioritize sine
4561 // HLO), all four allocations should get the same (zero) offsets. However,
4562 // while allocations 2, 3, and 4 need to be colocated with each other,
4563 // allocation 1 doesn't need to be colocated with the other three.
4564 absl::string_view hlo_string = R"(
4565 HloModule bug, is_scheduled=true
4566
4567 while_condition {
4568 param1 = (f32[2,4], f32[2,4]) parameter(0)
4569 ROOT cond = pred[] constant(true)
4570 }
4571
4572 while_body {
4573 param2 = (f32[2,4], f32[2,4]) parameter(0)
4574 gte2 = f32[2,4] get-tuple-element(param2), index=0
4575 gte3 = f32[2,4] get-tuple-element(param2), index=1
4576 add = f32[2,4] add(gte2, gte3)
4577 ROOT tuple2 = (f32[2,4], f32[2,4]) tuple(add, gte3)
4578 }
4579
4580 ENTRY Entry {
4581 param0 = f32[2,4] parameter(0)
4582 a = f32[2,4] sine(param0)
4583 b = f32[2,4] negate(a)
4584 c = f32[2,4] negate(b)
4585 d = f32[2,4] negate(c)
4586 e = f32[2,4] negate(d)
4587 f = f32[2,4] negate(e)
4588 g = f32[2,4] negate(f)
4589 h = f32[2,4] negate(g)
4590 i = f32[2,4] negate(h)
4591 j = f32[2,4] negate(i)
4592 k = f32[2,4] negate(j)
4593 l = f32[2,4] negate(k)
4594 m = f32[2,4] negate(l)
4595 n = f32[2,4] negate(m)
4596 o = f32[2,4] negate(n)
4597 p = f32[2,4] negate(o)
4598 q = f32[2,4] add(p, a)
4599 tuple = (f32[2,4], f32[2,4]) tuple(q, a)
4600 while = (f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
4601 gte0 = f32[2,4] get-tuple-element(while), index=0
4602 gte1 = f32[2,4] get-tuple-element(while), index=1
4603 r = f32[2,4] negate(gte0)
4604 s = f32[2,4] negate(r)
4605 t = f32[2,4] negate(s)
4606 constant = f32[] constant(0)
4607 broadcast = f32[8,4] broadcast(constant), dimensions={}
4608 cos = f32[8,4] cosine(broadcast)
4609 u = f32[2,4] add(t, gte1)
4610 v = f32[2,4] add(u, param0)
4611 w = f32[8,4] negate(cos)
4612 ROOT tuple3 = (f32[2,4], f32[8,4]) tuple(v, w)
4613 }
4614 )";
4615
4616 MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4617 [](const MemorySpaceAssignment::BufferInterval& a,
4618 const MemorySpaceAssignment::BufferInterval& b) {
4619 auto get_opcode_priority = [](const HloOpcode& opcode) {
4620 switch (opcode) {
4621 case HloOpcode::kSin:
4622 return 0;
4623 case HloOpcode::kCos:
4624 return 1;
4625 case HloOpcode::kTanh:
4626 return 2;
4627 default:
4628 return 3;
4629 }
4630 };
4631
4632 return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4633 get_opcode_priority(b.buffer->defining_instruction()->opcode());
4634 };
4635 TF_ASSERT_OK_AND_ASSIGN(auto module,
4636 ParseAndReturnVerifiedModule(hlo_string));
4637
4638 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4639 absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map;
4640
4641 // Expect that of the four separate allocations for the "a" buffer, the first
4642 // and the next three are in separate colocations.
4643 auto check_fun =
4644 [](absl::Span<MemorySpaceAssignmentRepacker::AllocationBlock*>
4645 allocations) {
4646 EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 ||
4647 allocations.at(0)->colocations.size() == 3);
4648 EXPECT_EQ(allocations.at(1)->colocations.size(), 3);
4649 EXPECT_EQ(allocations.at(2)->colocations.size(), 3);
4650 EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 ||
4651 allocations.at(3)->colocations.size() == 3);
4652 };
4653 FakeMemorySpaceAssignmentRepacker repacker =
4654 FakeMemorySpaceAssignmentRepacker(repack_map, check_fun);
4655 MemorySpaceAssignment::Options options;
4656 options.max_size_in_bytes = 128;
4657 options.alignment_in_bytes = 8;
4658 options.verify = true;
4659 options.max_repacks = 1;
4660 options.repacker = &repacker;
4661 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4662 buffer_interval_compare, &prefetch_interval_picker,
4663 options);
4664 }
4665
TEST_P(MemorySpaceAssignmentTest,RepackShouldntEraseRequiredAssignmentForConditionalOutput)4666 TEST_P(MemorySpaceAssignmentTest,
4667 RepackShouldntEraseRequiredAssignmentForConditionalOutput) {
4668 // This is a test case for b/171040271. Repacks erase the required assignments
4669 // (since some required assignments are inserted conditionally based on
4670 // allocation decisions), including the fact that conditional outputs are
4671 // always required to get assignments in the default memory. After repacking,
4672 // this required assignment was never added back, causing conditionals to get
4673 // alternate-memory allocations.
4674 absl::string_view hlo_string = R"(
4675 HloModule CondAllocation, is_scheduled=true
4676
4677 true_computation {
4678 p0 = (f32[3]) parameter(0)
4679 gte = f32[3] get-tuple-element(p0), index=0
4680 neg1 = f32[3] negate(gte)
4681 ROOT tuple1 = (f32[3]) tuple(neg1)
4682 }
4683
4684 false_computation {
4685 p0 = (f32[3]) parameter(0)
4686 gte = f32[3] get-tuple-element(p0), index=0
4687 neg2 = f32[3] negate(gte)
4688 ROOT tuple2 = (f32[3]) tuple(neg2)
4689 }
4690
4691 ENTRY entry {
4692 p0 = f32[3] parameter(0)
4693 p1 = pred[] parameter(1)
4694 copy = f32[3] copy(p0)
4695 tuple = (f32[3]) tuple(copy)
4696 conditional = (f32[3]) conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
4697 ROOT gte = f32[3] get-tuple-element(conditional), index=0
4698 }
4699 )";
4700 TF_ASSERT_OK_AND_ASSIGN(auto module,
4701 ParseAndReturnVerifiedModule(hlo_string));
4702 absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map;
4703 FakeMemorySpaceAssignmentRepacker repacker =
4704 FakeMemorySpaceAssignmentRepacker(repack_map, nullptr,
4705 /*always_return_modified=*/true);
4706 MemorySpaceAssignment::Options options;
4707 options.max_size_in_bytes = 128;
4708 options.alignment_in_bytes = 8;
4709 options.verify = true;
4710 options.max_repacks = 10;
4711 options.repacker = &repacker;
4712 options.repack_after_every_allocation = true;
4713 InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4714 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4715 /*buffer_interval_compare=*/{}, &prefetch_interval_picker,
4716 options);
4717 // Make sure the root of the entry computation is in the default memory space.
4718 EXPECT_EQ(module->entry_computation()
4719 ->root_instruction()
4720 ->shape()
4721 .layout()
4722 .memory_space(),
4723 kDefaultMemorySpace);
4724 }
4725
TEST_P(MemorySpaceAssignmentTest,Determinism)4726 TEST_P(MemorySpaceAssignmentTest, Determinism) {
4727 // Run memory space assignment a few times to make sure every time it compiles
4728 // to the same thing.
4729 std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
4730
4731 AssignMemorySpace(module.get());
4732 std::string module_str = module->ToString();
4733
4734 for (int i = 0; i < 10; ++i) {
4735 std::unique_ptr<HloModule> other_module = CreateEvictAndPrefetchModule();
4736 AssignMemorySpace(other_module.get());
4737 EXPECT_EQ(module_str, other_module->ToString());
4738 }
4739 }
4740
TEST_P(MemorySpaceAssignmentTest,InPlaceOp)4741 TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
4742 // Tests that in-place ops like DynamicUpdateSlice get the same allocation as
4743 // its input.
4744 absl::string_view hlo_string = R"(
4745 HloModule Module, is_scheduled=true
4746
4747 fused_computation {
4748 param0 = f32[2,3] parameter(0)
4749 constant.1 = f32[] constant(0)
4750 broadcast = f32[2,1] broadcast(constant.1), dimensions={}
4751 constant.3 = s32[] constant(0)
4752 ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
4753 }
4754
4755 ENTRY main {
4756 param = f32[2,3] parameter(0)
4757 negate = f32[2,3] negate(param)
4758 fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
4759 ROOT add = f32[2,3] add(fusion, fusion)
4760 }
4761 )";
4762
4763 TF_ASSERT_OK_AND_ASSIGN(auto module,
4764 ParseAndReturnVerifiedModule(hlo_string));
4765 auto preset_assignments = AssignMemorySpace(module.get());
4766 HloInstruction* negate_instruction =
4767 module->entry_computation()->GetInstructionWithName("negate");
4768 int64 negate_offset =
4769 GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
4770 HloInstruction* fusion_instruction =
4771 module->entry_computation()->GetInstructionWithName("fusion");
4772 int64 fusion_offset =
4773 GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
4774 // We expect negate and fusion to get the same offsets.
4775 EXPECT_EQ(negate_offset, fusion_offset);
4776 const bool allocate_across_sequential_calls = GetParam();
4777 if (allocate_across_sequential_calls) {
4778 EXPECT_NE(negate_offset, -1);
4779 }
4780 }
4781
4782 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
4783 MemorySpaceAssignmentTest,
4784 ::testing::Values(false, true));
4785
4786 using AsynchronousCopyOrderingTest = ::testing::Test;
4787
TEST_F(AsynchronousCopyOrderingTest,Simple)4788 TEST_F(AsynchronousCopyOrderingTest, Simple) {
4789 // Given asynchronous copies like the following, ensure the pipelining order
4790 // is maintained (earlier start time must have earlier end time).
4791 // 3,11 +-------+ OK
4792 // 1,8 +------+ OK
4793 // 5,14 +--------+ OK
4794 // 7,14 +------+ OK
4795 // 2,16 +-------------+ Violate
4796 // 9,12 +--+ Violate
4797 // 6,17 +----------+ Violate
4798 // 5,13 +-------+ OK (same start as 5,14)
4799 // 5,14 +--------+ OK (same as 5,14)
4800 auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
4801 AsynchronousCopyOrdering ordering;
4802 EXPECT_FALSE(ordering.ViolatesOrdering(3, 11));
4803 ordering.AddCopy({3, 11, alternate_mem_space});
4804 EXPECT_FALSE(ordering.ViolatesOrdering(1, 8));
4805 ordering.AddCopy({1, 8, alternate_mem_space});
4806 EXPECT_FALSE(ordering.ViolatesOrdering(5, 14));
4807 ordering.AddCopy({5, 14, alternate_mem_space});
4808 EXPECT_FALSE(ordering.ViolatesOrdering(7, 14));
4809 ordering.AddCopy({7, 14, alternate_mem_space});
4810 EXPECT_TRUE(ordering.ViolatesOrdering(2, 16));
4811 EXPECT_TRUE(ordering.ViolatesOrdering(9, 12));
4812 EXPECT_TRUE(ordering.ViolatesOrdering(6, 17));
4813 EXPECT_FALSE(ordering.ViolatesOrdering(5, 13));
4814 ordering.AddCopy({5, 13, alternate_mem_space});
4815 EXPECT_FALSE(ordering.ViolatesOrdering(5, 14));
4816 ordering.AddCopy({5, 14, alternate_mem_space});
4817 }
4818
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTest)4819 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) {
4820 HloComputation::Builder builder(TestName());
4821
4822 constexpr int kBatch = 8;
4823 constexpr int kFeature = 8;
4824 constexpr int kOutput = 2;
4825
4826 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4827 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4828 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4829 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4830 HloInstruction* param = builder.AddInstruction(
4831 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
4832
4833 auto lhs = builder.AddInstruction(
4834 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
4835 auto rhs = builder.AddInstruction(
4836 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
4837
4838 DotDimensionNumbers dot_dnums;
4839 dot_dnums.add_lhs_contracting_dimensions(1);
4840 dot_dnums.add_rhs_contracting_dimensions(0);
4841 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4842 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4843
4844 auto module = CreateNewVerifiedModule();
4845 HloComputation* computation = module->AddEntryComputation(builder.Build());
4846
4847 HloSchedule schedule(module.get());
4848 schedule.set_sequence(computation, {param, lhs, rhs, dot});
4849 TF_CHECK_OK(module->set_schedule(schedule));
4850
4851 AssignMemorySpace(module.get());
4852
4853 auto cross_program_prefetches = module->CrossProgramPrefetches();
4854 EXPECT_EQ(cross_program_prefetches.size(), 1);
4855 if (!cross_program_prefetches.empty()) {
4856 EXPECT_EQ(cross_program_prefetches[0].first, 0);
4857 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
4858 }
4859 }
4860
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchBitcastTest)4861 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) {
4862 HloComputation::Builder builder(TestName());
4863
4864 constexpr int kBatch = 8;
4865 constexpr int kFeature = 8;
4866 constexpr int kOutput = 2;
4867
4868 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4869 auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature});
4870 auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4871 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4872 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4873 HloInstruction* param = builder.AddInstruction(
4874 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
4875
4876 auto lhs = builder.AddInstruction(
4877 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
4878 auto rhs = builder.AddInstruction(
4879 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
4880
4881 auto bitcast =
4882 builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs));
4883
4884 DotDimensionNumbers dot_dnums;
4885 dot_dnums.add_lhs_contracting_dimensions(1);
4886 dot_dnums.add_rhs_contracting_dimensions(0);
4887 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4888 result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2)));
4889
4890 auto module = CreateNewVerifiedModule();
4891 HloComputation* computation = module->AddEntryComputation(builder.Build());
4892
4893 HloSchedule schedule(module.get());
4894 schedule.set_sequence(computation, {param, lhs, rhs, bitcast, dot});
4895 TF_CHECK_OK(module->set_schedule(schedule));
4896
4897 AssignMemorySpace(module.get());
4898
4899 auto cross_program_prefetches = module->CrossProgramPrefetches();
4900 EXPECT_EQ(cross_program_prefetches.size(), 1);
4901 if (!cross_program_prefetches.empty()) {
4902 EXPECT_EQ(cross_program_prefetches[0].first, 0);
4903 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
4904 }
4905 }
4906
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNestedTupleTest)4907 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) {
4908 HloComputation::Builder builder(TestName());
4909
4910 constexpr int kBatch = 8;
4911 constexpr int kFeature = 8;
4912 constexpr int kOutput = 2;
4913
4914 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4915 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4916 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4917 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4918 auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape});
4919 HloInstruction* param = builder.AddInstruction(
4920 HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0"));
4921
4922 auto gte = builder.AddInstruction(
4923 HloInstruction::CreateGetTupleElement(tuple_shape, param, 0));
4924
4925 auto lhs = builder.AddInstruction(
4926 HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0));
4927 auto rhs = builder.AddInstruction(
4928 HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1));
4929
4930 DotDimensionNumbers dot_dnums;
4931 dot_dnums.add_lhs_contracting_dimensions(1);
4932 dot_dnums.add_rhs_contracting_dimensions(0);
4933 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4934 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4935
4936 auto module = CreateNewVerifiedModule();
4937 HloComputation* computation = module->AddEntryComputation(builder.Build());
4938
4939 HloSchedule schedule(module.get());
4940 schedule.set_sequence(computation, {param, gte, lhs, rhs, dot});
4941 TF_CHECK_OK(module->set_schedule(schedule));
4942
4943 AssignMemorySpace(module.get());
4944
4945 auto cross_program_prefetches = module->CrossProgramPrefetches();
4946 EXPECT_EQ(cross_program_prefetches.size(), 0);
4947 }
4948
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchUnusedParamTest)4949 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) {
4950 HloComputation::Builder builder(TestName());
4951
4952 constexpr int kFeature = 8;
4953 constexpr int kOutput = 2;
4954
4955 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4956 HloInstruction* param = builder.AddInstruction(
4957 HloInstruction::CreateParameter(0, rhs_shape, "p0"));
4958
4959 auto module = CreateNewVerifiedModule();
4960 HloComputation* computation = module->AddEntryComputation(builder.Build());
4961
4962 HloSchedule schedule(module.get());
4963 schedule.set_sequence(computation, {param});
4964 TF_CHECK_OK(module->set_schedule(schedule));
4965
4966 AssignMemorySpace(module.get());
4967
4968 auto cross_program_prefetches = module->CrossProgramPrefetches();
4969 EXPECT_EQ(cross_program_prefetches.size(), 0);
4970 }
4971
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTooBigTest)4972 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) {
4973 HloComputation::Builder builder(TestName());
4974
4975 constexpr int kBatch = 8;
4976 constexpr int kFeature = 8;
4977 constexpr int kOutput = 8;
4978
4979 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4980 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4981 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4982 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4983 HloInstruction* param = builder.AddInstruction(
4984 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
4985
4986 auto lhs = builder.AddInstruction(
4987 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
4988 auto rhs = builder.AddInstruction(
4989 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
4990
4991 DotDimensionNumbers dot_dnums;
4992 dot_dnums.add_lhs_contracting_dimensions(1);
4993 dot_dnums.add_rhs_contracting_dimensions(0);
4994 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4995 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4996
4997 auto module = CreateNewVerifiedModule();
4998 HloComputation* computation = module->AddEntryComputation(builder.Build());
4999
5000 HloSchedule schedule(module.get());
5001 schedule.set_sequence(computation, {param, lhs, rhs, dot});
5002 TF_CHECK_OK(module->set_schedule(schedule));
5003
5004 AssignMemorySpace(module.get());
5005
5006 auto cross_program_prefetches = module->CrossProgramPrefetches();
5007 EXPECT_EQ(cross_program_prefetches.size(), 0);
5008 }
5009
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchFusionTest)5010 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
5011 HloComputation::Builder builder(TestName());
5012
5013 constexpr int kBatch = 2;
5014 constexpr int kFeature = 2;
5015 constexpr int kOutput = 2;
5016
5017 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
5018 auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
5019 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
5020 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
5021
5022 auto module = CreateNewVerifiedModule();
5023 HloComputation::Builder fusion_builder("fusion");
5024 {
5025 HloInstruction* param = fusion_builder.AddInstruction(
5026 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
5027 auto lhs = fusion_builder.AddInstruction(
5028 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
5029 auto rhs = fusion_builder.AddInstruction(
5030 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
5031 DotDimensionNumbers dot_dnums;
5032 dot_dnums.add_lhs_contracting_dimensions(1);
5033 dot_dnums.add_rhs_contracting_dimensions(0);
5034 auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
5035 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5036 (void)dot;
5037 }
5038 HloComputation* fusion_computation =
5039 module->AddEmbeddedComputation(fusion_builder.Build());
5040
5041 auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
5042 LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
5043 auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
5044 LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
5045 HloInstruction* tuple = builder.AddInstruction(
5046 HloInstruction::CreateTuple({activations, weights}));
5047 HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
5048 result_shape, HloInstruction::FusionKind::kCustom, {tuple},
5049 fusion_computation));
5050
5051 HloComputation* computation = module->AddEntryComputation(builder.Build());
5052
5053 HloSchedule schedule(module.get());
5054 schedule.set_sequence(computation, {activations, weights, tuple, fusion});
5055 TF_CHECK_OK(module->set_schedule(schedule));
5056
5057 AssignMemorySpace(module.get());
5058
5059 auto cross_program_prefetches = module->CrossProgramPrefetches();
5060 EXPECT_EQ(cross_program_prefetches.size(), 0);
5061 }
5062
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchPinnedTest)5063 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) {
5064 HloComputation::Builder builder(TestName());
5065
5066 constexpr int kBatch = 8;
5067 constexpr int kFeature = 8;
5068 constexpr int kOutput = 2;
5069
5070 auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
5071 auto rhs_shape = ShapeUtil::MakeShapeWithLayout(
5072 F32, {kFeature, kOutput},
5073 /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
5074 kAlternateMemorySpace);
5075 auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
5076 auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
5077 HloInstruction* param = builder.AddInstruction(
5078 HloInstruction::CreateParameter(0, tuple_shape, "p0"));
5079
5080 auto lhs = builder.AddInstruction(
5081 HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
5082 auto rhs = builder.AddInstruction(
5083 HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
5084
5085 DotDimensionNumbers dot_dnums;
5086 dot_dnums.add_lhs_contracting_dimensions(1);
5087 dot_dnums.add_rhs_contracting_dimensions(0);
5088 auto dot = builder.AddInstruction(HloInstruction::CreateDot(
5089 result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5090
5091 auto module = CreateNewVerifiedModule();
5092 HloComputation* computation = module->AddEntryComputation(builder.Build());
5093
5094 HloSchedule schedule(module.get());
5095 schedule.set_sequence(computation, {param, lhs, rhs, dot});
5096 TF_CHECK_OK(module->set_schedule(schedule));
5097
5098 AssignMemorySpace(module.get());
5099
5100 auto cross_program_prefetches = module->CrossProgramPrefetches();
5101 EXPECT_EQ(cross_program_prefetches.size(), 0);
5102 }
5103
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchReuse)5104 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) {
5105 // This test is for checking if the cross-program-prefetched buffer is freed
5106 // after its last use and there is an end-of-program prefetch.
5107 absl::string_view hlo_string = R"(
5108 HloModule cross_program_prefetch, is_scheduled=true
5109
5110 ENTRY CrossProgramPrefetch {
5111 p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
5112 get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
5113 get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
5114 dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
5115 negate.1 = f32[8,2]{1,0} negate(dot)
5116 negate.2 = f32[8,2]{1,0} negate(negate.1)
5117 negate.3 = f32[8,2]{1,0} negate(negate.2)
5118 negate.4 = f32[8,2]{1,0} negate(negate.3)
5119 negate.5 = f32[8,2]{1,0} negate(negate.4)
5120 negate.6 = f32[8,2]{1,0} negate(negate.5)
5121 negate.7 = f32[8,2]{1,0} negate(negate.6)
5122 negate.8 = f32[8,2]{1,0} negate(negate.7)
5123 ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
5124 }
5125 )";
5126 TF_ASSERT_OK_AND_ASSIGN(auto module,
5127 ParseAndReturnVerifiedModule(hlo_string));
5128
5129 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5130 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
5131
5132 auto cross_program_prefetches = module->CrossProgramPrefetches();
5133 EXPECT_EQ(cross_program_prefetches.size(), 1);
5134 if (!cross_program_prefetches.empty()) {
5135 EXPECT_EQ(cross_program_prefetches[0].first, 0);
5136 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
5137 }
5138
5139 TF_ASSERT_OK_AND_ASSIGN(
5140 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
5141 HloDataflowAnalysis::Run(*module));
5142 const HloValue& cross_program_prefetched_value =
5143 dataflow_analysis->GetValueDefinedAt(
5144 module->entry_computation()->parameter_instruction(0), {1});
5145 // Expect that there are two prefetches that use this value, one is the
5146 // cross-program prefetch, the other is the end-of-program prefetch.
5147 auto is_cross_program_prefetch = [](const HloUse& use) {
5148 return use.instruction->opcode() == HloOpcode::kCopyStart &&
5149 use.instruction->is_cross_program_prefetch();
5150 };
5151 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5152 is_cross_program_prefetch),
5153 1);
5154 auto is_end_of_program_prefetch = [](const HloUse& use) {
5155 return use.instruction->opcode() == HloOpcode::kCopyStart &&
5156 !use.instruction->is_cross_program_prefetch();
5157 };
5158 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5159 is_end_of_program_prefetch),
5160 1);
5161 }
5162
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNoReuse)5163 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
5164 // This tests the scenario that the cross-program-prefetched buffer is used
5165 // again close to the end of the computation. In this case, it is better not
5166 // to free the buffer.
5167 absl::string_view hlo_string = R"(
5168 HloModule cross_program_prefetch, is_scheduled=true
5169
5170 ENTRY CrossProgramPrefetch {
5171 p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
5172 get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
5173 get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
5174 dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
5175 negate.1 = f32[8,2]{1,0} negate(dot)
5176 negate.2 = f32[8,2]{1,0} negate(negate.1)
5177 negate.3 = f32[8,2]{1,0} negate(negate.2)
5178 negate.4 = f32[8,2]{1,0} negate(negate.3)
5179 negate.5 = f32[8,2]{1,0} negate(negate.4)
5180 negate.6 = f32[8,2]{1,0} negate(negate.5)
5181 negate.7 = f32[8,2]{1,0} negate(negate.6)
5182 negate.8 = f32[8,2]{1,0} negate(negate.7)
5183 ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
5184 }
5185 )";
5186 TF_ASSERT_OK_AND_ASSIGN(auto module,
5187 ParseAndReturnVerifiedModule(hlo_string));
5188
5189 AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5190 /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
5191
5192 auto cross_program_prefetches = module->CrossProgramPrefetches();
5193 EXPECT_EQ(cross_program_prefetches.size(), 1);
5194 if (!cross_program_prefetches.empty()) {
5195 EXPECT_EQ(cross_program_prefetches[0].first, 0);
5196 EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
5197 }
5198
5199 TF_ASSERT_OK_AND_ASSIGN(
5200 std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
5201 HloDataflowAnalysis::Run(*module));
5202 const HloValue& cross_program_prefetched_value =
5203 dataflow_analysis->GetValueDefinedAt(
5204 module->entry_computation()->parameter_instruction(0), {1});
5205 // Expect that there is one prefetch that use this value, the cross-program
5206 // prefetch. There shouldn't be an end-of-program prefetch.
5207 auto is_cross_program_prefetch = [](const HloUse& use) {
5208 return use.instruction->opcode() == HloOpcode::kCopyStart &&
5209 use.instruction->is_cross_program_prefetch();
5210 };
5211 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5212 is_cross_program_prefetch),
5213 1);
5214 auto is_end_of_program_prefetch = [](const HloUse& use) {
5215 return use.instruction->opcode() == HloOpcode::kCopyStart &&
5216 !use.instruction->is_cross_program_prefetch();
5217 };
5218 EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5219 is_end_of_program_prefetch),
5220 0);
5221 }
5222
5223 using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
5224
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrder)5225 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
5226 absl::string_view hlo_string = R"(
5227 HloModule bug, is_scheduled=true
5228
5229 ENTRY Entry {
5230 param0 = f32[2,4] parameter(0)
5231 a = f32[2,4] negate(param0)
5232 b = f32[2,4] negate(a)
5233 c = f32[2,4] negate(b)
5234 d = f32[2,4] negate(c)
5235 e = f32[2,4] negate(d)
5236 f = f32[2,4] negate(e)
5237 g = f32[2,4] negate(f)
5238 h = f32[2,4] negate(g)
5239 i = f32[2,4] negate(h)
5240 j = f32[2,4] negate(i)
5241 k = f32[2,4] negate(j)
5242 l = f32[2,4] negate(k)
5243 m = f32[2,4] negate(l)
5244 n = f32[2,4] negate(m)
5245 o = f32[2,4] negate(n)
5246 p = f32[2,4] negate(o)
5247 q = f32[2,4] negate(p)
5248 r = f32[2,4] negate(q)
5249 s = f32[2,4] negate(r)
5250 t = f32[2,4] negate(s)
5251 u = f32[2,4] negate(t)
5252 ROOT v = f32[2,4] add(u, param0)
5253 }
5254 )";
5255 TF_ASSERT_OK_AND_ASSIGN(auto module,
5256 ParseAndReturnVerifiedModule(hlo_string));
5257
5258 HloCostAnalysis hlo_cost_analysis(ShapeSize);
5259 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5260 FakeMemorySpaceAssignmentCostAnalysis::Create(
5261 hlo_cost_analysis, *module));
5262 CostAnalysisPrefetchIntervalPicker interval_picker(
5263 *cost_analysis,
5264 /*min_async_copy_to_overlap_ratio=*/1.0,
5265 /*max_async_copy_to_overlap_ratio=*/4.0,
5266 /*preferred_async_copy_to_overlap_ratio=*/2.0);
5267
5268 HloInstruction* root = module->entry_computation()->root_instruction();
5269 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
5270 interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22);
5271
5272 // Expect that the first interval is (15, 22), which has elapsed time of 6.0,
5273 // twice of the async copy elased (3.0). Then we expect that intervals will be
5274 // visited in alternating increasing and decreasing orders until hitting the
5275 // min and max async copy overlap ratios, which are the intervals (18, 22)
5276 // and (9, 22) respectively.
5277 LOG(INFO) << interval_picker.ToDebugString();
5278 EXPECT_EQ(interval_picker.Next(), 15);
5279 LOG(INFO) << interval_picker.ToDebugString();
5280 EXPECT_EQ(interval_picker.Next(), 16);
5281 LOG(INFO) << interval_picker.ToDebugString();
5282 EXPECT_EQ(interval_picker.Next(), 14);
5283 LOG(INFO) << interval_picker.ToDebugString();
5284 EXPECT_EQ(interval_picker.Next(), 17);
5285 LOG(INFO) << interval_picker.ToDebugString();
5286 EXPECT_EQ(interval_picker.Next(), 13);
5287 LOG(INFO) << interval_picker.ToDebugString();
5288 EXPECT_EQ(interval_picker.Next(), 18); // Min async overlap ratio reached.
5289 LOG(INFO) << interval_picker.ToDebugString();
5290 EXPECT_EQ(interval_picker.Next(), 12);
5291 LOG(INFO) << interval_picker.ToDebugString();
5292 EXPECT_EQ(interval_picker.Next(), 11);
5293 LOG(INFO) << interval_picker.ToDebugString();
5294 EXPECT_EQ(interval_picker.Next(), 10);
5295 LOG(INFO) << interval_picker.ToDebugString();
5296 EXPECT_EQ(interval_picker.Next(), 9); // Max async overlap ratio reached.
5297 LOG(INFO) << interval_picker.ToDebugString();
5298 EXPECT_TRUE(interval_picker.Done());
5299
5300 // Expect that if the time between start_time and end_time is too short, there
5301 // won't be any available intervals.
5302 interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22);
5303 LOG(INFO) << interval_picker.ToDebugString();
5304 EXPECT_TRUE(interval_picker.Done());
5305 }
5306
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrderWhile)5307 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) {
5308 absl::string_view hlo_string = R"(
5309 HloModule bug, is_scheduled=true
5310
5311 while_condition {
5312 param1 = (f32[2,4]) parameter(0) // 19
5313 ROOT cond = pred[] constant(true) // 20
5314 }
5315
5316 while_body {
5317 param2 = (f32[2,4]) parameter(0) // 21
5318 gte2 = f32[2,4] get-tuple-element(param2), index=0 // 22
5319 add = f32[2,4] add(gte2, gte2) // 23
5320 ROOT tuple2 = (f32[2,4]) tuple(add) // 24
5321 }
5322
5323 ENTRY Entry {
5324 param0 = f32[2,4] parameter(0) // 0
5325 a = f32[2,4] negate(param0) // 1
5326 b = f32[2,4] negate(a) // 2
5327 c = f32[2,4] negate(b) // 3
5328 d = f32[2,4] negate(c) // 4
5329 e = f32[2,4] negate(d) // 5
5330 f = f32[2,4] negate(e) // 6
5331 g = f32[2,4] negate(f) // 7
5332 h = f32[2,4] negate(g) // 8
5333 i = f32[2,4] negate(h) // 9
5334 j = f32[2,4] negate(i) // 10
5335 k = f32[2,4] negate(j) // 11
5336 l = f32[2,4] negate(k) // 12
5337 m = f32[2,4] negate(l) // 13
5338 n = f32[2,4] negate(m) // 14
5339 o = f32[2,4] negate(n) // 15
5340 p = f32[2,4] negate(o) // 16
5341 q = f32[2,4] negate(p) // 17
5342 tuple = (f32[2,4]) tuple(q) // 18
5343 while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body // 25
5344 gte1 = f32[2,4] get-tuple-element(while), index=0 // 26
5345 r = f32[2,4] negate(gte1) // 27
5346 s = f32[2,4] negate(r) // 28
5347 t = f32[2,4] negate(s) // 29
5348 u = f32[2,4] negate(t) // 30
5349 ROOT v = f32[2,4] add(u, param0) // 31
5350 }
5351 )";
5352 TF_ASSERT_OK_AND_ASSIGN(auto module,
5353 ParseAndReturnVerifiedModule(hlo_string));
5354
5355 HloCostAnalysis hlo_cost_analysis(ShapeSize);
5356 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5357 FakeMemorySpaceAssignmentCostAnalysis::Create(
5358 hlo_cost_analysis, *module));
5359 CostAnalysisPrefetchIntervalPicker interval_picker(
5360 *cost_analysis,
5361 /*min_async_copy_to_overlap_ratio=*/1.0,
5362 /*max_async_copy_to_overlap_ratio=*/12.0,
5363 /*preferred_async_copy_to_overlap_ratio=*/2.0);
5364
5365 HloInstruction* root = module->entry_computation()->root_instruction();
5366 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
5367 interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31);
5368
5369 // Because there are while loop computations between [19, 24], we ensure that
5370 // the interval picker avoids this interval.
5371 LOG(INFO) << interval_picker.ToDebugString();
5372 EXPECT_EQ(interval_picker.Next(), 25);
5373 LOG(INFO) << interval_picker.ToDebugString();
5374 EXPECT_EQ(interval_picker.Next(), 26);
5375 LOG(INFO) << interval_picker.ToDebugString();
5376 EXPECT_EQ(interval_picker.Next(), 18);
5377 LOG(INFO) << interval_picker.ToDebugString();
5378 EXPECT_EQ(interval_picker.Next(), 27); // Min async overlap ratio reached.
5379 LOG(INFO) << interval_picker.ToDebugString();
5380 EXPECT_EQ(interval_picker.Next(), 17); // Max async overlap ratio reached.
5381 LOG(INFO) << interval_picker.ToDebugString();
5382 EXPECT_TRUE(interval_picker.Done());
5383 }
5384
TEST_F(CostAnalysisPrefetchIntervalPickerTest,NestedWhile)5385 TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) {
5386 // This test is to check against a bug where we didn't assign
5387 // while_nest_level_ for while instructions, and defaulting to 0. This could
5388 // cause the prefetch interval logic to think a nested while instruction is
5389 // the same level as the outermost computation.
5390 absl::string_view hlo_string = R"(
5391 HloModule bug, is_scheduled=true
5392
5393 while_condition.2 {
5394 param1 = (f32[2,4]) parameter(0) // 11
5395 ROOT cond = pred[] constant(true) // 12
5396 }
5397
5398 while_body.2 {
5399 param2 = (f32[2,4]) parameter(0) // 13
5400 gte2 = f32[2,4] get-tuple-element(param2), index=0 // 14
5401 add = f32[2,4] add(gte2, gte2) // 15
5402 ROOT tuple2 = (f32[2,4]) tuple(add) // 16
5403 }
5404
5405 while_condition.1 {
5406 param3 = (f32[2,4]) parameter(0) // 5
5407 ROOT cond = pred[] constant(true) // 6
5408 }
5409
5410 while_body.1 {
5411 param4 = (f32[2,4]) parameter(0) // 7
5412 gte1 = f32[2,4] get-tuple-element(param4), index=0 // 8
5413 add1 = f32[2,4] add(gte1, gte1) // 9
5414 tuple1 = (f32[2,4]) tuple(add1) // 10
5415 while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2 // 17
5416 gte2 = f32[2,4] get-tuple-element(while), index=0 // 18
5417 add2 = f32[2,4] add(gte2, gte2) // 19
5418 ROOT tuple2 = (f32[2,4]) tuple(add2) // 20
5419 }
5420
5421 ENTRY Entry {
5422 param0 = f32[2,4] parameter(0) // 0
5423 a = f32[2,4] negate(param0) // 1
5424 b = f32[2,4] negate(a) // 2
5425 c = f32[2,4] negate(b) // 3
5426 tuple = (f32[2,4]) tuple(c) // 4
5427 while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1 // 21
5428 gte1 = f32[2,4] get-tuple-element(while), index=0 // 22
5429 ROOT root = f32[2,4] add(gte1, param0) // 23
5430 }
5431 )";
5432 TF_ASSERT_OK_AND_ASSIGN(auto module,
5433 ParseAndReturnVerifiedModule(hlo_string));
5434
5435 HloCostAnalysis hlo_cost_analysis(ShapeSize);
5436 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5437 FakeMemorySpaceAssignmentCostAnalysis::Create(
5438 hlo_cost_analysis, *module));
5439 CostAnalysisPrefetchIntervalPicker interval_picker(
5440 *cost_analysis,
5441 /*min_async_copy_to_overlap_ratio=*/1.0,
5442 /*max_async_copy_to_overlap_ratio=*/12.0,
5443 /*preferred_async_copy_to_overlap_ratio=*/2.0);
5444
5445 HloInstruction* root = module->entry_computation()->root_instruction();
5446 const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
5447 const Shape& shape = root->operand(1)->shape();
5448
5449 // We expect the root's latest prefetch start time to be before the while loop
5450 // (logical time 4).
5451 EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
5452 /*end_time=*/23, &use),
5453 4);
5454 }
5455
TEST_F(CostAnalysisPrefetchIntervalPickerTest,ConsecutiveConditionals)5456 TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) {
5457 // This is a test for b/170668492, where prefetching for consecutive
5458 // conditionals can cause the prefetch to start in the conditional's
5459 // computation.
5460 absl::string_view hlo_string = R"(
5461 HloModule bug, is_scheduled=true
5462
5463 true_computation.0 {
5464 p0 = (f32[3]{0}) parameter(0) // 5
5465 gte = f32[3]{0} get-tuple-element(p0), index=0 // 6
5466 ROOT neg1 = f32[3]{0} negate(gte) // 7
5467 }
5468
5469 false_computation.0 {
5470 p0 = (f32[3]{0}) parameter(0) // 8
5471 gte = f32[3]{0} get-tuple-element(p0), index=0 // 9
5472 ROOT neg2 = f32[3]{0} negate(gte) // 10
5473 }
5474
5475 true_computation.1 {
5476 p0 = (f32[3]{0}) parameter(0) // 12
5477 gte = f32[3]{0} get-tuple-element(p0), index=0 // 13
5478 ROOT neg1 = f32[3]{0} negate(gte) // 14
5479 }
5480
5481 false_computation.1 {
5482 p0 = (f32[3]{0}) parameter(0) // 15
5483 gte = f32[3]{0} get-tuple-element(p0), index=0 // 16
5484 ROOT neg2 = f32[3]{0} negate(gte) // 17
5485 }
5486
5487 ENTRY entry {
5488 p0 = f32[3]{0} parameter(0) // 0
5489 p1 = f32[3]{0} parameter(1) // 1
5490 p2 = pred[] parameter(2) // 2
5491 tuple0 = (f32[3]{0}) tuple(p0) // 3
5492 tuple1 = (f32[3]{0}) tuple(p1) // 4
5493 conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0 // 11
5494 conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1 // 18
5495 ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1) // 19
5496 }
5497 )";
5498 TF_ASSERT_OK_AND_ASSIGN(auto module,
5499 ParseAndReturnVerifiedModule(hlo_string));
5500
5501 HloCostAnalysis hlo_cost_analysis(ShapeSize);
5502 TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5503 FakeMemorySpaceAssignmentCostAnalysis::Create(
5504 hlo_cost_analysis, *module));
5505 CostAnalysisPrefetchIntervalPicker interval_picker(
5506 *cost_analysis,
5507 /*min_async_copy_to_overlap_ratio=*/1.0,
5508 /*max_async_copy_to_overlap_ratio=*/12.0,
5509 /*preferred_async_copy_to_overlap_ratio=*/2.0);
5510
5511 LOG(INFO) << module->ToString();
5512
5513 HloInstruction* conditional1 =
5514 module->entry_computation()->GetInstructionWithName("conditional1");
5515 const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}};
5516 const Shape& shape =
5517 module->entry_computation()->parameter_instruction(0)->shape();
5518
5519 // Expect that the prefetch to start before conditional0's called
5520 // computations.
5521 EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
5522 /*end_time=*/11, &use),
5523 5);
5524 }
5525
5526 } // namespace
5527 } // namespace xla
5528