1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 
22 namespace xla {
23 namespace {
24 
25 namespace op = xla::testing::opcode_matchers;
26 
27 constexpr int64 kPointerSize = 8;
28 constexpr float kAsyncCopyBandwidth = 100;
29 constexpr float kAlternateMemBandwidth = 1000;
30 constexpr float kBytesPerSecond = 100;
31 constexpr float kFlopsPerSecond = 1000;
32 constexpr float kTranscendentalsPerSecond = 10;
33 
ShapeSize(const Shape & shape)34 int64 ShapeSize(const Shape& shape) {
35   return ShapeUtil::ByteSizeOf(shape, kPointerSize);
36 }
37 
38 class MemorySpaceAssignmentTest : public HloTestBase,
39                                   public ::testing::WithParamInterface<bool> {
40  protected:
41   // We use the following two memory space values to describe the default (slow
42   // and large) and alternate (fast and small) memory spaces.
43   const int64 kDefaultMemorySpace = 0;
44   const int64 kAlternateMemorySpace = 1;
45 
AssignMemorySpaceUsingCostAnalysis(HloModule * module)46   std::unique_ptr<PresetAssignments> AssignMemorySpaceUsingCostAnalysis(
47       HloModule* module) {
48     HloCostAnalysis hlo_cost_analysis(ShapeSize);
49     hlo_cost_analysis.set_flops_per_second(kFlopsPerSecond);
50     hlo_cost_analysis.set_bytes_per_second(kBytesPerSecond);
51     hlo_cost_analysis.set_transcendentals_per_second(kTranscendentalsPerSecond);
52     for (HloComputation* computation : module->MakeNonfusionComputations()) {
53       TF_CHECK_OK(computation->Accept(&hlo_cost_analysis));
54     }
55     auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
56     auto cost_analysis = MemorySpaceAssignmentCostAnalysis::Create(
57                              hlo_cost_analysis, kAsyncCopyBandwidth,
58                              kAlternateMemBandwidth, *module)
59                              .ValueOrDie();
60     CostAnalysisPrefetchIntervalPicker prefetch_interval_picker(
61         CostAnalysisPrefetchIntervalPicker(
62             *cost_analysis, /*min_async_copy_to_overlap_ratio=*/0.8,
63             /*max_async_copy_to_overlap_ratio=*/10.0,
64             /*preferred_async_copy_to_overlap_ratio=*/1.5));
65     return AssignMemorySpace(
66         module, /*max_outstanding_async_copies=*/-1,
67         MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
68             *cost_analysis, &cache_),
69         &prefetch_interval_picker);
70   }
71 
AssignMemorySpace(HloModule * module,int64 max_outstanding_async_copies=-1,int64 max_prefetch_interval=10,int64 min_prefetch_interval=2,absl::optional<MemorySpaceAssignment::Options> options=absl::nullopt)72   std::unique_ptr<PresetAssignments> AssignMemorySpace(
73       HloModule* module, int64 max_outstanding_async_copies = -1,
74       int64 max_prefetch_interval = 10, int64 min_prefetch_interval = 2,
75       absl::optional<MemorySpaceAssignment::Options> options = absl::nullopt) {
76     InstructionCountPrefetchIntervalPicker prefetch_interval_picker(
77         min_prefetch_interval, max_prefetch_interval);
78     return AssignMemorySpace(module, max_outstanding_async_copies,
79                              /*buffer_interval_compare=*/{},
80                              &prefetch_interval_picker, options);
81   }
82 
AssignMemorySpace(HloModule * module,int64 max_outstanding_async_copies,absl::optional<MemorySpaceAssignment::BufferIntervalCompare> buffer_interval_compare,PrefetchIntervalPicker * prefetch_interval_picker,absl::optional<MemorySpaceAssignment::Options> memory_space_assignment_options=absl::nullopt)83   std::unique_ptr<PresetAssignments> AssignMemorySpace(
84       HloModule* module, int64 max_outstanding_async_copies,
85       absl::optional<MemorySpaceAssignment::BufferIntervalCompare>
86           buffer_interval_compare,
87       PrefetchIntervalPicker* prefetch_interval_picker,
88       absl::optional<MemorySpaceAssignment::Options>
89           memory_space_assignment_options = absl::nullopt) {
90     auto size_fn = [](const BufferValue& buffer) {
91       return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
92     };
93 
94     auto is_allowed_in_alternate_mem = [](const HloValue& value) {
95       // Check if the value belongs to the entry computation.
96       HloInstruction* instruction = value.instruction();
97       HloComputation* computation = instruction->parent();
98       bool in_entry_computation =
99           (computation == computation->parent()->entry_computation());
100       if (in_entry_computation &&
101           instruction->opcode() == HloOpcode::kParameter) {
102         return false;
103       }
104       return true;
105     };
106 
107     // Only check parameters in default memory if the original module didn't
108     // have the parameters in alternate memory.
109     bool check_parameters_in_default_memory = true;
110     for (const HloInstruction* parameter :
111          module->entry_computation()->parameter_instructions()) {
112       ShapeUtil::ForEachSubshape(
113           parameter->shape(),
114           [&](const Shape& subshape, const ShapeIndex& /*index*/) {
115             if (subshape.has_layout() &&
116                 subshape.layout().memory_space() == kAlternateMemorySpace) {
117               check_parameters_in_default_memory = false;
118             }
119           });
120     }
121 
122     MemorySpaceAssignment::Options options;
123     if (memory_space_assignment_options) {
124       options = *memory_space_assignment_options;
125     } else {
126       options.max_size_in_bytes = 128;
127       options.alignment_in_bytes = 8;
128       options.verify = true;
129     }
130 
131     options.alternate_memory_space = kAlternateMemorySpace;
132     options.buffer_interval_compare = buffer_interval_compare;
133     options.prefetch_interval_picker = prefetch_interval_picker;
134     options.size_fn = size_fn;
135     options.is_allowed_in_alternate_mem_fn = is_allowed_in_alternate_mem;
136     options.max_outstanding_prefetches = max_outstanding_async_copies;
137     options.max_outstanding_evictions = max_outstanding_async_copies;
138     options.allocate_across_sequential_calls = GetParam();
139 
140     auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
141     std::unique_ptr<HloLiveRange> hlo_live_range =
142         HloLiveRange::Run(module->schedule(), *alias_analysis,
143                           module->entry_computation())
144             .ValueOrDie();
145 
146     std::unique_ptr<PresetAssignments> preset_assignments =
147         MemorySpaceAssignment::Run(module, *hlo_live_range, *alias_analysis,
148                                    options)
149             .ValueOrDie();
150     if (check_parameters_in_default_memory) {
151       CheckParametersInDefaultMemory(module);
152     }
153     CheckPresetAssignments(preset_assignments.get());
154     return preset_assignments;
155   }
156 
CheckPresetAssignments(const PresetAssignments * preset_assignments)157   void CheckPresetAssignments(const PresetAssignments* preset_assignments) {
158     // Ensure that the exported preset assignments point to layouts in the
159     // alternate memory.  Also ensure that the positions are unique. Note that
160     // we're using a std::set instead of absl::flat_hash_set because we can make
161     // use of HloPosition's comparator logic instead of providing a hasher.
162     std::set<HloPosition> positions_in_preset_assignments;
163     for (auto& position_and_chunk : preset_assignments->chunks()) {
164       HloPosition position = position_and_chunk.first;
165       EXPECT_EQ(positions_in_preset_assignments.find(position),
166                 positions_in_preset_assignments.end());
167       positions_in_preset_assignments.insert(position);
168       const Shape& subshape =
169           ShapeUtil::GetSubshape(position.instruction->shape(), position.index);
170       EXPECT_EQ(subshape.layout().memory_space(), kAlternateMemorySpace)
171           << "Exported position is not in alternate mem: "
172           << position.ToString();
173     }
174   }
175 
CheckParametersInDefaultMemory(const HloModule * module)176   void CheckParametersInDefaultMemory(const HloModule* module) {
177     // Check that all the entry parameter subshapes are placed in default
178     // memory.
179     const HloComputation* entry_computation = module->entry_computation();
180     for (const HloInstruction* parameter :
181          entry_computation->parameter_instructions()) {
182       ShapeUtil::ForEachSubshape(
183           parameter->shape(),
184           [&](const Shape& subshape, const ShapeIndex& /*index*/) {
185             if (subshape.has_layout()) {
186               EXPECT_NE(subshape.layout().memory_space(), kAlternateMemorySpace)
187                   << "Parameter not in default memory: "
188                   << parameter->ToString();
189             }
190           });
191     }
192   }
193 
194   struct OutstandingAsyncCopies {
195     int64 max_copies;
196     int64 max_prefetches;
197     int64 max_evictions;
198   };
199 
CountMaximumOutstandingAsyncCopies(const HloModule & module)200   /*static*/ OutstandingAsyncCopies CountMaximumOutstandingAsyncCopies(
201       const HloModule& module) {
202     OutstandingAsyncCopies copies{0, 0, 0};
203     int64 current_copies = 0;
204     int64 current_prefetches = 0;
205     int64 current_evictions = 0;
206     for (HloInstruction* instruction : module.schedule()
207                                            .sequence(module.entry_computation())
208                                            .instructions()) {
209       if (instruction->opcode() == HloOpcode::kCopyStart) {
210         current_copies++;
211         if (ShapeUtil::GetSubshape(instruction->shape(), {0})
212                 .layout()
213                 .memory_space() == kAlternateMemorySpace) {
214           current_prefetches++;
215         } else {
216           current_evictions++;
217         }
218       } else if (instruction->opcode() == HloOpcode::kCopyDone) {
219         current_copies--;
220         if (instruction->shape().layout().memory_space() ==
221             kAlternateMemorySpace) {
222           current_prefetches--;
223         } else {
224           current_evictions--;
225         }
226       }
227       copies.max_copies = std::max(copies.max_copies, current_copies);
228       copies.max_prefetches =
229           std::max(copies.max_prefetches, current_prefetches);
230       copies.max_prefetches = std::max(copies.max_evictions, current_evictions);
231     }
232     return copies;
233   }
234 
GetAlternateMemoryOffset(const PresetAssignments & preset_assignments,const HloInstruction * instruction,const ShapeIndex & index={}) const235   int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
236                                  const HloInstruction* instruction,
237                                  const ShapeIndex& index = {}) const {
238     // Returns the offset of the assignment, -1 if it's not in the alternate
239     // memory.
240     const HloModule* module = instruction->parent()->parent();
241     auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
242     HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
243     for (auto& pos_and_chunk : preset_assignments.chunks()) {
244       for (auto& value : buffer.values()) {
245         if (pos_and_chunk.first == value->defining_position()) {
246           return pos_and_chunk.second.offset;
247         }
248       }
249     }
250     return -1;
251   }
252 
CreateEvictAndPrefetchModule()253   std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
254     HloComputation::Builder builder(TestName());
255     Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
256     HloInstruction* p0 =
257         builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
258     HloInstruction* p1 =
259         builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
260     HloInstruction* tanh = builder.AddInstruction(
261         HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
262     // tanh should be placed in the alternate memory since there isn't much
263     // contention in the beginning. However, tanh has another consumer at the
264     // end. So it should be kicked out to default memory and prefetched back in.
265     // The graph below is meant to increase the contention to force
266     // eviction/prefetch behavior.
267     HloInstruction* a = builder.AddInstruction(
268         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
269     HloInstruction* b = builder.AddInstruction(
270         HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
271     HloInstruction* c = builder.AddInstruction(
272         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
273     HloInstruction* d = builder.AddInstruction(
274         HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
275     HloInstruction* e = builder.AddInstruction(
276         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
277     HloInstruction* f = builder.AddInstruction(
278         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
279     HloInstruction* g = builder.AddInstruction(
280         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
281     HloInstruction* h = builder.AddInstruction(
282         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
283     HloInstruction* i = builder.AddInstruction(
284         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
285     HloInstruction* j = builder.AddInstruction(
286         HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
287     HloInstruction* k = builder.AddInstruction(
288         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
289     HloInstruction* l = builder.AddInstruction(
290         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
291     HloInstruction* m = builder.AddInstruction(
292         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
293     HloInstruction* n = builder.AddInstruction(
294         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
295     HloInstruction* o = builder.AddInstruction(
296         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
297     // tanh is being used at the root instruction, and this should be
298     // prefetched.
299     HloInstruction* add = builder.AddInstruction(
300         HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
301 
302     auto module = CreateNewVerifiedModule();
303     HloComputation* computation = module->AddEntryComputation(builder.Build());
304 
305     HloSchedule schedule(module.get());
306     schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
307                                         j, k, l, m, n, o, add});
308     TF_CHECK_OK(module->set_schedule(schedule));
309     return module;
310   }
311 
312   MemorySpaceAssignmentCostAnalysis::Cache cache_;
313 };
314 
315 // For testing purposes, we define a cost analysis where we can control the
316 // elapsed times of each HLO and asynchronous copy.
317 class FakeMemorySpaceAssignmentCostAnalysis
318     : public MemorySpaceAssignmentCostAnalysis {
319  public:
320   static StatusOr<std::unique_ptr<FakeMemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,const HloModule & module)321   Create(const HloCostAnalysis& cost_analysis, const HloModule& module) {
322     TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
323     TF_ASSIGN_OR_RETURN(auto hlo_live_range,
324                         HloLiveRange::Run(module.schedule(), *alias_analysis,
325                                           module.entry_computation()));
326     auto call_graph = CallGraph::Build(&module);
327     return absl::WrapUnique(new FakeMemorySpaceAssignmentCostAnalysis(
328         cost_analysis, /*async_copy_bandwidth_bytes_per_second=*/1,
329         /*alternate_mem_bandwidth_bytes_per_second=*/1,
330         std::move(alias_analysis), std::move(hlo_live_range),
331         std::move(call_graph)));
332   }
333 
GetInstructionElapsed(const HloInstruction & instruction) const334   float GetInstructionElapsed(
335       const HloInstruction& instruction) const override {
336     if (get_instruction_elapsed_override_) {
337       return get_instruction_elapsed_override_(instruction);
338     }
339     return 1.0;
340   }
341 
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const342   float GetInstructionElapsedInAlternateMemory(
343       const HloInstruction& instruction,
344       absl::optional<int64> operand_in_alternate_mem,
345       bool output_in_alternate_mem) const override {
346     if (get_instruction_elapsed_in_alternate_memory_override_) {
347       return get_instruction_elapsed_in_alternate_memory_override_(
348           instruction, operand_in_alternate_mem, output_in_alternate_mem);
349     }
350     if (operand_in_alternate_mem) {
351       return 0.5;
352     } else {
353       return 1.0;
354     }
355   }
356 
GetAsyncCopyElapsed(const Shape & shape) const357   float GetAsyncCopyElapsed(const Shape& shape) const override {
358     if (get_async_copy_elapsed_override_) {
359       return get_async_copy_elapsed_override_(shape);
360     }
361     return 3.0;
362   }
363 
364   // The following methods can be used to override what the above API calls
365   // return.
SetOverrideForGetInstructionElapsed(std::function<float (const HloInstruction &)> function)366   void SetOverrideForGetInstructionElapsed(
367       std::function<float(const HloInstruction&)> function) {
368     get_instruction_elapsed_override_ = function;
369   }
SetOverrideForGetInstructionElapsedInAlternateMemory(std::function<float (const HloInstruction &,absl::optional<int64>,bool)> function)370   void SetOverrideForGetInstructionElapsedInAlternateMemory(
371       std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
372           function) {
373     get_instruction_elapsed_in_alternate_memory_override_ = function;
374   }
SetOverrideForGetAsyncCopyElapsed(std::function<float (const Shape &)> function)375   void SetOverrideForGetAsyncCopyElapsed(
376       std::function<float(const Shape&)> function) {
377     get_async_copy_elapsed_override_ = function;
378   }
379 
380  protected:
FakeMemorySpaceAssignmentCostAnalysis(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,std::unique_ptr<HloAliasAnalysis> alias_analysis,std::unique_ptr<HloLiveRange> hlo_live_range,std::unique_ptr<CallGraph> call_graph)381   FakeMemorySpaceAssignmentCostAnalysis(
382       const HloCostAnalysis& cost_analysis,
383       float async_copy_bandwidth_bytes_per_second,
384       float alternate_mem_bandwidth_bytes_per_second,
385       std::unique_ptr<HloAliasAnalysis> alias_analysis,
386       std::unique_ptr<HloLiveRange> hlo_live_range,
387       std::unique_ptr<CallGraph> call_graph)
388       : MemorySpaceAssignmentCostAnalysis(
389             cost_analysis, async_copy_bandwidth_bytes_per_second,
390             alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
391             std::move(hlo_live_range), std::move(call_graph)) {}
392 
393  private:
394   std::function<float(const HloInstruction&)>
395       get_instruction_elapsed_override_ = nullptr;
396   std::function<float(const HloInstruction&, absl::optional<int64>, bool)>
397       get_instruction_elapsed_in_alternate_memory_override_ = nullptr;
398   std::function<float(const Shape&)> get_async_copy_elapsed_override_ = nullptr;
399 };
400 
TEST_P(MemorySpaceAssignmentTest,ParameterOnly)401 TEST_P(MemorySpaceAssignmentTest, ParameterOnly) {
402   // A module consisting of a single parameter. Inputs/outputs are currently
403   // excluded from memory space assignment.
404   HloComputation::Builder builder(TestName());
405   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
406   HloInstruction* p0 =
407       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
408 
409   auto module = CreateNewVerifiedModule();
410   HloComputation* computation = module->AddEntryComputation(builder.Build());
411 
412   HloSchedule schedule(module.get());
413   schedule.set_sequence(computation, {p0});
414   TF_CHECK_OK(module->set_schedule(schedule));
415 
416   AssignMemorySpace(module.get());
417 
418   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
419 }
420 
TEST_P(MemorySpaceAssignmentTest,Simple)421 TEST_P(MemorySpaceAssignmentTest, Simple) {
422   // A simple module with a few simple instructions. Expect this to be
423   // transformed with CopyStart and CopyDone instructions inserted after inputs
424   // and before outputs.
425   HloComputation::Builder builder(TestName());
426   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
427   HloInstruction* p0 =
428       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
429   HloInstruction* p1 =
430       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
431   HloInstruction* add = builder.AddInstruction(
432       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p1));
433   HloInstruction* sub = builder.AddInstruction(
434       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
435   HloInstruction* mul = builder.AddInstruction(
436       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, sub));
437 
438   auto module = CreateNewVerifiedModule();
439   HloComputation* computation = module->AddEntryComputation(builder.Build());
440 
441   HloSchedule schedule(module.get());
442   schedule.set_sequence(computation, {p0, p1, add, sub, mul});
443   TF_CHECK_OK(module->set_schedule(schedule));
444 
445   auto preset_assignments = AssignMemorySpace(module.get());
446 
447   // Inputs and outputs are currently placed in the default memory. Everything
448   // else should be in the alternate memory.
449   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
450       F32, {2, 3},
451       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
452       kAlternateMemorySpace);
453   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
454   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
455   EXPECT_THAT(mul, op::ShapeWithLayout(shape));
456   EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
457   EXPECT_THAT(sub, op::ShapeWithLayout(shape_in_alternate_mem));
458 
459   // Make sure the preset assignments is sane.
460   EXPECT_EQ(preset_assignments->chunks().size(), 3);
461   EXPECT_EQ(preset_assignments->assignment_informations().size(), 1);
462   // Ensure the offset assigned to add and sub are different.
463   EXPECT_NE(preset_assignments->chunks()[0].second.offset,
464             preset_assignments->chunks()[1].second.offset);
465 }
466 
TEST_P(MemorySpaceAssignmentTest,NegateChain)467 TEST_P(MemorySpaceAssignmentTest, NegateChain) {
468   // The negate chain is long enough for asynchronous copy to be inserted
469   // between p1 and add.
470   HloComputation::Builder builder(TestName());
471   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
472   HloInstruction* p0 =
473       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
474   HloInstruction* p1 =
475       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
476   HloInstruction* negate0 = builder.AddInstruction(
477       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
478   HloInstruction* negate1 = builder.AddInstruction(
479       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
480   HloInstruction* negate2 = builder.AddInstruction(
481       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
482   HloInstruction* negate3 = builder.AddInstruction(
483       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
484   HloInstruction* negate4 = builder.AddInstruction(
485       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
486   HloInstruction* negate5 = builder.AddInstruction(
487       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
488   HloInstruction* negate6 = builder.AddInstruction(
489       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
490   HloInstruction* add = builder.AddInstruction(
491       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
492 
493   auto module = CreateNewVerifiedModule();
494   HloComputation* computation = module->AddEntryComputation(builder.Build());
495 
496   HloSchedule schedule(module.get());
497   schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
498                                       negate3, negate4, negate5, negate6, add});
499   TF_CHECK_OK(module->set_schedule(schedule));
500 
501   AssignMemorySpace(module.get());
502 
503   EXPECT_THAT(add, op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
504                                                        kDefaultMemorySpace,
505                                                        op::Parameter(1))));
506   // Parameters are in the default memory space.
507   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
508   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
509   // Negate instructions are in the alternate memory space (1).
510   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
511       F32, {2, 3},
512       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
513       kAlternateMemorySpace);
514   EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
515   EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
516   EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
517   EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
518   EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
519   EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
520   EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
521   // Ensure the CopyStart/CopyDone schedules.
522   const HloInstructionSequence& sequence =
523       module->schedule().sequence(computation);
524   EXPECT_THAT(sequence.instructions()[0], op::Parameter(0));
525   EXPECT_THAT(sequence.instructions()[1], op::Parameter(1));
526   EXPECT_THAT(sequence.instructions()[2], op::CopyStart());
527   EXPECT_THAT(sequence.instructions()[10], op::CopyDone());
528 }
529 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetch)530 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetch) {
531   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
532 
533   AssignMemorySpace(module.get());
534 
535   EXPECT_THAT(
536       module->entry_computation()->root_instruction(),
537       op::Add(op::Add(),
538               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
539                             op::AsyncCopy(kDefaultMemorySpace,
540                                           kAlternateMemorySpace, op::Tanh()))));
541 }
542 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies0)543 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies0) {
544   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
545 
546   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/0);
547 
548   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 0);
549   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 0);
550 }
551 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies1)552 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies1) {
553   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
554 
555   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
556 
557   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 1);
558   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 1);
559 }
560 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchLimitAsyncCopies2)561 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchLimitAsyncCopies2) {
562   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
563 
564   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/2);
565 
566   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_prefetches, 2);
567   EXPECT_LE(CountMaximumOutstandingAsyncCopies(*module).max_evictions, 2);
568 }
569 
570 // TODO(berkin): This test is broken with some prefetch timing improvements.
TEST_P(MemorySpaceAssignmentTest,DISABLED_DontEvictWhenThereIsDefaultMemAllocation)571 TEST_P(MemorySpaceAssignmentTest,
572        DISABLED_DontEvictWhenThereIsDefaultMemAllocation) {
573   // This test is the same as EvictAndPrefetchLimitAsyncCopies1, except we check
574   // that there is no eviction if not necessary (due to an existing allocation
575   // in default memory).
576   HloComputation::Builder builder(TestName());
577   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
578   HloInstruction* p0 =
579       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
580   HloInstruction* p1 =
581       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
582   HloInstruction* tanh = builder.AddInstruction(
583       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
584   // tanh should be placed in the alternate memory since there isn't much
585   // contention in the beginning. However, tanh has another consumer at the end.
586   // So it should be kicked out to default memory and prefetched back in.  The
587   // graph below is meant to increase the contention to force eviction/prefetch
588   // behavior.
589   HloInstruction* a = builder.AddInstruction(
590       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
591   HloInstruction* b = builder.AddInstruction(
592       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
593   HloInstruction* c = builder.AddInstruction(
594       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
595   HloInstruction* d = builder.AddInstruction(
596       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
597   HloInstruction* e = builder.AddInstruction(
598       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
599   HloInstruction* f = builder.AddInstruction(
600       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
601   HloInstruction* g = builder.AddInstruction(
602       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
603   HloInstruction* h = builder.AddInstruction(
604       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
605   HloInstruction* i = builder.AddInstruction(
606       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
607   HloInstruction* j = builder.AddInstruction(
608       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
609   HloInstruction* k = builder.AddInstruction(
610       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
611   HloInstruction* l = builder.AddInstruction(
612       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
613   HloInstruction* m = builder.AddInstruction(
614       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
615   HloInstruction* n = builder.AddInstruction(
616       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
617   HloInstruction* o = builder.AddInstruction(
618       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
619   // tanh is being used at the root instruction, and this should be
620   // prefetched.
621   HloInstruction* add = builder.AddInstruction(
622       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
623 
624   auto module = CreateNewVerifiedModule();
625   HloComputation* computation = module->AddEntryComputation(builder.Build());
626 
627   HloSchedule schedule(module.get());
628   schedule.set_sequence(computation, {p0, p1, tanh, a, b, c, d, e, f, g, h, i,
629                                       j, k, l, m, n, o, add});
630   TF_CHECK_OK(module->set_schedule(schedule));
631 
632   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/1);
633 
634   // We expect the second argument to multiply is prefetched c.
635   EXPECT_THAT(f, op::Multiply(op::Add(), op::CopyDone()));
636   // We make sure that the second argument to this multiply is not evicted
637   // CopyDone but is the original c.
638   EXPECT_THAT(h, op::Multiply(op::Subtract(), op::Multiply()));
639 }
640 
TEST_P(MemorySpaceAssignmentTest,EvictAndPrefetchAndPrefetch)641 TEST_P(MemorySpaceAssignmentTest, EvictAndPrefetchAndPrefetch) {
642   // Test for a memory corruption bug involving evict/prefetch/prefetch pattern,
643   // where the last prefetch copied from the original buffer in alternate buffer
644   // instead of evicted buffer.
645   HloComputation::Builder builder(TestName());
646   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
647   HloInstruction* p0 =
648       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
649   HloInstruction* p1 =
650       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
651   HloInstruction* tanh = builder.AddInstruction(
652       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
653   HloInstruction* a = builder.AddInstruction(
654       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, tanh));
655   HloInstruction* b = builder.AddInstruction(
656       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
657   HloInstruction* c = builder.AddInstruction(
658       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, p0, p1));
659   HloInstruction* d = builder.AddInstruction(
660       HloInstruction::CreateBinary(shape, HloOpcode::kSubtract, p0, p1));
661   HloInstruction* e = builder.AddInstruction(
662       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, b));
663   HloInstruction* f = builder.AddInstruction(
664       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, c));
665   HloInstruction* g = builder.AddInstruction(
666       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, a, d));
667   HloInstruction* h = builder.AddInstruction(
668       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, c));
669   HloInstruction* i = builder.AddInstruction(
670       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, b, d));
671   HloInstruction* j = builder.AddInstruction(
672       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, c, d));
673   HloInstruction* k = builder.AddInstruction(
674       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, e, f));
675   HloInstruction* l = builder.AddInstruction(
676       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, g, h));
677   HloInstruction* m = builder.AddInstruction(
678       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, i, j));
679   HloInstruction* n = builder.AddInstruction(
680       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, k, l));
681   HloInstruction* o = builder.AddInstruction(
682       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, n, m));
683   HloInstruction* add0 = builder.AddInstruction(
684       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, o, tanh));
685   HloInstruction* negate0 = builder.AddInstruction(
686       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add0));
687   HloInstruction* negate1 = builder.AddInstruction(
688       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
689   HloInstruction* negate2 = builder.AddInstruction(
690       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
691   HloInstruction* negate3 = builder.AddInstruction(
692       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
693   HloInstruction* negate4 = builder.AddInstruction(
694       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
695   HloInstruction* negate5 = builder.AddInstruction(
696       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
697   HloInstruction* negate6 = builder.AddInstruction(
698       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
699   HloInstruction* negate7 = builder.AddInstruction(
700       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
701   HloInstruction* negate8 = builder.AddInstruction(
702       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
703   HloInstruction* negate9 = builder.AddInstruction(
704       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
705   HloInstruction* add1 = builder.AddInstruction(
706       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate9, tanh));
707 
708   auto module = CreateNewVerifiedModule();
709   HloComputation* computation = module->AddEntryComputation(builder.Build());
710 
711   HloSchedule schedule(module.get());
712   schedule.set_sequence(
713       computation,
714       {p0,      p1,      tanh,    a,       b,       c,       d,       e,
715        f,       g,       h,       i,       j,       k,       l,       m,
716        n,       o,       add0,    negate0, negate1, negate2, negate3, negate4,
717        negate5, negate6, negate7, negate8, negate9, add1});
718   TF_CHECK_OK(module->set_schedule(schedule));
719 
720   AssignMemorySpace(module.get());
721 
722   // Check that both prefetches (add0 and add1) prefetch from the eviction
723   // instead of tanh, which will be placed in the alternate memory directly.
724   EXPECT_THAT(
725       add0,
726       op::Add(op::Add(),
727               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
728                             op::AsyncCopy(kDefaultMemorySpace,
729                                           kAlternateMemorySpace, op::Tanh()))));
730   EXPECT_THAT(
731       add1,
732       op::Add(op::Negate(),
733               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
734                             op::AsyncCopy(kDefaultMemorySpace,
735                                           kAlternateMemorySpace, op::Tanh()))));
736 }
737 
TEST_P(MemorySpaceAssignmentTest,While)738 TEST_P(MemorySpaceAssignmentTest, While) {
739   auto module = CreateNewVerifiedModule();
740   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
741   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
742   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
743 
744   auto cond_builder = HloComputation::Builder("WhileCond");
745   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
746   HloInstruction* cond_param = cond_builder.AddInstruction(
747       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
748   HloInstruction* cond_iter = cond_builder.AddInstruction(
749       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
750   HloInstruction* cond_limit = cond_builder.AddInstruction(
751       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
752   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
753   HloInstruction* cond_lt = cond_builder.AddInstruction(
754       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
755                                     cond_limit, ComparisonDirection::kLt));
756   HloComputation* cond_computation =
757       module->AddEmbeddedComputation(cond_builder.Build());
758 
759   auto body_builder = HloComputation::Builder("WhileBody");
760   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
761   HloInstruction* body_param = body_builder.AddInstruction(
762       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
763   HloInstruction* body_iter = body_builder.AddInstruction(
764       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
765   HloInstruction* body_data = body_builder.AddInstruction(
766       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
767   HloInstruction* body_iter_increment = body_builder.AddInstruction(
768       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
769   HloInstruction* body_iter_next =
770       body_builder.AddInstruction(HloInstruction::CreateBinary(
771           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
772   HloInstruction* body_data_increment =
773       body_builder.AddInstruction(HloInstruction::CreateConstant(
774           LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
775   HloInstruction* body_data_mul =
776       body_builder.AddInstruction(HloInstruction::CreateBinary(
777           shape, HloOpcode::kMultiply, body_data, body_data));
778   HloInstruction* body_data_add =
779       body_builder.AddInstruction(HloInstruction::CreateBinary(
780           shape, HloOpcode::kAdd, body_data, body_data_increment));
781   HloInstruction* body_data_next =
782       body_builder.AddInstruction(HloInstruction::CreateBinary(
783           shape, HloOpcode::kAdd, body_data_add, body_data_mul));
784   HloInstruction* body_out = body_builder.AddInstruction(
785       HloInstruction::CreateTuple({body_data_next, body_iter_next}));
786   HloComputation* body_computation =
787       module->AddEmbeddedComputation(body_builder.Build());
788 
789   auto builder = HloComputation::Builder(TestName());
790   HloInstruction* data = builder.AddInstruction(
791       HloInstruction::CreateParameter(0, shape, "param_iter"));
792   HloInstruction* iter = builder.AddInstruction(
793       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
794   HloInstruction* tuple =
795       builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
796   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
797       tuple_shape, cond_computation, body_computation, tuple));
798   HloComputation* entry_computation =
799       module->AddEntryComputation(builder.Build());
800 
801   HloSchedule schedule(module.get());
802   schedule.set_sequence(cond_computation,
803                         {cond_param, cond_iter, cond_limit, cond_lt});
804   schedule.set_sequence(body_computation,
805                         {body_param, body_iter, body_data, body_iter_increment,
806                          body_iter_next, body_data_increment, body_data_mul,
807                          body_data_add, body_data_next, body_out});
808   schedule.set_sequence(entry_computation, {iter, data, tuple, while_op});
809   TF_CHECK_OK(module->set_schedule(schedule));
810 
811   AssignMemorySpace(module.get());
812 
813   // Ensure the tuple value and buffers used in the while instruction are
814   // exempted from using the alternate memory when allocating across sequential
815   // calls is disabled. However, body_data_mul is independent and can be safely
816   // be placed in the alternate memory.
817   const bool allocate_across_sequential_calls = GetParam();
818   if (!allocate_across_sequential_calls) {
819     EXPECT_THAT(tuple, op::ShapeWithLayout(tuple_shape));
820     EXPECT_THAT(data, op::ShapeWithLayout(shape));
821     EXPECT_THAT(iter, op::ShapeWithLayout(scalar_shape));
822     EXPECT_THAT(body_data, op::ShapeWithLayout(shape));
823     EXPECT_THAT(body_iter, op::ShapeWithLayout(scalar_shape));
824     EXPECT_THAT(cond_iter, op::ShapeWithLayout(scalar_shape));
825   }
826   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
827       F32, {2, 3},
828       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
829       kAlternateMemorySpace);
830   EXPECT_THAT(body_data_mul, op::ShapeWithLayout(shape_in_alternate_mem));
831 }
832 
TEST_P(MemorySpaceAssignmentTest,Tuple)833 TEST_P(MemorySpaceAssignmentTest, Tuple) {
834   HloComputation::Builder builder(TestName());
835   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
836   Shape inner_tuple_shape = ShapeUtil::MakeTupleShape({shape});
837   Shape tuple_shape =
838       ShapeUtil::MakeTupleShape({shape, shape, inner_tuple_shape});
839   HloInstruction* p = builder.AddInstruction(
840       HloInstruction::CreateParameter(0, tuple_shape, "p"));
841   HloInstruction* p0 = builder.AddInstruction(
842       HloInstruction::CreateGetTupleElement(shape, p, 0));
843   HloInstruction* negate0 = builder.AddInstruction(
844       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
845   HloInstruction* negate1 = builder.AddInstruction(
846       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
847   HloInstruction* negate2 = builder.AddInstruction(
848       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
849   HloInstruction* negate3 = builder.AddInstruction(
850       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
851   HloInstruction* negate4 = builder.AddInstruction(
852       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
853   HloInstruction* negate5 = builder.AddInstruction(
854       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
855   HloInstruction* negate6 = builder.AddInstruction(
856       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
857   HloInstruction* p1 = builder.AddInstruction(
858       HloInstruction::CreateGetTupleElement(shape, p, 1));
859   HloInstruction* add = builder.AddInstruction(
860       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
861   HloInstruction* p2 = builder.AddInstruction(
862       HloInstruction::CreateGetTupleElement(inner_tuple_shape, p, 2));
863   HloInstruction* p2_0 = builder.AddInstruction(
864       HloInstruction::CreateGetTupleElement(shape, p2, 0));
865   HloInstruction* mul = builder.AddInstruction(
866       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add, p2_0));
867 
868   auto module = CreateNewVerifiedModule();
869   HloComputation* computation = module->AddEntryComputation(builder.Build());
870 
871   HloSchedule schedule(module.get());
872   schedule.set_sequence(
873       computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
874                     negate6, p1, add, p2, p2_0, mul});
875   TF_CHECK_OK(module->set_schedule(schedule));
876 
877   AssignMemorySpace(module.get());
878 
879   EXPECT_THAT(
880       mul,
881       op::Multiply(op::Add(op::Negate(), op::AsyncCopy(kAlternateMemorySpace,
882                                                        kDefaultMemorySpace,
883                                                        op::GetTupleElement())),
884                    op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
885                                  op::GetTupleElement(op::GetTupleElement()))));
886 }
887 
TEST_P(MemorySpaceAssignmentTest,Bitcast)888 TEST_P(MemorySpaceAssignmentTest, Bitcast) {
889   // Bitcasts can cause the position in the alternate memory to appear multiple
890   // times in the preset assignments. This test ensure the preset assignments
891   // refer to unique positions.
892   HloComputation::Builder builder(TestName());
893   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
894   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
895   HloInstruction* p0 =
896       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
897   HloInstruction* p1 = builder.AddInstruction(
898       HloInstruction::CreateParameter(1, param_shape, "p1"));
899   HloInstruction* negate = builder.AddInstruction(
900       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
901   HloInstruction* bitcast = builder.AddInstruction(
902       HloInstruction::CreateBitcast(param_shape, negate));
903   HloInstruction* add = builder.AddInstruction(
904       HloInstruction::CreateBinary(param_shape, HloOpcode::kAdd, bitcast, p1));
905 
906   auto module = CreateNewVerifiedModule();
907   HloComputation* computation = module->AddEntryComputation(builder.Build());
908 
909   HloSchedule schedule(module.get());
910   schedule.set_sequence(computation, {p0, p1, negate, bitcast, add});
911   TF_CHECK_OK(module->set_schedule(schedule));
912 
913   AssignMemorySpace(module.get());
914 
915   bitcast = add->mutable_operand(0);
916   EXPECT_EQ(bitcast->opcode(), HloOpcode::kBitcast);
917   EXPECT_EQ(bitcast->shape().layout().memory_space(), kAlternateMemorySpace);
918 }
919 
TEST_P(MemorySpaceAssignmentTest,Bitcast2)920 TEST_P(MemorySpaceAssignmentTest, Bitcast2) {
921   HloComputation::Builder builder(TestName());
922   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
923   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
924   HloInstruction* p0 =
925       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
926   HloInstruction* p1 = builder.AddInstruction(
927       HloInstruction::CreateParameter(1, param_shape, "p1"));
928   HloInstruction* negate0 = builder.AddInstruction(
929       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
930   HloInstruction* negate1 = builder.AddInstruction(
931       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
932   HloInstruction* negate2 = builder.AddInstruction(
933       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
934   HloInstruction* negate3 = builder.AddInstruction(
935       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
936   HloInstruction* negate4 = builder.AddInstruction(
937       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
938   HloInstruction* bitcast =
939       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
940   HloInstruction* add = builder.AddInstruction(
941       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
942 
943   auto module = CreateNewVerifiedModule();
944   HloComputation* computation = module->AddEntryComputation(builder.Build());
945 
946   HloSchedule schedule(module.get());
947   schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
948                                       negate3, negate4, bitcast, add});
949   TF_CHECK_OK(module->set_schedule(schedule));
950 
951   AssignMemorySpace(module.get());
952 
953   EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
954             kAlternateMemorySpace);
955 }
956 
TEST_P(MemorySpaceAssignmentTest,Bitcast3)957 TEST_P(MemorySpaceAssignmentTest, Bitcast3) {
958   HloComputation::Builder builder(TestName());
959   Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
960   Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
961   Shape shape3 = ShapeUtil::MakeShape(F32, {1, 6});
962   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
963   HloInstruction* p0 =
964       builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
965   HloInstruction* p1 = builder.AddInstruction(
966       HloInstruction::CreateParameter(1, param_shape, "p1"));
967   HloInstruction* negate0 = builder.AddInstruction(
968       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p0));
969   HloInstruction* negate1 = builder.AddInstruction(
970       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0));
971   HloInstruction* negate2 = builder.AddInstruction(
972       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1));
973   HloInstruction* negate3 = builder.AddInstruction(
974       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2));
975   HloInstruction* negate4 = builder.AddInstruction(
976       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3));
977   HloInstruction* bitcast1 =
978       builder.AddInstruction(HloInstruction::CreateBitcast(shape1, p1));
979   HloInstruction* add = builder.AddInstruction(
980       HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, bitcast1, negate4));
981   HloInstruction* bitcast2 =
982       builder.AddInstruction(HloInstruction::CreateBitcast(shape3, p1));
983   HloInstruction* bitcast3 =
984       builder.AddInstruction(HloInstruction::CreateBitcast(shape2, bitcast2));
985   HloInstruction* bitcast4 =
986       builder.AddInstruction(HloInstruction::CreateBitcast(shape2, add));
987   HloInstruction* mul = builder.AddInstruction(HloInstruction::CreateBinary(
988       shape2, HloOpcode::kMultiply, bitcast3, bitcast4));
989 
990   auto module = CreateNewVerifiedModule();
991   HloComputation* computation = module->AddEntryComputation(builder.Build());
992 
993   HloSchedule schedule(module.get());
994   schedule.set_sequence(computation,
995                         {p0, p1, negate0, negate1, negate2, negate3, negate4,
996                          bitcast1, add, bitcast2, bitcast3, bitcast4, mul});
997   TF_CHECK_OK(module->set_schedule(schedule));
998 
999   AssignMemorySpace(module.get());
1000 
1001   // We expect one bitcast on the LHS of multiply since bitcast(bitcast(foo)) is
1002   // converted to bitcast(foo).
1003   EXPECT_THAT(
1004       mul,
1005       op::Multiply(
1006           op::Bitcast(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
1007                                     op::Parameter(1))),
1008           op::Bitcast(op::Add(
1009               op::Bitcast(op::AsyncCopy(kAlternateMemorySpace,
1010                                         kDefaultMemorySpace, op::Parameter(1))),
1011               op::Negate()))));
1012   EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1013             kAlternateMemorySpace);
1014   EXPECT_EQ(add->shape().layout().memory_space(), kAlternateMemorySpace);
1015   // bitcast2 will no longer have a consumer and should get DCE'd, so we don't
1016   // care about its memory space.
1017   EXPECT_EQ(mul->operand(0)->shape().layout().memory_space(),
1018             kAlternateMemorySpace);
1019   EXPECT_EQ(mul->operand(1)->shape().layout().memory_space(),
1020             kAlternateMemorySpace);
1021 }
1022 
TEST_P(MemorySpaceAssignmentTest,BitcastTuple)1023 TEST_P(MemorySpaceAssignmentTest, BitcastTuple) {
1024   HloComputation::Builder builder(TestName());
1025   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1026   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1027   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1028 
1029   auto module = CreateNewVerifiedModule();
1030   HloComputation::Builder fusion_builder("fusion");
1031   HloInstruction* fusion_param = fusion_builder.AddInstruction(
1032       HloInstruction::CreateParameter(0, tuple_shape, "p"));
1033   HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1034       HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1035   HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1036       HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1037   fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1038       shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1039   HloComputation* fusion_computation =
1040       module->AddEmbeddedComputation(fusion_builder.Build());
1041 
1042   HloInstruction* p0 =
1043       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1044   HloInstruction* p1 = builder.AddInstruction(
1045       HloInstruction::CreateParameter(1, param_shape, "p1"));
1046   HloInstruction* negate0 = builder.AddInstruction(
1047       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1048   HloInstruction* negate1 = builder.AddInstruction(
1049       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1050   HloInstruction* negate2 = builder.AddInstruction(
1051       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1052   HloInstruction* negate3 = builder.AddInstruction(
1053       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1054   HloInstruction* negate4 = builder.AddInstruction(
1055       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1056   HloInstruction* bitcast =
1057       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1058   HloInstruction* tuple =
1059       builder.AddInstruction(HloInstruction::CreateTuple({bitcast, p0}));
1060   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1061       shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1062 
1063   HloComputation* computation = module->AddEntryComputation(builder.Build());
1064 
1065   HloSchedule schedule(module.get());
1066   schedule.set_sequence(computation,
1067                         {p0, p1, negate0, negate1, negate2, negate3, negate4,
1068                          bitcast, tuple, fusion});
1069   TF_CHECK_OK(module->set_schedule(schedule));
1070 
1071   AssignMemorySpace(module.get());
1072 }
1073 
TEST_P(MemorySpaceAssignmentTest,BitcastGetTupleElementTuple)1074 TEST_P(MemorySpaceAssignmentTest, BitcastGetTupleElementTuple) {
1075   // This test pattern was encountered in
1076   // //third_party/tensorflow/compiler/xla/tests:slice_test and was causing a
1077   // breakage when there is a GetTupleElement(Tuple(Bitcast())) pattern. Also
1078   // added a GetTupleElement(GetTupleElement(Tuple(Tuple(Bitcast())))) pattern.
1079   absl::string_view hlo_string = R"(
1080   HloModule DoIt_S64_10_0_5_1.3, is_scheduled=true
1081 
1082   ENTRY %DoIt_S64_10_0_5_1.3 (p0.1: (u32[10], u32[10])) -> (u32[5], u32[5]) {
1083     %p0.1 = (u32[10]{0:T(128)}, u32[10]{0:T(128)}) parameter(0)
1084     %get-tuple-element.1 = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=1
1085     %bitcast.1 = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element.1)
1086     %get-tuple-element = u32[10]{0:T(128)} get-tuple-element((u32[10]{0:T(128)}, u32[10]{0:T(128)}) %p0.1), index=0
1087     %bitcast = u32[5]{0:T(128)} bitcast(u32[10]{0:T(128)} %get-tuple-element)
1088     %tuple.1 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %bitcast, u32[5]{0:T(128)} %bitcast.1)
1089     %tuple.3 = ((u32[5]{0:T(128)}, u32[5]{0:T(128)}), (u32[5]{0:T(128)}, u32[5]{0:T(128)})) tuple(%tuple.1, %tuple.1)
1090     %get-tuple-element.4 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %tuple.1), index=0
1091     %get-tuple-element.5 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) get-tuple-element(%tuple.3), index=0
1092     %get-tuple-element.6 = u32[5]{0:T(128)} get-tuple-element((u32[5]{0:T(128)}, u32[5]{0:T(128)}) %get-tuple-element.5), index=1
1093     %copy.2 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.4)
1094     %copy.3 = u32[5]{0:T(128)} copy(u32[5]{0:T(128)} %get-tuple-element.6)
1095     ROOT %tuple.2 = (u32[5]{0:T(128)}, u32[5]{0:T(128)}) tuple(u32[5]{0:T(128)} %copy.2, u32[5]{0:T(128)} %copy.3)
1096   }
1097   )";
1098 
1099   TF_ASSERT_OK_AND_ASSIGN(auto module,
1100                           ParseAndReturnVerifiedModule(hlo_string));
1101   AssignMemorySpace(module.get());
1102 }
1103 
TEST_P(MemorySpaceAssignmentTest,GetSimplifiedOperandBug)1104 TEST_P(MemorySpaceAssignmentTest, GetSimplifiedOperandBug) {
1105   // Test case for a bug finding Bitcasts in GTE(Tuple(...)) pattern.
1106   absl::string_view hlo_string = R"(
1107   HloModule sort.16, is_scheduled=true
1108 
1109   ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1110     %param.3.4 = s32[1]{0:T(128)} parameter(3)
1111     %param.2.3 = u32[1]{0:T(128)} parameter(2)
1112     %param.1.2 = f32[1]{0:T(128)} parameter(1)
1113     %param.0.1 = s32[1]{0:T(128)} parameter(0)
1114     %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4)
1115     %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1116     %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1117     %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1118     %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
1119     %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
1120     %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
1121     %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
1122     %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
1123     ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
1124 }
1125   )";
1126 
1127   TF_ASSERT_OK_AND_ASSIGN(auto module,
1128                           ParseAndReturnVerifiedModule(hlo_string));
1129   AssignMemorySpace(module.get());
1130 }
1131 
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUse)1132 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUse) {
1133   // When there is a pattern where a bitcast has multiple uses (negate0 and add)
1134   // and one is in the default memory and the other is in alternate memory, they
1135   // both need their own bitcast.
1136   HloComputation::Builder builder(TestName());
1137   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1138   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1139   HloInstruction* p0 = builder.AddInstruction(
1140       HloInstruction::CreateParameter(0, param_shape, "p1"));
1141   HloInstruction* bitcast =
1142       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1143   HloInstruction* negate0 = builder.AddInstruction(
1144       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1145   HloInstruction* negate1 = builder.AddInstruction(
1146       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1147   HloInstruction* negate2 = builder.AddInstruction(
1148       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1149   HloInstruction* negate3 = builder.AddInstruction(
1150       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1151   HloInstruction* negate4 = builder.AddInstruction(
1152       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1153   HloInstruction* add = builder.AddInstruction(
1154       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate4));
1155 
1156   auto module = CreateNewVerifiedModule();
1157   HloComputation* computation = module->AddEntryComputation(builder.Build());
1158 
1159   HloSchedule schedule(module.get());
1160   schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1161                                       negate3, negate4, add});
1162   TF_CHECK_OK(module->set_schedule(schedule));
1163 
1164   AssignMemorySpace(module.get());
1165   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1166       F32, {2, 3},
1167       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
1168       kAlternateMemorySpace);
1169   EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1170   EXPECT_THAT(add->operand(0), op::ShapeWithLayout(shape_in_alternate_mem));
1171 }
1172 
TEST_P(MemorySpaceAssignmentTest,BitcastMultiUseTuple)1173 TEST_P(MemorySpaceAssignmentTest, BitcastMultiUseTuple) {
1174   // Same as BitcastMultUse but the second use is a tuple.
1175   HloComputation::Builder builder(TestName());
1176   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1177   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1178   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
1179 
1180   auto module = CreateNewVerifiedModule();
1181   HloComputation::Builder fusion_builder("fusion");
1182   HloInstruction* fusion_param = fusion_builder.AddInstruction(
1183       HloInstruction::CreateParameter(0, tuple_shape, "p"));
1184   HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
1185       HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
1186   HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
1187       HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
1188   fusion_builder.AddInstruction(HloInstruction::CreateBinary(
1189       shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
1190   HloComputation* fusion_computation =
1191       module->AddEmbeddedComputation(fusion_builder.Build());
1192 
1193   HloInstruction* p0 = builder.AddInstruction(
1194       HloInstruction::CreateParameter(0, param_shape, "p1"));
1195   HloInstruction* bitcast =
1196       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p0));
1197   HloInstruction* negate0 = builder.AddInstruction(
1198       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, bitcast));
1199   HloInstruction* negate1 = builder.AddInstruction(
1200       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1201   HloInstruction* negate2 = builder.AddInstruction(
1202       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1203   HloInstruction* negate3 = builder.AddInstruction(
1204       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1205   HloInstruction* negate4 = builder.AddInstruction(
1206       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1207   HloInstruction* tuple =
1208       builder.AddInstruction(HloInstruction::CreateTuple({bitcast, negate4}));
1209   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
1210       shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
1211 
1212   HloComputation* computation = module->AddEntryComputation(builder.Build());
1213 
1214   HloSchedule schedule(module.get());
1215   schedule.set_sequence(computation, {p0, bitcast, negate0, negate1, negate2,
1216                                       negate3, negate4, tuple, fusion});
1217   TF_CHECK_OK(module->set_schedule(schedule));
1218 
1219   AssignMemorySpace(module.get());
1220   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
1221       F32, {2, 3},
1222       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
1223       kAlternateMemorySpace);
1224   EXPECT_THAT(negate0->operand(0), op::ShapeWithLayout(shape));
1225   EXPECT_THAT(fusion->operand(0)->operand(0),
1226               op::ShapeWithLayout(shape_in_alternate_mem));
1227 }
1228 
TEST_P(MemorySpaceAssignmentTest,BitcastScheduleBug)1229 TEST_P(MemorySpaceAssignmentTest, BitcastScheduleBug) {
1230   // Bitcasts can force asynchronous copies to be scheduled too early, possibly
1231   // leading to memory corruption.
1232   //  Bug:
1233   //    p0------------------>neg-->neg-->neg ... -->neg-->neg-->neg->add
1234   //                                                                 /
1235   //    p1->cs->cd->bitcast-----------------------------------------+
1236   //
1237   //  Expected:
1238   //    p0-->neg-->neg-->neg ... -->neg-->neg-->neg------------->add
1239   //                                                             /
1240   //    p1--------------------->cs----------------->cd->bitcast-+
1241   HloComputation::Builder builder(TestName());
1242   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
1243   Shape param_shape = ShapeUtil::MakeShape(F32, {6});
1244   HloInstruction* p0 =
1245       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
1246   HloInstruction* p1 = builder.AddInstruction(
1247       HloInstruction::CreateParameter(1, param_shape, "p1"));
1248   HloInstruction* bitcast =
1249       builder.AddInstruction(HloInstruction::CreateBitcast(shape, p1));
1250   HloInstruction* negate0 = builder.AddInstruction(
1251       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
1252   HloInstruction* negate1 = builder.AddInstruction(
1253       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
1254   HloInstruction* negate2 = builder.AddInstruction(
1255       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
1256   HloInstruction* negate3 = builder.AddInstruction(
1257       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
1258   HloInstruction* negate4 = builder.AddInstruction(
1259       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
1260   HloInstruction* negate5 = builder.AddInstruction(
1261       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
1262   HloInstruction* negate6 = builder.AddInstruction(
1263       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
1264   HloInstruction* negate7 = builder.AddInstruction(
1265       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
1266   HloInstruction* negate8 = builder.AddInstruction(
1267       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate7));
1268   HloInstruction* negate9 = builder.AddInstruction(
1269       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate8));
1270   HloInstruction* add = builder.AddInstruction(
1271       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, bitcast, negate9));
1272 
1273   auto module = CreateNewVerifiedModule();
1274   HloComputation* computation = module->AddEntryComputation(builder.Build());
1275 
1276   HloSchedule schedule(module.get());
1277   schedule.set_sequence(
1278       computation, {p0, p1, bitcast, negate0, negate1, negate2, negate3,
1279                     negate4, negate5, negate6, negate7, negate8, negate9, add});
1280   TF_CHECK_OK(module->set_schedule(schedule));
1281 
1282   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1283                     /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/4);
1284 
1285   EXPECT_EQ(add->operand(0)->shape().layout().memory_space(),
1286             kAlternateMemorySpace);
1287   const auto& instructions =
1288       module->schedule().sequence(module->entry_computation()).instructions();
1289   for (int i = 0; i < instructions.size(); ++i) {
1290     // Expect that there is a negate before and after the CopyStart and there is
1291     // a negate before CopyDone.
1292     if (instructions.at(i)->opcode() == HloOpcode::kCopyStart) {
1293       EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1294       EXPECT_EQ(instructions.at(i + 1)->opcode(), HloOpcode::kNegate);
1295     } else if (instructions.at(i)->opcode() == HloOpcode::kCopyDone) {
1296       EXPECT_EQ(instructions.at(i - 1)->opcode(), HloOpcode::kNegate);
1297     }
1298   }
1299 }
1300 
TEST_P(MemorySpaceAssignmentTest,TupleSelect)1301 TEST_P(MemorySpaceAssignmentTest, TupleSelect) {
1302   // Make sure tuple-select is not optimized away.
1303   absl::string_view hlo_string = R"(
1304   HloModule tuple, is_scheduled=true
1305 
1306   ENTRY %main (a: f32[2], b: f32[2], c: f32[2], d: f32[2], cond: pred[]) -> f32[2] {
1307     %cond = pred[]{:T(128)E(32)} parameter(4)
1308     %token0 = token[] after-all()
1309     %d = f32[2]{0:T(128)} parameter(3)
1310     %c = f32[2]{0:T(128)} parameter(2)
1311     %b = f32[2]{0:T(128)} parameter(1)
1312     %a = f32[2]{0:T(128)} parameter(0)
1313     %tup0 = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple(f32[2]{0:T(128)} %a, f32[2]{0:T(128)} %b)
1314     %tup1 = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple(f32[2]{0:T(128)} %c, f32[2]{0:T(128)} %d)
1315     %s = (f32[2]{0:T(128)}, f32[2]{0:T(128)}) tuple-select(pred[]{:T(128)E(32)} %cond, (f32[2]{0:T(128)}, f32[2]{0:T(128)}) %tup0, (f32[2]{0:T(128)}, f32[2]{0:T(128)}) %tup1)
1316     %gte = f32[2]{0:T(128)} get-tuple-element((f32[2]{0:T(128)}, f32[2]{0:T(128)}) %s), index=0
1317     ROOT %negate = f32[2]{0:T(128)} negate(f32[2]{0:T(128)} %gte)
1318   }
1319   )";
1320 
1321   TF_ASSERT_OK_AND_ASSIGN(auto module,
1322                           ParseAndReturnVerifiedModule(hlo_string));
1323   AssignMemorySpace(module.get());
1324 
1325   EXPECT_THAT(module->entry_computation()->root_instruction(),
1326               op::Negate(op::GetTupleElement(op::TupleSelect())));
1327 }
1328 
TEST_P(MemorySpaceAssignmentTest,AddDependency)1329 TEST_P(MemorySpaceAssignmentTest, AddDependency) {
1330   // Make sure add-dependency is not optimized away.
1331   absl::string_view hlo_string = R"(
1332   HloModule AddDependency, is_scheduled=true
1333 
1334   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
1335     %p = f32[3]{0} parameter(0)
1336     %neg0 = f32[3]{0} negate(f32[3]{0} %p)
1337     %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
1338     %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
1339     %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
1340     %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
1341     %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
1342     %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
1343     %token0 = token[] after-all()
1344     %add_dep = f32[3]{0} add-dependency(f32[3]{0} %p, token[] %token0)
1345     ROOT %add = f32[3]{0} add(f32[3]{0} %add_dep, f32[3]{0} %neg6)
1346   }
1347   )";
1348 
1349   TF_ASSERT_OK_AND_ASSIGN(auto module,
1350                           ParseAndReturnVerifiedModule(hlo_string));
1351   AssignMemorySpace(module.get());
1352 
1353   EXPECT_THAT(module->entry_computation()->root_instruction(),
1354               op::Add(op::AddDependency(), op::Negate()));
1355 }
1356 
TEST_P(MemorySpaceAssignmentTest,WhileAllocationBug)1357 TEST_P(MemorySpaceAssignmentTest, WhileAllocationBug) {
1358   // This test is carefully crafted to include two multiply ops sized [4,3] in a
1359   // while body. For testing purposes, we have provided a BufferIntervalCompare
1360   // such that first multiply, then tanh, then other HloValues will be
1361   // allocated. The memory is sized just enough to fit two [4,3] buffers.
1362   // Because the multiplies in the while body are going to be allocated in the
1363   // alternate memory first, the tanh that is fed inside the while loop should
1364   // not be placed in the alternate memory. Otherwise, we will corrupt memory.
1365   absl::string_view hlo_string = R"(
1366   HloModule WhileAllocationBug, is_scheduled=true
1367 
1368   %WhileBody (body_param: (f32[4,3], f32[])) -> (f32[4,3], f32[]) {
1369     %body_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1370     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=1
1371     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %body_param), index=0
1372     %constant.1 = f32[] constant(1)
1373     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1374     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1375     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.2)
1376     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1377     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1378     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1379     ROOT %tuple = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[] %add)
1380   }
1381 
1382   %WhileCond (cond_param: (f32[4,3], f32[])) -> pred[] {
1383     %cond_param = (f32[4,3]{1,0}, f32[]) parameter(0)
1384     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[]) %cond_param), index=1
1385     %constant = f32[] constant(50)
1386     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1387   }
1388 
1389   ENTRY %Entry (param_iter: f32[4,3], param_data: f32[], p2: f32[4,3]) -> f32[4,3] {
1390     %param_data = f32[] parameter(1)
1391     %param_iter = f32[4,3]{1,0} parameter(0)
1392     %p2 = f32[4,3]{1,0} parameter(2)
1393     %tanh = f32[4,3]{1,0} tanh(f32[4,3]{1,0} %param_iter)
1394     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1395     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1396     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1397     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1398     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1399     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1400     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1401     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %tanh)
1402     %tuple.1 = (f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %tanh, f32[] %param_data)
1403     %while = (f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1404     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[]) %while), index=0
1405     ROOT %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.3, f32[4,3]{1,0} %add.4)
1406   }
1407   )";
1408 
1409   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
1410       [](const MemorySpaceAssignment::BufferInterval& a,
1411          const MemorySpaceAssignment::BufferInterval& b) {
1412         bool a_is_mul =
1413             a.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1414         bool b_is_mul =
1415             b.buffer->defining_instruction()->opcode() == HloOpcode::kMultiply;
1416         if (a_is_mul && !b_is_mul) {
1417           return true;
1418         }
1419         if (!a_is_mul && b_is_mul) {
1420           return false;
1421         }
1422         bool a_is_tanh =
1423             a.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1424         bool b_is_tanh =
1425             b.buffer->defining_instruction()->opcode() == HloOpcode::kTanh;
1426         if (a_is_tanh && !b_is_tanh) {
1427           return true;
1428         }
1429         if (!a_is_tanh && b_is_tanh) {
1430           return false;
1431         }
1432         return a.buffer->id() < b.buffer->id();
1433       };
1434   TF_ASSERT_OK_AND_ASSIGN(auto module,
1435                           ParseAndReturnVerifiedModule(hlo_string));
1436 
1437   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
1438   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
1439                     buffer_interval_compare, &prefetch_interval_picker);
1440 
1441   for (const HloInstruction* instruction :
1442        module->entry_computation()->instructions()) {
1443     if (instruction->opcode() == HloOpcode::kWhile) {
1444       const Shape& while_subshape =
1445           ShapeUtil::GetSubshape(instruction->shape(), {0});
1446       // We expect shape {0} to either be in default memory for the entire while
1447       // loop or there has to be an eviction within the while loop.
1448       if (while_subshape.layout().memory_space() == kAlternateMemorySpace) {
1449         const HloInstruction* body_param =
1450             instruction->while_body()->parameter_instruction(0);
1451         const HloInstruction* gte = nullptr;
1452         for (const HloInstruction* user : body_param->users()) {
1453           if (user->opcode() == HloOpcode::kGetTupleElement &&
1454               user->tuple_index() == 0) {
1455             gte = user;
1456             break;
1457           }
1458         }
1459         EXPECT_NE(gte, nullptr);
1460         const HloInstruction* copy_start = nullptr;
1461         for (const HloInstruction* user : gte->users()) {
1462           if (user->opcode() == HloOpcode::kCopyStart) {
1463             copy_start = user;
1464             break;
1465           }
1466         }
1467         EXPECT_NE(copy_start, nullptr);
1468         const Shape& copy_start_subshape =
1469             ShapeUtil::GetSubshape(copy_start->shape(), {0});
1470 
1471         EXPECT_NE(copy_start_subshape.layout().memory_space(),
1472                   kAlternateMemorySpace);
1473       }
1474     }
1475   }
1476 }
1477 
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoops)1478 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoops) {
1479   absl::string_view hlo_string = R"(
1480   HloModule WhileAllocationBug, is_scheduled=true
1481 
1482   %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1483     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1484     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1485     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1486     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1487     %constant.1 = f32[] constant(1)
1488     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1489     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1490     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1491     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1492     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1493     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1494     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1495   }
1496 
1497   %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1498     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1499     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1500     %constant = f32[] constant(50)
1501     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1502   }
1503 
1504   %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1505     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1506     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1507     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1508     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1509     %constant.1 = f32[] constant(1)
1510     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1511     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1512     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %get-tuple-element.3)
1513     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1514     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.2, f32[4,3]{1,0} %constant.2)
1515     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1516     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1517   }
1518 
1519   %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1520     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1521     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1522     %constant = f32[] constant(50)
1523     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1524   }
1525 
1526   ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1527     %param_iter = f32[] parameter(1)
1528     %param_data = f32[4,3]{1,0} parameter(0)
1529     %p2 = f32[4,3]{1,0} parameter(2)
1530     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1531     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1532     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1533     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1534     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1535     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1536     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1537     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1538     %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1539     %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1540     %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1541     %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1542     %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1543     %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} get-tuple-element.5, f32[] %param_iter)
1544     %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1545     %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1546     ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.6, f32[4,3]{1,0} %add.3)
1547   }
1548   )";
1549 
1550   TF_ASSERT_OK_AND_ASSIGN(auto module,
1551                           ParseAndReturnVerifiedModule(hlo_string));
1552   AssignMemorySpace(module.get());
1553 }
1554 
TEST_P(MemorySpaceAssignmentTest,WhileLiveRangeBug)1555 TEST_P(MemorySpaceAssignmentTest, WhileLiveRangeBug) {
1556   // Tests against while live ranges being incorrect and the verifier
1557   // complaining about a conflict.
1558   absl::string_view hlo_string = R"(
1559   HloModule WhileAllocationBug, is_scheduled=true
1560 
1561   %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1562     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1563     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1564     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1565     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1566     %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1567     %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1568     %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1569     %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1570     %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1571     %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1572     %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1573     %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1574     %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1575     %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1576     %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1577     %constant.1 = f32[] constant(1)
1578     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1579     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1580     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1581     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1582     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1583     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1584     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1585   }
1586 
1587   %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1588     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1589     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1590     %constant = f32[] constant(50)
1591     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1592   }
1593 
1594   ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1595     %param_iter = f32[] parameter(1)
1596     %param_data = f32[4,3]{1,0} parameter(0)
1597     %p2 = f32[4,3]{1,0} parameter(2)
1598     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1599     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1600     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1601     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1602     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1603     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1604     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1605     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1606     %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1607     %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1608     %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1609     %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=1
1610     %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1611     ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %add.3)
1612   }
1613   )";
1614 
1615   TF_ASSERT_OK_AND_ASSIGN(auto module,
1616                           ParseAndReturnVerifiedModule(hlo_string));
1617   AssignMemorySpace(module.get());
1618 }
1619 
TEST_P(MemorySpaceAssignmentTest,ConsecutiveWhileLoopsOneBuffer)1620 TEST_P(MemorySpaceAssignmentTest, ConsecutiveWhileLoopsOneBuffer) {
1621   // Tests against a bug when there are consecutive while loops with one buffer
1622   // (the value doesn't change in the buffer), the parameter can be colored in
1623   // the alternate memory space.
1624   absl::string_view hlo_string = R"(
1625   HloModule WhileAllocationBug, is_scheduled=true
1626 
1627   %WhileBody (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1628     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1629     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1630     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1631     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1632     %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1633     %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1634     %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1635     %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1636     %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1637     %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1638     %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1639     %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1640     %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1641     %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1642     %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1643     %constant.1 = f32[] constant(1)
1644     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1645     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1646     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1647     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1648     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1649     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1650     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1651   }
1652 
1653   %WhileCond (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1654     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1655     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1656     %constant = f32[] constant(50)
1657     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1658   }
1659 
1660   %WhileBody2 (body_param: (f32[4,3], f32[4,3], f32[])) -> (f32[4,3], f32[4,3], f32[]) {
1661     %body_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1662     %get-tuple-element.1 = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=2
1663     %get-tuple-element.2 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=0
1664     %get-tuple-element.3 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %body_param), index=1
1665     %neg10 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %get-tuple-element.2)
1666     %neg11 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg10)
1667     %neg12 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg11)
1668     %neg13 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg12)
1669     %neg14 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg13)
1670     %neg15 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg14)
1671     %neg16 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg15)
1672     %neg17 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg16)
1673     %neg18 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg17)
1674     %neg19 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg18)
1675     %neg20 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg19)
1676     %constant.1 = f32[] constant(1)
1677     %add = f32[] add(f32[] %get-tuple-element.1, f32[] %constant.1)
1678     %constant.2 = f32[4,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 }, { 1, 2, 3 }, { 4, 5, 6 } })
1679     %multiply = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %neg20, f32[4,3]{1,0} %neg20)
1680     %multiply2 = f32[4,3]{1,0} multiply(f32[4,3]{1,0} %multiply, f32[4,3]{1,0} %multiply)
1681     %add.1 = f32[4,3]{1,0} add(f32[4,3]{1,0} get-tuple-element.3, f32[4,3]{1,0} %constant.2)
1682     %add.2 = f32[4,3]{1,0} add(f32[4,3]{1,0} %add.1, f32[4,3]{1,0} %multiply2)
1683     ROOT %tuple = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} %add.2, f32[4,3]{1,0} %get-tuple-element.3, f32[] %add)
1684   }
1685 
1686   %WhileCond2 (cond_param: (f32[4,3], f32[4,3], f32[])) -> pred[] {
1687     %cond_param = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) parameter(0)
1688     %get-tuple-element = f32[] get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %cond_param), index=2
1689     %constant = f32[] constant(50)
1690     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1691   }
1692 
1693   ENTRY %Entry (param_data: f32[4,3], param_iter: f32[], p2: f32[4,3]) -> f32[4,3] {
1694     %param_iter = f32[] parameter(1)
1695     %param_data = f32[4,3]{1,0} parameter(0)
1696     %p2 = f32[4,3]{1,0} parameter(2)
1697     %neg0 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %p2)
1698     %neg1 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg0)
1699     %neg2 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg1)
1700     %neg3 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg2)
1701     %neg4 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg3)
1702     %neg5 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg4)
1703     %neg6 = f32[4,3]{1,0} negate(f32[4,3]{1,0} %neg5)
1704     %add.4 = f32[4,3]{1,0} add(f32[4,3]{1,0} %neg6, f32[4,3]{1,0} %p2)
1705     %tuple.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.4, f32[4,3]{1,0} param_data, f32[] %param_iter)
1706     %while = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1707     %get-tuple-element.4 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while), index=0
1708     %add.3 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.4, f32[4,3]{1,0} %add.4)
1709     %tuple.2 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) tuple(f32[4,3]{1,0} add.3, f32[4,3]{1,0} param_data, f32[] %param_iter)
1710     %while.1 = (f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) while((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %tuple.2), condition=%WhileCond2, body=%WhileBody2
1711     %get-tuple-element.5 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=0
1712     %get-tuple-element.6 = f32[4,3]{1,0} get-tuple-element((f32[4,3]{1,0}, f32[4,3]{1,0}, f32[]) %while.1), index=1
1713     ROOT %add.5 = f32[4,3]{1,0} add(f32[4,3]{1,0} %get-tuple-element.5, f32[4,3]{1,0} %get-tuple-element.6)
1714   }
1715   )";
1716 
1717   TF_ASSERT_OK_AND_ASSIGN(auto module,
1718                           ParseAndReturnVerifiedModule(hlo_string));
1719   AssignMemorySpace(module.get());
1720 }
1721 
TEST_P(MemorySpaceAssignmentTest,WhileCondAliasBug)1722 TEST_P(MemorySpaceAssignmentTest, WhileCondAliasBug) {
1723   // While loop is the root of the entry computation. We should ensure the
1724   // output of the entry computation remains to be in default memory space.
1725   // Test from //third_party/tensorflow/compiler/xla/tests:while_test
1726   // WhileTest.WhileWithPrngScalarResult.
1727   absl::string_view hlo_string = R"(
1728   HloModule WhileWithPrngScalarResult.18, is_scheduled=true
1729 
1730   %fused_computation (param_0.1: s32[6], param_1.3: s32[1], param_2.3: s32[5]) -> s32[6] {
1731     %param_1.3 = s32[1]{0:T(128)} parameter(1)
1732     %constant.2 = s32[]{:T(128)} constant(-2147483648)
1733     %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5
1734     %param_2.3 = s32[5]{0:T(128)} parameter(2)
1735     %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0
1736     %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3)
1737     %param_0.1 = s32[6]{0:T(128)} parameter(0)
1738     ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1)
1739   }
1740 
1741   %body.3 (prev.4: s32[6]) -> s32[6] {
1742     %constant.7 = s32[]{:T(128)} constant(100)
1743     %constant.6 = s32[]{:T(128)} constant(0)
1744     %constant.5 = s32[1]{0:T(128)} constant({1})
1745     %prev.4 = s32[6]{0:T(128)} parameter(0)
1746     %rng.8 = s32[5]{0:T(128)} rng(s32[]{:T(128)} %constant.6, s32[]{:T(128)} %constant.7), distribution=rng_uniform
1747     %neg = s32[1]{0:T(128)} negate(s32[1]{0:T(128)} %constant.5)
1748     ROOT %fusion = s32[6]{0:T(128)} fusion(s32[6]{0:T(128)} %prev.4, s32[1]{0:T(128)} %neg, s32[5]{0:T(128)} %rng.8), kind=kLoop, calls=%fused_computation
1749   }
1750 
1751   %WhileWithPrngScalarResult.11 (prev.12: s32[6]) -> pred[] {
1752     %constant.15 = s32[]{:T(128)} constant(1)
1753     %prev.12 = s32[6]{0:T(128)} parameter(0)
1754     %bitcast.1 = s32[1]{0:T(128)} bitcast(s32[6]{0:T(128)} %prev.12)
1755     %bitcast = s32[]{:T(128)} bitcast(s32[1]{0:T(128)} %bitcast.1)
1756     ROOT %compare.16 = pred[]{:T(128)E(32)} compare(s32[]{:T(128)} %constant.15, s32[]{:T(128)} %bitcast), direction=GT
1757   }
1758 
1759   ENTRY %WhileWithPrngScalarResult.18 () -> s32[6] {
1760     %constant.1 = s32[]{:T(128)} constant(0)
1761     %broadcast.2 = s32[6]{0:T(128)} broadcast(s32[]{:T(128)} %constant.1), dimensions={}
1762     ROOT %while.17 = s32[6]{0:T(128)} while(s32[6]{0:T(128)} %broadcast.2), condition=%WhileWithPrngScalarResult.11, body=%body.3
1763   }
1764   )";
1765 
1766   TF_ASSERT_OK_AND_ASSIGN(auto module,
1767                           ParseAndReturnVerifiedModule(hlo_string));
1768   AssignMemorySpace(module.get());
1769   // Expect the output to have default memory space.
1770   EXPECT_EQ(module->entry_computation()
1771                 ->root_instruction()
1772                 ->shape()
1773                 .layout()
1774                 .memory_space(),
1775             kDefaultMemorySpace);
1776 }
1777 
TEST_P(MemorySpaceAssignmentTest,WhileInPlaceBuffer)1778 TEST_P(MemorySpaceAssignmentTest, WhileInPlaceBuffer) {
1779   // Ensure that a dynamic update slice within a while loop is able to get an
1780   // alternate memory allocation.
1781   absl::string_view hlo_string = R"(
1782   HloModule Module, is_scheduled=true
1783 
1784   fused_computation {
1785     param0 = f32[2,3] parameter(0)
1786     constant.1 = f32[] constant(0)
1787     broadcast = f32[2,1] broadcast(constant.1), dimensions={}
1788     constant.3 = s32[] constant(0)
1789     ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
1790   }
1791 
1792   %WhileBody (body_param: (f32[2,3], f32[2,3], f32[])) -> (f32[2,3], f32[2,3], f32[]) {
1793     %body_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1794     %get-tuple-element.1 = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=2
1795     %get-tuple-element.2 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=0
1796     %get-tuple-element.3 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %body_param), index=1
1797     %fusion = f32[2,3]{1,0} fusion(get-tuple-element.3), kind=kLoop, calls=fused_computation
1798     %multiply = f32[2,3]{1,0} multiply(f32[2,3]{1,0} %get-tuple-element.2, f32[2,3]{1,0} %fusion)
1799     ROOT %tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} %multiply, f32[2,3]{1,0} %fusion, f32[] %get-tuple-element.1)
1800   }
1801 
1802   %WhileCond (cond_param: (f32[2,3], f32[2,3], f32[])) -> pred[] {
1803     %cond_param = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) parameter(0)
1804     %get-tuple-element = f32[] get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %cond_param), index=2
1805     %constant = f32[] constant(50)
1806     ROOT %compare = pred[] compare(f32[] %get-tuple-element, f32[] %constant), direction=LT
1807   }
1808 
1809   ENTRY %Entry (param_data: f32[2,3], param_iter: f32[], p2: f32[2,3]) -> f32[2,3] {
1810     %param_iter = f32[] parameter(1)
1811     %param_data = f32[2,3]{1,0} parameter(0)
1812     %p2 = f32[2,3]{1,0} parameter(2)
1813     %copy1 = f32[2,3]{1,0} copy(param_data)
1814     %copy2 = f32[2,3]{1,0} copy(p2)
1815     %tuple.1 = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) tuple(f32[2,3]{1,0} copy1, f32[2,3]{1,0} copy2, f32[] %param_iter)
1816     %while = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) while((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %tuple.1), condition=%WhileCond, body=%WhileBody
1817     %get-tuple-element.4 = f32[2,3]{1,0} get-tuple-element((f32[2,3]{1,0}, f32[2,3]{1,0}, f32[]) %while), index=0
1818     ROOT %copy3 = f32[2,3]{1,0} copy(get-tuple-element.4)
1819   }
1820   )";
1821 
1822   TF_ASSERT_OK_AND_ASSIGN(auto module,
1823                           ParseAndReturnVerifiedModule(hlo_string));
1824   AssignMemorySpace(module.get());
1825   const HloInstruction* while_op =
1826       module->entry_computation()->GetInstructionWithName("while");
1827   if (GetParam()) {
1828     EXPECT_EQ(
1829         ShapeUtil::GetSubshape(while_op->shape(), {1}).layout().memory_space(),
1830         kAlternateMemorySpace);
1831   }
1832 }
1833 
TEST_P(MemorySpaceAssignmentTest,WhileSharedBufferVerificationBug)1834 TEST_P(MemorySpaceAssignmentTest, WhileSharedBufferVerificationBug) {
1835   // Tests a spurious verification failure when a while has the same value
1836   // passed in twice (copy0) and that value is evicted within the while loop.
1837   absl::string_view hlo_string = R"(
1838   HloModule module, is_scheduled=true
1839 
1840   while_cond {
1841     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1842     ROOT gte = pred[] get-tuple-element(p0), index=3
1843   }
1844 
1845   while_body {
1846     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1847     gte0 = f32[3]{0} get-tuple-element(p0), index=0
1848     gte1 = f32[3]{0} get-tuple-element(p0), index=1
1849     gte2 = f32[3]{0} get-tuple-element(p0), index=2
1850     gte3 = pred[] get-tuple-element(p0), index=3
1851     add = f32[3]{0} add(gte0, gte0)
1852     negate0 = f32[3]{0} negate(add)
1853     negate1 = f32[3]{0} negate(negate0)
1854     negate2 = f32[3]{0} negate(negate1)
1855     negate3 = f32[3]{0} negate(negate2)
1856     negate4 = f32[3]{0} negate(negate3)
1857     negate5 = f32[3]{0} negate(negate4)
1858     negate6 = f32[3]{0} negate(negate5)
1859     negate7 = f32[3]{0} negate(negate6)
1860     negate8 = f32[3]{0} negate(negate7)
1861     negate9 = f32[3]{0} negate(negate8)
1862     negate10 = f32[3]{0} negate(negate9)
1863     negate11 = f32[3]{0} negate(negate10)
1864     negate12 = f32[3]{0} negate(negate11)
1865     negate13 = f32[3]{0} negate(negate12)
1866     negate14 = f32[3]{0} negate(negate13)
1867     negate15 = f32[3]{0} negate(negate14)
1868     negate16 = f32[3]{0} negate(negate15)
1869     ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, gte0, negate16, gte3)
1870   }
1871 
1872   ENTRY entry {
1873     p0 = f32[3]{0} parameter(0)
1874     p1 = pred[] parameter(1)
1875     copy0 = f32[3]{0} copy(p0)
1876     copy1 = f32[3]{0} copy(p0)
1877     tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
1878     while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1879     ROOT gte = f32[3]{0} get-tuple-element(while), index=2
1880   }
1881   )";
1882   TF_ASSERT_OK_AND_ASSIGN(auto module,
1883                           ParseAndReturnVerifiedModule(hlo_string));
1884   AssignMemorySpace(module.get());
1885 }
1886 
TEST_P(MemorySpaceAssignmentTest,b172243149)1887 TEST_P(MemorySpaceAssignmentTest, b172243149) {
1888   // Tests for the failure in b/172243149, where if we skip processing
1889   // non-copy allocations that are in default memory can actually cause
1890   // failures. In this case, the problem tensor is copy0, where it is fed to
1891   // both negate, while, and add0. The copy0->negate dependency can be allocated
1892   // in the alternate memory. Then the algorithm attempts to place the
1893   // copy0->while edge in the alternate memory, but since this value isn't used
1894   // in the while loop, it won't get an alternate memory allocation. Finally for
1895   // the copy0->add0 edge, the algorithm will actually replace it with
1896   // while{0}->add0, since this is equivalent and while is defined later than
1897   // copy0. However, if we actually skip processing this while{0}->add0
1898   // allocation, we won't replace this edge, and will end up with the
1899   // copy0->add0 edge, which illegally extends the lifetime of the alternate
1900   // memory buffer in copy0.
1901   absl::string_view hlo_string = R"(
1902   HloModule module, is_scheduled=true
1903 
1904   while_cond {
1905     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1906     ROOT gte = pred[] get-tuple-element(p0), index=3
1907   }
1908 
1909   while_body {
1910     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
1911     gte0 = f32[3]{0} get-tuple-element(p0), index=0
1912     gte1 = f32[3]{0} get-tuple-element(p0), index=1
1913     gte2 = f32[3]{0} get-tuple-element(p0), index=2
1914     gte3 = pred[] get-tuple-element(p0), index=3
1915     add = f32[3]{0} add(gte1, gte2)
1916     negate0 = f32[3]{0} negate(add)
1917     negate1 = f32[3]{0} negate(negate0)
1918     negate2 = f32[3]{0} negate(negate1)
1919     negate3 = f32[3]{0} negate(negate2)
1920     negate4 = f32[3]{0} negate(negate3)
1921     negate5 = f32[3]{0} negate(negate4)
1922     negate6 = f32[3]{0} negate(negate5)
1923     negate7 = f32[3]{0} negate(negate6)
1924     negate8 = f32[3]{0} negate(negate7)
1925     negate9 = f32[3]{0} negate(negate8)
1926     negate10 = f32[3]{0} negate(negate9)
1927     negate11 = f32[3]{0} negate(negate10)
1928     negate12 = f32[3]{0} negate(negate11)
1929     negate13 = f32[3]{0} negate(negate12)
1930     negate14 = f32[3]{0} negate(negate13)
1931     negate15 = f32[3]{0} negate(negate14)
1932     negate16 = f32[3]{0} negate(negate15)
1933     ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, add, negate16, gte3)
1934   }
1935 
1936   ENTRY entry {
1937     p0 = f32[3]{0} parameter(0)
1938     p1 = pred[] parameter(1)
1939     copy0 = f32[3]{0} copy(p0)
1940     copy1 = f32[3]{0} copy(p0)
1941     copy2 = f32[3]{0} copy(p0)
1942     negate = f32[3]{0} negate(copy0)
1943     tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, copy2, p1)
1944     while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
1945     gte = f32[3]{0} get-tuple-element(while), index=2
1946     add0 = f32[3]{0} add(negate, copy0)
1947     ROOT add1 = f32[3]{0} add(add0, gte)
1948   }
1949   )";
1950   TF_ASSERT_OK_AND_ASSIGN(auto module,
1951                           ParseAndReturnVerifiedModule(hlo_string));
1952   AssignMemorySpace(module.get());
1953 }
1954 
TEST_P(MemorySpaceAssignmentTest,ControlPredecessorsBug)1955 TEST_P(MemorySpaceAssignmentTest, ControlPredecessorsBug) {
1956   // Having control_predecessors on an HLO was preventing us from DCEing an op
1957   // that doesn't have any users (tuple.1). The scheduler assumes the graph is
1958   // fully DCEed, which causes some instructions not to be scheduled.
1959   absl::string_view hlo_string = R"(
1960   HloModule sort.16, is_scheduled=true
1961 
1962   ENTRY %sort.16 (param.0.1: s32[1], param.1.2: f32[1], param.2.3: u32[1], param.3.4: s32[1]) -> (s32[1], f32[1], u32[1], s32[1]) {
1963     %param.3.4 = s32[1]{0:T(128)} parameter(3)
1964     %param.2.3 = u32[1]{0:T(128)} parameter(2)
1965     %param.1.2 = f32[1]{0:T(128)} parameter(1)
1966     %param.0.1 = s32[1]{0:T(128)} parameter(0)
1967     %tuple.1 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %param.0.1, f32[1]{0:T(128)} %param.1.2, u32[1]{0:T(128)} %param.2.3, s32[1]{0:T(128)} %param.3.4), control-predecessors={%param.0.1}
1968     %get-tuple-element.4 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=0
1969     %get-tuple-element.5 = f32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=1
1970     %get-tuple-element.6 = u32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=2
1971     %get-tuple-element.7 = s32[1]{0:T(128)} get-tuple-element((s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) %tuple.1), index=3
1972     %copy.4 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.4)
1973     %copy.5 = f32[1]{0:T(128)} copy(f32[1]{0:T(128)} %get-tuple-element.5)
1974     %copy.6 = u32[1]{0:T(128)} copy(u32[1]{0:T(128)} %get-tuple-element.6)
1975     %copy.7 = s32[1]{0:T(128)} copy(s32[1]{0:T(128)} %get-tuple-element.7)
1976     ROOT %tuple.2 = (s32[1]{0:T(128)}, f32[1]{0:T(128)}, u32[1]{0:T(128)}, s32[1]{0:T(128)}) tuple(s32[1]{0:T(128)} %copy.4, f32[1]{0:T(128)} %copy.5, u32[1]{0:T(128)} %copy.6, s32[1]{0:T(128)} %copy.7)
1977 }
1978   )";
1979 
1980   TF_ASSERT_OK_AND_ASSIGN(auto module,
1981                           ParseAndReturnVerifiedModule(hlo_string));
1982   AssignMemorySpace(module.get());
1983 }
1984 
TEST_P(MemorySpaceAssignmentTest,ConditionalShouldBeAllocatedInAlternateMem)1985 TEST_P(MemorySpaceAssignmentTest, ConditionalShouldBeAllocatedInAlternateMem) {
1986   // Checks if simple conditionals get alternate memory allocations.
1987   absl::string_view hlo_string = R"(
1988   HloModule CondAllocation, is_scheduled=true
1989 
1990   true_computation {
1991     p0 = (f32[3]{0}) parameter(0)
1992     gte = f32[3]{0} get-tuple-element(p0), index=0
1993     ROOT neg1 = f32[3]{0} negate(gte)
1994   }
1995 
1996   false_computation {
1997     p0 = (f32[3]{0}) parameter(0)
1998     gte = f32[3]{0} get-tuple-element(p0), index=0
1999     ROOT neg2 = f32[3]{0} negate(gte)
2000   }
2001 
2002   ENTRY entry {
2003     p0 = f32[3]{0} parameter(0)
2004     p1 = pred[] parameter(1)
2005     copy = f32[3]{0} copy(p0)
2006     tuple = (f32[3]{0}) tuple(copy)
2007     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
2008   }
2009   )";
2010   TF_ASSERT_OK_AND_ASSIGN(auto module,
2011                           ParseAndReturnVerifiedModule(hlo_string));
2012   AssignMemorySpace(module.get());
2013 
2014   if (GetParam()) {
2015     // Check that copy and gtes got alternate memory allocations.
2016     auto copy =
2017         module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2018     EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2019     auto neg1 = module->GetComputationWithName("true_computation")
2020                     ->GetInstructionWithName("neg1");
2021     auto neg1_operand = neg1->operand(0);
2022     EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2023               kAlternateMemorySpace);
2024     auto neg2 = module->GetComputationWithName("false_computation")
2025                     ->GetInstructionWithName("neg2");
2026     auto neg2_operand = neg2->operand(0);
2027     EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2028               kAlternateMemorySpace);
2029   }
2030 }
2031 
TEST_P(MemorySpaceAssignmentTest,ConditionalAvoidsUnnecessaryPrefetch)2032 TEST_P(MemorySpaceAssignmentTest, ConditionalAvoidsUnnecessaryPrefetch) {
2033   // Checks if we avoid unnecessary allocation in alternate memory if the input
2034   // won't be used in the computation for a long time.
2035   absl::string_view hlo_string = R"(
2036   HloModule CondAllocation, is_scheduled=true
2037 
2038   true_computation {
2039     p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2040     gte0 = f32[3]{0} get-tuple-element(p0), index=0
2041     neg0 = f32[3]{0} negate(gte0)
2042     neg1 = f32[3]{0} negate(neg0)
2043     neg2 = f32[3]{0} negate(neg1)
2044     neg3 = f32[3]{0} negate(neg2)
2045     neg4 = f32[3]{0} negate(neg3)
2046     neg5 = f32[3]{0} negate(neg4)
2047     neg6 = f32[3]{0} negate(neg5)
2048     neg7 = f32[3]{0} negate(neg6)
2049     neg8 = f32[3]{0} negate(neg7)
2050     neg9 = f32[3]{0} negate(neg8)
2051     gte1 = f32[3]{0} get-tuple-element(p0), index=1
2052     ROOT add = f32[3]{0} add(neg9, gte1)
2053   }
2054 
2055   false_computation {
2056     p0 = (f32[3]{0}) parameter(0)
2057     gte = f32[3]{0} get-tuple-element(p0), index=0
2058     ROOT neg = f32[3]{0} negate(gte)
2059   }
2060 
2061   ENTRY entry {
2062     p0 = f32[3]{0} parameter(0)
2063     p1 = pred[] parameter(1)
2064     copy0 = f32[3]{0} copy(p0)
2065     copy1 = f32[3]{0} copy(p0)
2066     tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2067     tuple1 = (f32[3]{0}) tuple(copy0)
2068     ROOT conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2069   }
2070   )";
2071   TF_ASSERT_OK_AND_ASSIGN(auto module,
2072                           ParseAndReturnVerifiedModule(hlo_string));
2073   AssignMemorySpace(module.get());
2074 
2075   if (GetParam()) {
2076     // Check that copy1 doesn't get unnecessarily allocated in alternate mem
2077     // (due to long negate chain in true_computation) but is prefetched before
2078     // add.
2079     auto copy0 =
2080         module->GetComputationWithName("entry")->GetInstructionWithName(
2081             "copy0");
2082     EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2083     auto copy1 =
2084         module->GetComputationWithName("entry")->GetInstructionWithName(
2085             "copy1");
2086     EXPECT_EQ(copy1->shape().layout().memory_space(), kDefaultMemorySpace);
2087     auto add = module->GetComputationWithName("true_computation")
2088                    ->GetInstructionWithName("add");
2089     auto add_operand = add->operand(1);
2090     EXPECT_EQ(add_operand->shape().layout().memory_space(),
2091               kAlternateMemorySpace);
2092   }
2093 }
2094 
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUse)2095 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUse) {
2096   // Make sure there is an evict when there is a conditional use followed by
2097   // another use.
2098   absl::string_view hlo_string = R"(
2099   HloModule CondAllocation, is_scheduled=true
2100 
2101   true_computation {
2102     p0 = (f32[3]{0}, f32[3]{0}) parameter(0)
2103     gte0 = f32[3]{0} get-tuple-element(p0), index=0
2104     gte1 = f32[3]{0} get-tuple-element(p0), index=1
2105     add0 = f32[3]{0} add(gte0, gte1)
2106     neg0 = f32[3]{0} negate(add0)
2107     neg1 = f32[3]{0} negate(neg0)
2108     neg2 = f32[3]{0} negate(neg1)
2109     neg3 = f32[3]{0} negate(neg2)
2110     neg4 = f32[3]{0} negate(neg3)
2111     neg5 = f32[3]{0} negate(neg4)
2112     neg6 = f32[3]{0} negate(neg5)
2113     neg7 = f32[3]{0} negate(neg6)
2114     neg8 = f32[3]{0} negate(neg7)
2115     ROOT neg9 = f32[3]{0} negate(neg8)
2116   }
2117 
2118   false_computation {
2119     p0 = (f32[3]{0}) parameter(0)
2120     gte = f32[3]{0} get-tuple-element(p0), index=0
2121     ROOT neg = f32[3]{0} negate(gte)
2122   }
2123 
2124   ENTRY entry {
2125     p0 = f32[3]{0} parameter(0)
2126     p1 = pred[] parameter(1)
2127     copy0 = f32[3]{0} copy(p0)
2128     copy1 = f32[3]{0} copy(p0)
2129     tuple0 = (f32[3]{0}, f32[3]{0}) tuple(copy0, copy1)
2130     tuple1 = (f32[3]{0}) tuple(copy0)
2131     conditional = f32[3]{0} conditional(p1, tuple0, tuple1), true_computation=true_computation, false_computation=false_computation
2132     ROOT add1 = f32[3]{0} add(copy1, conditional)
2133   }
2134   )";
2135   TF_ASSERT_OK_AND_ASSIGN(auto module,
2136                           ParseAndReturnVerifiedModule(hlo_string));
2137   AssignMemorySpace(module.get());
2138 
2139   if (GetParam()) {
2140     // Make sure the copy1->add edge is in alternate memory. Before conditional,
2141     // this should be evicted to default memory and neg uses the input from
2142     // default memory.
2143     auto copy1 =
2144         module->GetComputationWithName("entry")->GetInstructionWithName(
2145             "copy1");
2146     EXPECT_EQ(copy1->shape().layout().memory_space(), kAlternateMemorySpace);
2147     auto add0 = module->GetComputationWithName("true_computation")
2148                     ->GetInstructionWithName("add0");
2149     auto add0_operand = add0->operand(1);
2150     EXPECT_EQ(add0_operand->shape().layout().memory_space(),
2151               kAlternateMemorySpace);
2152     auto add1 =
2153         module->GetComputationWithName("entry")->GetInstructionWithName("add1");
2154     auto add1_operand = add1->operand(0);
2155     EXPECT_EQ(add1_operand->shape().layout().memory_space(),
2156               kDefaultMemorySpace);
2157     EXPECT_EQ(add1_operand->opcode(), HloOpcode::kCopyDone);
2158   }
2159 }
2160 
TEST_P(MemorySpaceAssignmentTest,ConditionalMultiUseInWhile)2161 TEST_P(MemorySpaceAssignmentTest, ConditionalMultiUseInWhile) {
2162   absl::string_view hlo_string = R"(
2163   HloModule CondAllocation, is_scheduled=true
2164 
2165   true_computation {
2166     p0 = (f32[3]{0}) parameter(0)
2167     gte = f32[3]{0} get-tuple-element(p0), index=0
2168     ROOT neg1 = f32[3]{0} negate(gte)
2169   }
2170 
2171   false_computation {
2172     p0 = (f32[3]{0}) parameter(0)
2173     gte = f32[3]{0} get-tuple-element(p0), index=0
2174     ROOT neg2 = f32[3]{0} negate(gte)
2175   }
2176 
2177   while_cond {
2178     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2179     ROOT gte = pred[] get-tuple-element(p0), index=2
2180   }
2181 
2182   while_body {
2183     p0 = (f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
2184     gte0 = f32[3]{0} get-tuple-element(p0), index=0
2185     gte1 = f32[3]{0} get-tuple-element(p0), index=1
2186     gte2 = pred[] get-tuple-element(p0), index=2
2187     cond_tuple = (f32[3]{0}) tuple(gte0)
2188     conditional = f32[3]{0} conditional(gte2, cond_tuple, cond_tuple), true_computation=true_computation, false_computation=false_computation
2189     add = f32[3]{0} add(conditional, gte1)
2190     neg0 = f32[3]{0} negate(add)
2191     neg1 = f32[3]{0} negate(neg0)
2192     ROOT tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(gte0, neg1, gte2)
2193   }
2194 
2195   ENTRY entry {
2196     p0 = f32[3]{0} parameter(0)
2197     p1 = pred[] parameter(1)
2198     copy0 = f32[3]{0} copy(p0)
2199     copy1 = f32[3]{0} copy(p0)
2200     tuple = (f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy1, p1)
2201     while = (f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
2202     ROOT gte = f32[3]{0} get-tuple-element(while), index=1
2203   }
2204   )";
2205   TF_ASSERT_OK_AND_ASSIGN(auto module,
2206                           ParseAndReturnVerifiedModule(hlo_string));
2207   AssignMemorySpace(module.get());
2208 
2209   if (GetParam()) {
2210     // Make sure copy1/while{0}/cond_tuple{0} gets alternate memory allocation.
2211     // This will force an eviction and a prefetch for while body root.
2212     auto copy0 =
2213         module->GetComputationWithName("entry")->GetInstructionWithName(
2214             "copy0");
2215     EXPECT_EQ(copy0->shape().layout().memory_space(), kAlternateMemorySpace);
2216     auto conditional = module->GetComputationWithName("while_body")
2217                            ->GetInstructionWithName("conditional");
2218     auto conditional_operand = conditional->operand(1);
2219     EXPECT_EQ(ShapeUtil::GetSubshape(conditional_operand->shape(), {0})
2220                   .layout()
2221                   .memory_space(),
2222               kAlternateMemorySpace);
2223     auto while_root =
2224         module->GetComputationWithName("while_body")->root_instruction();
2225     auto while_root_operand = while_root->operand(0);
2226     EXPECT_THAT(
2227         while_root_operand,
2228         op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2229                       op::AsyncCopy(kDefaultMemorySpace, kAlternateMemorySpace,
2230                                     op::GetTupleElement(op::Parameter(0)))));
2231   }
2232 }
2233 
TEST_P(MemorySpaceAssignmentTest,NestedConditional)2234 TEST_P(MemorySpaceAssignmentTest, NestedConditional) {
2235   absl::string_view hlo_string = R"(
2236   HloModule CondAllocation, is_scheduled=true
2237 
2238   true_computation2 {
2239     p0 = (f32[3]{0}) parameter(0)
2240     gte = f32[3]{0} get-tuple-element(p0), index=0
2241     ROOT neg1 = f32[3]{0} negate(gte)
2242   }
2243 
2244   false_computation2 {
2245     p0 = (f32[3]{0}) parameter(0)
2246     gte = f32[3]{0} get-tuple-element(p0), index=0
2247     ROOT neg2 = f32[3]{0} negate(gte)
2248   }
2249 
2250   true_computation1 {
2251     p0 = (f32[3]{0}) parameter(0)
2252     gte = f32[3]{0} get-tuple-element(p0), index=0
2253     slice = f32[1]{0} slice(gte), slice={[0:1]}
2254     bitcast = f32[] bitcast(slice)
2255     constant = f32[] constant(0.0)
2256     compare = pred[] compare(bitcast, constant), direction=GT
2257     ROOT conditional = f32[3]{0} conditional(compare, p0, p0), true_computation=true_computation2, false_computation=false_computation2
2258   }
2259 
2260   false_computation1 {
2261     p0 = (f32[3]{0}) parameter(0)
2262     gte = f32[3]{0} get-tuple-element(p0), index=0
2263     ROOT neg3 = f32[3]{0} negate(gte)
2264   }
2265 
2266 
2267   ENTRY entry {
2268     p0 = f32[3]{0} parameter(0)
2269     p1 = pred[] parameter(1)
2270     copy = f32[3]{0} copy(p0)
2271     tuple = (f32[3]{0}) tuple(copy)
2272     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2273   }
2274   )";
2275   TF_ASSERT_OK_AND_ASSIGN(auto module,
2276                           ParseAndReturnVerifiedModule(hlo_string));
2277   AssignMemorySpace(module.get());
2278 
2279   if (GetParam()) {
2280     // Make sure alternate memory allocation gets propagated into both levels of
2281     // conditional.
2282     auto copy =
2283         module->GetComputationWithName("entry")->GetInstructionWithName("copy");
2284     EXPECT_EQ(copy->shape().layout().memory_space(), kAlternateMemorySpace);
2285     auto neg1_operand = module->GetComputationWithName("true_computation2")
2286                             ->GetInstructionWithName("neg1")
2287                             ->operand(0);
2288     auto neg2_operand = module->GetComputationWithName("false_computation2")
2289                             ->GetInstructionWithName("neg2")
2290                             ->operand(0);
2291     auto neg3_operand = module->GetComputationWithName("false_computation1")
2292                             ->GetInstructionWithName("neg3")
2293                             ->operand(0);
2294     EXPECT_EQ(neg1_operand->shape().layout().memory_space(),
2295               kAlternateMemorySpace);
2296     EXPECT_EQ(neg2_operand->shape().layout().memory_space(),
2297               kAlternateMemorySpace);
2298     EXPECT_EQ(neg3_operand->shape().layout().memory_space(),
2299               kAlternateMemorySpace);
2300   }
2301 }
2302 
TEST_P(MemorySpaceAssignmentTest,NestedConditionalBufferReuseVerificationBug)2303 TEST_P(MemorySpaceAssignmentTest, NestedConditionalBufferReuseVerificationBug) {
2304   // Tests a spurious verification failure when there are nested conditionals
2305   // and the innermost conditional computation reuses the buffer. Here, both the
2306   // parameter of true_computation2 and neg2 will get the same buffer. Make sure
2307   // that verification doesn't claim a failure in this case.
2308   absl::string_view hlo_string = R"(
2309   HloModule CondAllocation, is_scheduled=true
2310 
2311   true_computation2 {
2312     p0 = (f32[3]{0}) parameter(0)
2313     gte = f32[3]{0} get-tuple-element(p0), index=0
2314     neg1 = f32[3]{0} negate(gte)
2315     neg2 = f32[3]{0} negate(neg1)
2316     ROOT neg3 = f32[3]{0} negate(neg2)
2317   }
2318 
2319   false_computation2 {
2320     p0 = (f32[3]{0}) parameter(0)
2321     gte = f32[3]{0} get-tuple-element(p0), index=0
2322     ROOT neg4 = f32[3]{0} negate(gte)
2323   }
2324 
2325   true_computation1 {
2326     p0 = (f32[3]{0}) parameter(0)
2327     gte = f32[3]{0} get-tuple-element(p0), index=0
2328     slice = f32[1]{0} slice(gte), slice={[0:1]}
2329     bitcast = f32[] bitcast(slice)
2330     constant = f32[] constant(0.0)
2331     compare = pred[] compare(bitcast, constant), direction=GT
2332     tuple = (f32[3]{0}) tuple(gte)
2333     ROOT conditional = f32[3]{0} conditional(compare, tuple, tuple), true_computation=true_computation2, false_computation=false_computation2
2334   }
2335 
2336   false_computation1 {
2337     p0 = (f32[3]{0}) parameter(0)
2338     gte = f32[3]{0} get-tuple-element(p0), index=0
2339     ROOT neg5 = f32[3]{0} negate(gte)
2340   }
2341 
2342   ENTRY entry {
2343     p0 = f32[3]{0} parameter(0)
2344     p1 = pred[] parameter(1)
2345     copy = f32[3]{0} copy(p0)
2346     tuple = (f32[3]{0}) tuple(copy)
2347     ROOT conditional = f32[3]{0} conditional(p1, tuple, tuple), true_computation=true_computation1, false_computation=false_computation1
2348   }
2349   )";
2350   TF_ASSERT_OK_AND_ASSIGN(auto module,
2351                           ParseAndReturnVerifiedModule(hlo_string));
2352   AssignMemorySpace(module.get());
2353 }
2354 
TEST_P(MemorySpaceAssignmentTest,RequestIdentifierShouldNotBeAllocatedInAlternateMem)2355 TEST_P(MemorySpaceAssignmentTest,
2356        RequestIdentifierShouldNotBeAllocatedInAlternateMem) {
2357   // Ensure that request identifier returned by Send/Recv HLOs are not allocated
2358   // in the alternate memory.
2359   absl::string_view hlo_string = R"(
2360   HloModule SendRecv, is_scheduled=true
2361 
2362   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2363     %p = f32[3]{0} parameter(0)
2364     %after-all = token[] after-all()
2365     %recv.4 = (f32[3]{0}, u32[], token[]) recv(token[] %after-all), channel_id=7
2366     %recv-done.4 = (f32[3]{0}, token[]) recv-done((f32[3]{0}, u32[], token[]) %recv.4), channel_id=7
2367     %token.1 = token[] get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=1
2368     %data = f32[3]{0} get-tuple-element((f32[3]{0}, token[]) %recv-done.4), index=0
2369     %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %data, token[] %token.1), channel_id=2
2370     %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2371     ROOT %add = f32[3]{0} add(f32[3]{0} %p, f32[3]{0} %data)
2372   }
2373   )";
2374 
2375   TF_ASSERT_OK_AND_ASSIGN(auto module,
2376                           ParseAndReturnVerifiedModule(hlo_string));
2377   AssignMemorySpace(module.get());
2378 
2379   for (const HloInstruction* instruction :
2380        module->entry_computation()->instructions()) {
2381     if (instruction->opcode() == HloOpcode::kSend ||
2382         instruction->opcode() == HloOpcode::kRecv) {
2383       const Shape& request_identifier_shape =
2384           ShapeUtil::GetSubshape(instruction->shape(), {1});
2385       EXPECT_NE(request_identifier_shape.layout().memory_space(),
2386                 kAlternateMemorySpace);
2387     }
2388   }
2389 }
2390 
TEST_P(MemorySpaceAssignmentTest,SendDoneShouldHaveSendOperand)2391 TEST_P(MemorySpaceAssignmentTest, SendDoneShouldHaveSendOperand) {
2392   // Ensure that SendDone has only a Send operand.
2393   absl::string_view hlo_string = R"(
2394   HloModule SendRecv, is_scheduled=true
2395 
2396   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2397     %p0 = f32[3]{0} parameter(0)
2398     %p1 = f32[3]{0} parameter(1)
2399     %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2400     %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2401     %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2402     %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2403     %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2404     %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2405     %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2406     %after-all = token[] after-all()
2407     %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2408     %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2409     ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2410   }
2411   )";
2412 
2413   TF_ASSERT_OK_AND_ASSIGN(auto module,
2414                           ParseAndReturnVerifiedModule(hlo_string));
2415   AssignMemorySpace(module.get());
2416 }
2417 
TEST_P(MemorySpaceAssignmentTest,SendAndSendDoneShouldGetSameAllocation)2418 TEST_P(MemorySpaceAssignmentTest, SendAndSendDoneShouldGetSameAllocation) {
2419   // Ensure that Send and SendDone have the same allocation.
2420   absl::string_view hlo_string = R"(
2421   HloModule SendRecv, is_scheduled=true
2422 
2423   ENTRY %AddDependency (p: f32[3]) -> f32[3] {
2424     %p0 = f32[3]{0} parameter(0)
2425     %p1 = f32[3]{0} parameter(1)
2426     %after-all = token[] after-all()
2427     %send = (f32[3]{0}, u32[], token[]) send(f32[3]{0} %p0, token[] %after-all), channel_id=2
2428     %neg0 = f32[3]{0} negate(f32[3]{0} %p1)
2429     %neg1 = f32[3]{0} negate(f32[3]{0} %neg0)
2430     %neg2 = f32[3]{0} negate(f32[3]{0} %neg1)
2431     %neg3 = f32[3]{0} negate(f32[3]{0} %neg2)
2432     %neg4 = f32[3]{0} negate(f32[3]{0} %neg3)
2433     %neg5 = f32[3]{0} negate(f32[3]{0} %neg4)
2434     %neg6 = f32[3]{0} negate(f32[3]{0} %neg5)
2435     %send-done = token[] send-done((f32[3]{0}, u32[], token[]) %send), channel_id=2
2436     ROOT %add = f32[3]{0} add(f32[3]{0} %p0, f32[3]{0} %neg6)
2437   }
2438   )";
2439 
2440   TF_ASSERT_OK_AND_ASSIGN(auto module,
2441                           ParseAndReturnVerifiedModule(hlo_string));
2442   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
2443                     /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/4);
2444 }
2445 
TEST_P(MemorySpaceAssignmentTest,LastUseOpt)2446 TEST_P(MemorySpaceAssignmentTest, LastUseOpt) {
2447   // Test that checks the last use optimization. It uses two buffers that should
2448   // be placed in alternate memory.
2449   //
2450   //      +-------+
2451   //     /         \
2452   // add1--->sub1   +-------->mul2
2453   //              mul1===>add2
2454   //
2455   // Without the last use optimization, the mul1 buffer will be assigned first
2456   // (because it is larger) to offset 0. Then, add1 will be scheduled for the
2457   // add1 to sub1 segment. Because offset 0 is available, it will get that
2458   // offset. But because offset 0 is not available in the sub1 to mul2 offset,
2459   // it will end up in unnecessary copies. With the last use optimization, these
2460   // copies can be optimized away.
2461   HloComputation::Builder builder(TestName());
2462   Shape shape1 = ShapeUtil::MakeShape(F32, {2, 3});
2463   Shape shape2 = ShapeUtil::MakeShape(F32, {2, 4});
2464   PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}});
2465   HloInstruction* p0 =
2466       builder.AddInstruction(HloInstruction::CreateParameter(0, shape1, "p0"));
2467   HloInstruction* p1 =
2468       builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2469   HloInstruction* add1 = builder.AddInstruction(
2470       HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, p0, p0));
2471   HloInstruction* sub1 = builder.AddInstruction(
2472       HloInstruction::CreateBinary(shape1, HloOpcode::kSubtract, p0, add1));
2473   HloInstruction* mul1 = builder.AddInstruction(
2474       HloInstruction::CreateBinary(shape2, HloOpcode::kMultiply, p1, p1));
2475   HloInstruction* add2 = builder.AddInstruction(
2476       HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, mul1, p1));
2477   HloInstruction* mul2 = builder.AddInstruction(
2478       HloInstruction::CreateBinary(shape1, HloOpcode::kMultiply, add1, sub1));
2479   HloInstruction* padding_value = builder.AddInstruction(
2480       HloInstruction::CreateConstant(LiteralUtil::Zero(F32)));
2481   HloInstruction* padded_mul2 = builder.AddInstruction(
2482       HloInstruction::CreatePad(shape2, mul2, padding_value, padding_config));
2483   HloInstruction* add3 = builder.AddInstruction(
2484       HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, add2, padded_mul2));
2485 
2486   auto module = CreateNewVerifiedModule();
2487   HloComputation* computation = module->AddEntryComputation(builder.Build());
2488 
2489   HloSchedule schedule(module.get());
2490   schedule.set_sequence(computation, {p0, p1, add1, sub1, mul1, add2, mul2,
2491                                       padding_value, padded_mul2, add3});
2492   TF_CHECK_OK(module->set_schedule(schedule));
2493 
2494   AssignMemorySpace(module.get());
2495 
2496   EXPECT_THAT(
2497       mul2,
2498       op::Multiply(
2499           op::Add(op::Parameter(0), op::Parameter(0)),
2500           op::Subtract(op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
2501                                      op::Parameter(0)),
2502                        op::Add(op::Parameter(0), op::Parameter(0)))));
2503 }
2504 
TEST_P(MemorySpaceAssignmentTest,CopyOrdering)2505 TEST_P(MemorySpaceAssignmentTest, CopyOrdering) {
2506   // Test to make sure the CopyStarts follow the same CopyDone order. The shapes
2507   // are picked in increasing order to exploit the fact that heap simulator
2508   // processes larger tensors first. This checks the ability of the compiler to
2509   // reschedule:
2510   //
2511   //  CS1            CD1
2512   //   +--------------+
2513   //    +-----------+
2514   //   CS2         CD2
2515   //
2516   // into:
2517   //
2518   //    CS1          CD1
2519   //     +------------+
2520   //    +-----------+
2521   //   CS2         CD2
2522   HloComputation::Builder builder(TestName());
2523   Shape shape1 = ShapeUtil::MakeShape(F32, {2, 1});
2524   Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
2525   Shape shape3 = ShapeUtil::MakeShape(F32, {2, 3});
2526   Shape shape4 = ShapeUtil::MakeShape(F32, {2, 4});
2527   PaddingConfig padding_config = MakeEdgePaddingConfig({{0, 0}, {0, 1}});
2528   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape3, shape4});
2529   HloInstruction* p0 = builder.AddInstruction(
2530       HloInstruction::CreateParameter(0, tuple_shape, "p"));
2531   HloInstruction* p4 = builder.AddInstruction(
2532       HloInstruction::CreateGetTupleElement(shape4, p0, 1));
2533   HloInstruction* p3 = builder.AddInstruction(
2534       HloInstruction::CreateGetTupleElement(shape3, p0, 0));
2535   HloInstruction* p2 =
2536       builder.AddInstruction(HloInstruction::CreateParameter(2, shape2, "p2"));
2537   HloInstruction* p1 =
2538       builder.AddInstruction(HloInstruction::CreateParameter(1, shape1, "p1"));
2539   HloInstruction* negate0 = builder.AddInstruction(
2540       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, p1));
2541   HloInstruction* negate1 = builder.AddInstruction(
2542       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate0));
2543   HloInstruction* negate2 = builder.AddInstruction(
2544       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate1));
2545   HloInstruction* negate3 = builder.AddInstruction(
2546       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate2));
2547   HloInstruction* negate4 = builder.AddInstruction(
2548       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate3));
2549   HloInstruction* negate5 = builder.AddInstruction(
2550       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate4));
2551   HloInstruction* negate6 = builder.AddInstruction(
2552       HloInstruction::CreateUnary(shape1, HloOpcode::kNegate, negate5));
2553   HloInstruction* padding_value = builder.AddInstruction(
2554       HloInstruction::CreateConstant(LiteralUtil::Zero(F32)));
2555   HloInstruction* add1 = builder.AddInstruction(
2556       HloInstruction::CreateBinary(shape1, HloOpcode::kAdd, negate6, p1));
2557   HloInstruction* padded_add1 = builder.AddInstruction(
2558       HloInstruction::CreatePad(shape2, add1, padding_value, padding_config));
2559   HloInstruction* add2 = builder.AddInstruction(
2560       HloInstruction::CreateBinary(shape2, HloOpcode::kAdd, padded_add1, p2));
2561   HloInstruction* padded_add2 = builder.AddInstruction(
2562       HloInstruction::CreatePad(shape3, add2, padding_value, padding_config));
2563   HloInstruction* negate7 = builder.AddInstruction(
2564       HloInstruction::CreateUnary(shape4, HloOpcode::kNegate, p4));
2565   HloInstruction* add3 = builder.AddInstruction(
2566       HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, padded_add2, p3));
2567   HloInstruction* padded_add3 = builder.AddInstruction(
2568       HloInstruction::CreatePad(shape4, add3, padding_value, padding_config));
2569   HloInstruction* add4 = builder.AddInstruction(HloInstruction::CreateBinary(
2570       shape4, HloOpcode::kAdd, padded_add3, negate7));
2571 
2572   auto module = CreateNewVerifiedModule();
2573   HloComputation* computation = module->AddEntryComputation(builder.Build());
2574 
2575   HloSchedule schedule(module.get());
2576   schedule.set_sequence(computation, {p0,
2577                                       p4,
2578                                       p3,
2579                                       p2,
2580                                       p1,
2581                                       negate0,
2582                                       negate1,
2583                                       negate2,
2584                                       negate3,
2585                                       negate4,
2586                                       negate5,
2587                                       negate6,
2588                                       padding_value,
2589                                       add1,
2590                                       padded_add1,
2591                                       add2,
2592                                       padded_add2,
2593                                       negate7,
2594                                       add3,
2595                                       padded_add3,
2596                                       add4});
2597   TF_CHECK_OK(module->set_schedule(schedule));
2598 
2599   // Use a large max prefetch interval to force CopyStart/CopyDone right after
2600   // the parameters.
2601   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
2602                     /*max_prefetch_interval=*/50);
2603 
2604   // Iterate over the schedule to make sure CopyStart order and the
2605   // corresponding CopyDone order match.
2606   std::list<const HloInstruction*> copy_starts;
2607   for (HloInstruction* instruction : module->schedule()
2608                                          .sequence(module->entry_computation())
2609                                          .instructions()) {
2610     if (instruction->opcode() == HloOpcode::kCopyStart) {
2611       copy_starts.push_back(instruction);
2612     }
2613     if (instruction->opcode() == HloOpcode::kCopyDone) {
2614       EXPECT_EQ(copy_starts.front(), instruction->operand(0));
2615       copy_starts.pop_front();
2616     }
2617   }
2618 }
2619 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule1)2620 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule1) {
2621   // Test to ensure CopyStart/CopyDone is placed only in the entry computation.
2622   auto module = CreateNewVerifiedModule();
2623   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2624   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2625   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape});
2626 
2627   auto cond_builder = HloComputation::Builder("WhileCond");
2628   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2629   HloInstruction* cond_param = cond_builder.AddInstruction(
2630       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2631   HloInstruction* cond_iter = cond_builder.AddInstruction(
2632       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2633   HloInstruction* cond_limit = cond_builder.AddInstruction(
2634       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2635   // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
2636   HloInstruction* cond_lt = cond_builder.AddInstruction(
2637       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2638                                     cond_limit, ComparisonDirection::kLt));
2639   HloComputation* cond_computation =
2640       module->AddEmbeddedComputation(cond_builder.Build());
2641 
2642   auto body_builder = HloComputation::Builder("WhileBody");
2643   // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
2644   HloInstruction* body_param = body_builder.AddInstruction(
2645       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2646   HloInstruction* body_iter = body_builder.AddInstruction(
2647       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2648   HloInstruction* body_data = body_builder.AddInstruction(
2649       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2650   HloInstruction* body_iter_increment = body_builder.AddInstruction(
2651       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2652   HloInstruction* body_iter_next =
2653       body_builder.AddInstruction(HloInstruction::CreateBinary(
2654           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2655   HloInstruction* body_data_increment =
2656       body_builder.AddInstruction(HloInstruction::CreateConstant(
2657           LiteralUtil::CreateR2<float>({{1.f, 2.f, 3.f}, {4.f, 5.f, 6.f}})));
2658   HloInstruction* body_data_mul =
2659       body_builder.AddInstruction(HloInstruction::CreateBinary(
2660           shape, HloOpcode::kMultiply, body_data, body_data));
2661   HloInstruction* body_data_add =
2662       body_builder.AddInstruction(HloInstruction::CreateBinary(
2663           shape, HloOpcode::kAdd, body_data, body_data_increment));
2664   HloInstruction* body_data_next =
2665       body_builder.AddInstruction(HloInstruction::CreateBinary(
2666           shape, HloOpcode::kAdd, body_data_add, body_data_mul));
2667   HloInstruction* body_out = body_builder.AddInstruction(
2668       HloInstruction::CreateTuple({body_data_next, body_iter_next}));
2669   HloComputation* body_computation =
2670       module->AddEmbeddedComputation(body_builder.Build());
2671 
2672   auto builder = HloComputation::Builder(TestName());
2673   HloInstruction* data = builder.AddInstruction(
2674       HloInstruction::CreateParameter(0, shape, "param_iter"));
2675   HloInstruction* iter = builder.AddInstruction(
2676       HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
2677   HloInstruction* p2 =
2678       builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2"));
2679   HloInstruction* tuple =
2680       builder.AddInstruction(HloInstruction::CreateTuple({data, iter}));
2681   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
2682       tuple_shape, cond_computation, body_computation, tuple));
2683   HloInstruction* while_data = builder.AddInstruction(
2684       HloInstruction::CreateGetTupleElement(shape, while_op, 0));
2685   HloInstruction* add = builder.AddInstruction(
2686       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, while_data, p2));
2687   HloComputation* entry_computation =
2688       module->AddEntryComputation(builder.Build());
2689 
2690   HloSchedule schedule(module.get());
2691   schedule.set_sequence(cond_computation,
2692                         {cond_param, cond_iter, cond_limit, cond_lt});
2693   schedule.set_sequence(body_computation,
2694                         {body_param, body_iter, body_data, body_iter_increment,
2695                          body_iter_next, body_data_increment, body_data_mul,
2696                          body_data_add, body_data_next, body_out});
2697   schedule.set_sequence(entry_computation,
2698                         {iter, data, p2, tuple, while_op, while_data, add});
2699   TF_CHECK_OK(module->set_schedule(schedule));
2700 
2701   AssignMemorySpace(module.get(), -1, 50);
2702 }
2703 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule2)2704 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule2) {
2705   auto module = CreateNewVerifiedModule();
2706   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2707   Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2708 
2709   auto call_builder = HloComputation::Builder("Call");
2710   HloInstruction* call_param = call_builder.AddInstruction(
2711       HloInstruction::CreateParameter(0, shape, "call_param"));
2712   HloInstruction* call_param2 = call_builder.AddInstruction(
2713       HloInstruction::CreateParameter(1, shape2, "call_param2"));
2714   HloInstruction* slice = call_builder.AddInstruction(
2715       HloInstruction::CreateSlice(shape, call_param2, {0, 0}, {2, 3}, {1, 1}));
2716   HloInstruction* mul =
2717       call_builder.AddInstruction(HloInstruction::CreateBinary(
2718           shape, HloOpcode::kMultiply, call_param, slice));
2719   HloInstruction* negate0 = call_builder.AddInstruction(
2720       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2721   HloInstruction* negate1 = call_builder.AddInstruction(
2722       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2723   HloInstruction* negate2 = call_builder.AddInstruction(
2724       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2725   HloInstruction* negate3 = call_builder.AddInstruction(
2726       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2727   HloInstruction* negate4 = call_builder.AddInstruction(
2728       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2729   HloInstruction* negate5 = call_builder.AddInstruction(
2730       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2731   HloInstruction* negate6 = call_builder.AddInstruction(
2732       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2733   HloInstruction* negate7 = call_builder.AddInstruction(
2734       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2735   HloInstruction* add0 =
2736       call_builder.AddInstruction(HloInstruction::CreateBinary(
2737           shape, HloOpcode::kAdd, call_param, negate7));
2738   HloComputation* call_computation =
2739       module->AddEmbeddedComputation(call_builder.Build());
2740 
2741   auto builder = HloComputation::Builder(TestName());
2742   HloInstruction* p0 =
2743       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2744   HloInstruction* p1 =
2745       builder.AddInstruction(HloInstruction::CreateParameter(1, shape2, "p1"));
2746   HloInstruction* add1 = builder.AddInstruction(
2747       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2748   HloInstruction* add2 = builder.AddInstruction(
2749       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2750   HloInstruction* negate8 = builder.AddInstruction(
2751       HloInstruction::CreateUnary(shape2, HloOpcode::kNegate, p1));
2752   HloInstruction* call = builder.AddInstruction(
2753       HloInstruction::CreateCall(shape, {add1, negate8}, call_computation));
2754   HloInstruction* add3 = builder.AddInstruction(
2755       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, add1));
2756   HloInstruction* add4 = builder.AddInstruction(
2757       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add3));
2758   HloInstruction* add5 = builder.AddInstruction(
2759       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add2, add4));
2760   HloComputation* entry_computation =
2761       module->AddEntryComputation(builder.Build());
2762 
2763   HloSchedule schedule(module.get());
2764   schedule.set_sequence(
2765       call_computation,
2766       {call_param, call_param2, slice, mul, negate0, negate1, negate2, negate3,
2767        negate4, negate5, negate6, negate7, add0});
2768   schedule.set_sequence(entry_computation,
2769                         {p0, p1, add1, add2, negate8, call, add3, add4, add5});
2770   TF_CHECK_OK(module->set_schedule(schedule));
2771 
2772   AssignMemorySpace(module.get(), -1, 5);
2773 }
2774 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule3)2775 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule3) {
2776   auto module = CreateNewVerifiedModule();
2777   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2778   Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2779 
2780   auto call_builder = HloComputation::Builder("Call");
2781   HloInstruction* call_param = call_builder.AddInstruction(
2782       HloInstruction::CreateParameter(0, shape, "call_param"));
2783   // Use shape2 here which is larger (scheduled earlier) to occupy alternate
2784   // memory at the beginning. This should cause a situation where the prefetch
2785   // of add1 later in the function body gets the wrong offset which cannot be
2786   // communicated to the outside the function.
2787   HloInstruction* iota =
2788       call_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2789   HloInstruction* slice = call_builder.AddInstruction(
2790       HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2791   HloInstruction* mul =
2792       call_builder.AddInstruction(HloInstruction::CreateBinary(
2793           shape, HloOpcode::kMultiply, call_param, slice));
2794   HloInstruction* negate0 = call_builder.AddInstruction(
2795       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2796   HloInstruction* negate1 = call_builder.AddInstruction(
2797       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2798   HloInstruction* negate2 = call_builder.AddInstruction(
2799       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2800   HloInstruction* negate3 = call_builder.AddInstruction(
2801       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2802   HloInstruction* negate4 = call_builder.AddInstruction(
2803       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2804   HloInstruction* negate5 = call_builder.AddInstruction(
2805       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2806   HloInstruction* negate6 = call_builder.AddInstruction(
2807       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2808   HloInstruction* negate7 = call_builder.AddInstruction(
2809       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2810   HloInstruction* add0 =
2811       call_builder.AddInstruction(HloInstruction::CreateBinary(
2812           shape, HloOpcode::kAdd, call_param, negate7));
2813   HloComputation* call_computation =
2814       module->AddEmbeddedComputation(call_builder.Build());
2815 
2816   auto builder = HloComputation::Builder(TestName());
2817   HloInstruction* p0 =
2818       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2819   HloInstruction* add1 = builder.AddInstruction(
2820       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2821   HloInstruction* add2 = builder.AddInstruction(
2822       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2823   HloInstruction* call = builder.AddInstruction(
2824       HloInstruction::CreateCall(shape, {add1}, call_computation));
2825   HloInstruction* add3 = builder.AddInstruction(
2826       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, call, add1));
2827   HloComputation* entry_computation =
2828       module->AddEntryComputation(builder.Build());
2829 
2830   HloSchedule schedule(module.get());
2831   schedule.set_sequence(
2832       call_computation,
2833       {call_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2834        negate4, negate5, negate6, negate7, add0});
2835   schedule.set_sequence(entry_computation, {p0, add1, add2, call, add3});
2836   TF_CHECK_OK(module->set_schedule(schedule));
2837 
2838   AssignMemorySpace(module.get(), -1, 5);
2839 }
2840 
2841 // TODO(berkin): This might be an incorrect input graph, investigate.
TEST_P(MemorySpaceAssignmentTest,DISABLED_NonEntryComputationSchedule4)2842 TEST_P(MemorySpaceAssignmentTest, DISABLED_NonEntryComputationSchedule4) {
2843   auto module = CreateNewVerifiedModule();
2844   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2845   Shape shape2 = ShapeUtil::MakeShape(xla::F32, {3, 3});
2846 
2847   auto true_builder = HloComputation::Builder("True");
2848   HloInstruction* true_param = true_builder.AddInstruction(
2849       HloInstruction::CreateParameter(0, shape, "true_param"));
2850   HloInstruction* iota =
2851       true_builder.AddInstruction(HloInstruction::CreateIota(shape2, 0));
2852   HloInstruction* slice = true_builder.AddInstruction(
2853       HloInstruction::CreateSlice(shape, iota, {0, 0}, {2, 3}, {1, 1}));
2854   HloInstruction* mul =
2855       true_builder.AddInstruction(HloInstruction::CreateBinary(
2856           shape, HloOpcode::kMultiply, true_param, slice));
2857   HloInstruction* negate0 = true_builder.AddInstruction(
2858       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, mul));
2859   HloInstruction* negate1 = true_builder.AddInstruction(
2860       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2861   HloInstruction* negate2 = true_builder.AddInstruction(
2862       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2863   HloInstruction* negate3 = true_builder.AddInstruction(
2864       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2865   HloInstruction* negate4 = true_builder.AddInstruction(
2866       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
2867   HloInstruction* negate5 = true_builder.AddInstruction(
2868       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
2869   HloInstruction* negate6 = true_builder.AddInstruction(
2870       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
2871   HloInstruction* negate7 = true_builder.AddInstruction(
2872       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
2873   HloInstruction* add0 =
2874       true_builder.AddInstruction(HloInstruction::CreateBinary(
2875           shape, HloOpcode::kAdd, true_param, negate7));
2876   HloComputation* true_computation =
2877       module->AddEmbeddedComputation(true_builder.Build());
2878 
2879   auto false_builder = HloComputation::Builder("False");
2880   HloInstruction* false_param = false_builder.AddInstruction(
2881       HloInstruction::CreateParameter(0, shape, "false_param"));
2882   HloComputation* false_computation =
2883       module->AddEmbeddedComputation(false_builder.Build());
2884 
2885   auto builder = HloComputation::Builder(TestName());
2886   HloInstruction* p0 =
2887       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
2888   HloInstruction* add1 = builder.AddInstruction(
2889       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, p0, p0));
2890   HloInstruction* add2 = builder.AddInstruction(
2891       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add1, p0));
2892   HloInstruction* pred = builder.AddInstruction(
2893       HloInstruction::CreateConstant(LiteralUtil::CreateR0<bool>(true)));
2894   HloInstruction* conditional =
2895       builder.AddInstruction(HloInstruction::CreateConditional(
2896           shape, pred, add1, true_computation, add2, false_computation));
2897   HloInstruction* add3 = builder.AddInstruction(
2898       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, conditional, add1));
2899   HloComputation* entry_computation =
2900       module->AddEntryComputation(builder.Build());
2901 
2902   HloSchedule schedule(module.get());
2903   schedule.set_sequence(
2904       true_computation,
2905       {true_param, iota, slice, mul, negate0, negate1, negate2, negate3,
2906        negate4, negate5, negate6, negate7, add0});
2907   schedule.set_sequence(false_computation, {false_param});
2908   schedule.set_sequence(entry_computation,
2909                         {p0, add1, add2, pred, conditional, add3});
2910   TF_CHECK_OK(module->set_schedule(schedule));
2911 
2912   AssignMemorySpace(module.get(), -1, 5);
2913 }
2914 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule5)2915 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule5) {
2916   // This test reproduces the failure in b/143288178.  Given a graph like the
2917   // following:
2918   //
2919   // ... = foo(a)
2920   // tuple = tuple((..., a)
2921   // ... = while(tuple) {
2922   //   p = param(0)
2923   //   a1 = get-tuple-element(p), index=n-1
2924   //   ...
2925   //   ROOT tuple((..., a1))
2926   // }
2927   //
2928   // If a copy to alternate memory is inserted before foo, and if the size of
2929   // the while body is less than max prefetch interval so that the copy-done is
2930   // kept in the alternate memory, then we end up referring to the copy-done in
2931   // the root instruction of the while loop body. I.e.,
2932   //
2933   // cs = copy-start(a)
2934   // ...
2935   // cd = copy-done(cs)
2936   // ... = foo(cd)
2937   // tuple = tuple((..., cd)
2938   // ... = while(tuple) {
2939   //   p = param(0)
2940   //   a1 = get-tuple-element(p), index=n-1
2941   //   ...
2942   //   ROOT tuple((..., cd))  <-- Error: cd belongs to outside computation.
2943   // }
2944   //
2945   auto module = CreateNewVerifiedModule();
2946   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
2947   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
2948   Shape tuple_shape =
2949       ShapeUtil::MakeTupleShape({shape, scalar_shape, scalar_shape});
2950 
2951   auto cond_builder = HloComputation::Builder("WhileCond");
2952   HloInstruction* cond_param = cond_builder.AddInstruction(
2953       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
2954   HloInstruction* cond_iter = cond_builder.AddInstruction(
2955       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
2956   HloInstruction* cond_limit = cond_builder.AddInstruction(
2957       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
2958   HloInstruction* cond_lt = cond_builder.AddInstruction(
2959       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
2960                                     cond_limit, ComparisonDirection::kLt));
2961   HloComputation* cond_computation =
2962       module->AddEmbeddedComputation(cond_builder.Build());
2963 
2964   auto body_builder = HloComputation::Builder("WhileBody");
2965   HloInstruction* body_param = body_builder.AddInstruction(
2966       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
2967   HloInstruction* body_iter = body_builder.AddInstruction(
2968       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
2969   HloInstruction* body_data = body_builder.AddInstruction(
2970       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
2971   HloInstruction* body_iter_increment = body_builder.AddInstruction(
2972       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
2973   HloInstruction* body_iter_next =
2974       body_builder.AddInstruction(HloInstruction::CreateBinary(
2975           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
2976   HloInstruction* body_data2 = body_builder.AddInstruction(
2977       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 2));
2978   HloInstruction* body_out = body_builder.AddInstruction(
2979       HloInstruction::CreateTuple({body_data, body_iter_next, body_data2}));
2980   HloComputation* body_computation =
2981       module->AddEmbeddedComputation(body_builder.Build());
2982 
2983   auto builder = HloComputation::Builder(TestName());
2984   HloInstruction* data = builder.AddInstruction(
2985       HloInstruction::CreateParameter(0, shape, "param_data"));
2986   HloInstruction* iter = builder.AddInstruction(
2987       HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
2988   HloInstruction* data2 = builder.AddInstruction(
2989       HloInstruction::CreateParameter(2, scalar_shape, "param_data2"));
2990   HloInstruction* negate0 = builder.AddInstruction(
2991       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
2992   HloInstruction* negate1 = builder.AddInstruction(
2993       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
2994   HloInstruction* negate2 = builder.AddInstruction(
2995       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
2996   HloInstruction* negate3 = builder.AddInstruction(
2997       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
2998   HloInstruction* negate4 = builder.AddInstruction(
2999       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3000   HloInstruction* negate5 = builder.AddInstruction(
3001       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3002   HloInstruction* negate6 = builder.AddInstruction(
3003       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3004   HloInstruction* negate7 = builder.AddInstruction(
3005       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
3006   HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
3007       scalar_shape, HloOpcode::kSubtract, iter, data2));
3008   HloInstruction* tuple = builder.AddInstruction(
3009       HloInstruction::CreateTuple({negate7, iter, data2}));
3010   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
3011       tuple_shape, cond_computation, body_computation, tuple));
3012   HloInstruction* while_data = builder.AddInstruction(
3013       HloInstruction::CreateGetTupleElement(scalar_shape, while_op, 1));
3014   HloInstruction* root =
3015       builder.AddInstruction(HloInstruction::CreateTuple({while_data, sub}));
3016   HloComputation* entry_computation =
3017       module->AddEntryComputation(builder.Build());
3018 
3019   HloSchedule schedule(module.get());
3020   schedule.set_sequence(cond_computation,
3021                         {cond_param, cond_iter, cond_limit, cond_lt});
3022   schedule.set_sequence(body_computation,
3023                         {body_param, body_iter, body_data, body_iter_increment,
3024                          body_iter_next, body_data2, body_out});
3025   schedule.set_sequence(
3026       entry_computation,
3027       {iter, data, data2, negate0, negate1, negate2, negate3, negate4, negate5,
3028        negate6, negate7, sub, tuple, while_op, while_data, root});
3029   TF_CHECK_OK(module->set_schedule(schedule));
3030 
3031   // Set a large max prefetch interval so that the buffer can be kept in
3032   // alternate memory.
3033   AssignMemorySpace(module.get(), -1, 20);
3034 }
3035 
TEST_P(MemorySpaceAssignmentTest,NonEntryComputationSchedule6)3036 TEST_P(MemorySpaceAssignmentTest, NonEntryComputationSchedule6) {
3037   auto module = CreateNewVerifiedModule();
3038   Shape shape = ShapeUtil::MakeShape(xla::F32, {2, 3});
3039   Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
3040   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, scalar_shape, shape});
3041 
3042   auto cond_builder = HloComputation::Builder("WhileCond");
3043   HloInstruction* cond_param = cond_builder.AddInstruction(
3044       HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
3045   HloInstruction* cond_iter = cond_builder.AddInstruction(
3046       HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
3047   HloInstruction* cond_limit = cond_builder.AddInstruction(
3048       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(50.f)));
3049   HloInstruction* cond_lt = cond_builder.AddInstruction(
3050       HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), cond_iter,
3051                                     cond_limit, ComparisonDirection::kLt));
3052   HloComputation* cond_computation =
3053       module->AddEmbeddedComputation(cond_builder.Build());
3054 
3055   auto body_builder = HloComputation::Builder("WhileBody");
3056   HloInstruction* body_param = body_builder.AddInstruction(
3057       HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
3058   HloInstruction* body_iter = body_builder.AddInstruction(
3059       HloInstruction::CreateGetTupleElement(scalar_shape, body_param, 1));
3060   HloInstruction* body_data = body_builder.AddInstruction(
3061       HloInstruction::CreateGetTupleElement(shape, body_param, 0));
3062   HloInstruction* body_negate0 = body_builder.AddInstruction(
3063       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_data));
3064   HloInstruction* body_negate1 = body_builder.AddInstruction(
3065       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate0));
3066   HloInstruction* body_negate2 = body_builder.AddInstruction(
3067       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate1));
3068   HloInstruction* body_negate3 = body_builder.AddInstruction(
3069       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate2));
3070   HloInstruction* body_negate4 = body_builder.AddInstruction(
3071       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate3));
3072   HloInstruction* body_negate5 = body_builder.AddInstruction(
3073       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate4));
3074   HloInstruction* body_negate6 = body_builder.AddInstruction(
3075       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate5));
3076   HloInstruction* body_negate7 = body_builder.AddInstruction(
3077       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, body_negate6));
3078   HloInstruction* body_iter_increment = body_builder.AddInstruction(
3079       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.f)));
3080   HloInstruction* body_iter_next =
3081       body_builder.AddInstruction(HloInstruction::CreateBinary(
3082           scalar_shape, HloOpcode::kAdd, body_iter, body_iter_increment));
3083   HloInstruction* body_out = body_builder.AddInstruction(
3084       HloInstruction::CreateTuple({body_data, body_iter_next, body_negate7}));
3085   HloComputation* body_computation =
3086       module->AddEmbeddedComputation(body_builder.Build());
3087 
3088   auto builder = HloComputation::Builder(TestName());
3089   HloInstruction* data = builder.AddInstruction(
3090       HloInstruction::CreateParameter(0, shape, "param_data"));
3091   HloInstruction* iter = builder.AddInstruction(
3092       HloInstruction::CreateParameter(1, scalar_shape, "param_iter"));
3093   HloInstruction* negate0 = builder.AddInstruction(
3094       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, data));
3095   HloInstruction* negate1 = builder.AddInstruction(
3096       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3097   HloInstruction* negate2 = builder.AddInstruction(
3098       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3099   HloInstruction* negate3 = builder.AddInstruction(
3100       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3101   HloInstruction* negate4 = builder.AddInstruction(
3102       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3103   HloInstruction* negate5 = builder.AddInstruction(
3104       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3105   HloInstruction* negate6 = builder.AddInstruction(
3106       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3107   HloInstruction* negate7 = builder.AddInstruction(
3108       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate6));
3109   HloInstruction* tuple = builder.AddInstruction(
3110       HloInstruction::CreateTuple({data, iter, negate7}));
3111   HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
3112       tuple_shape, cond_computation, body_computation, tuple));
3113   HloInstruction* while_data = builder.AddInstruction(
3114       HloInstruction::CreateGetTupleElement(shape, while_op, 0));
3115   HloInstruction* while_data2 = builder.AddInstruction(
3116       HloInstruction::CreateGetTupleElement(shape, while_op, 2));
3117   HloInstruction* root = builder.AddInstruction(HloInstruction::CreateBinary(
3118       shape, HloOpcode::kAdd, while_data, while_data2));
3119   HloComputation* entry_computation =
3120       module->AddEntryComputation(builder.Build());
3121 
3122   HloSchedule schedule(module.get());
3123   schedule.set_sequence(cond_computation,
3124                         {cond_param, cond_iter, cond_limit, cond_lt});
3125   schedule.set_sequence(
3126       body_computation,
3127       {body_param, body_iter, body_data, body_negate0, body_negate1,
3128        body_negate2, body_negate3, body_negate4, body_negate5, body_negate6,
3129        body_negate7, body_iter_increment, body_iter_next, body_out});
3130   schedule.set_sequence(
3131       entry_computation,
3132       {iter, data, negate0, negate1, negate2, negate3, negate4, negate5,
3133        negate6, negate7, tuple, while_op, while_data, while_data2, root});
3134   TF_CHECK_OK(module->set_schedule(schedule));
3135 
3136   // Pick a large max prefetch interval to ensure all the while inputs are
3137   // allocated in the alternate memory.
3138   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3139                     /*max_prefetch_interval=*/25);
3140 
3141   // Index {0} of the while loop argument is not written inside the while loop,
3142   // so it can be trivially placed in the alternate memory space.
3143   *ShapeUtil::GetMutableSubshape(&tuple_shape, {0})->mutable_layout() =
3144       LayoutUtil::MakeLayout(
3145           /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3146           kAlternateMemorySpace);
3147   // Index {1} is a scalar, so it is always placed in the default memory.
3148   *ShapeUtil::GetMutableSubshape(&tuple_shape, {1})->mutable_layout() =
3149       LayoutUtil::MakeLayout(
3150           /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3151           kDefaultMemorySpace);
3152   // Index {2} of the while loop is placed in the default memory.
3153   *ShapeUtil::GetMutableSubshape(&tuple_shape, {2})->mutable_layout() =
3154       LayoutUtil::MakeLayout(
3155           /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3156           kDefaultMemorySpace);
3157 
3158   // Expect the layout for the while loop and its aliased buffers.
3159   EXPECT_THAT(while_op, op::ShapeWithLayout(tuple_shape));
3160   EXPECT_THAT(while_op->operand(0), op::ShapeWithLayout(tuple_shape));
3161   EXPECT_THAT(cond_param, op::ShapeWithLayout(tuple_shape));
3162   EXPECT_THAT(body_param, op::ShapeWithLayout(tuple_shape));
3163   EXPECT_THAT(body_out, op::ShapeWithLayout(tuple_shape));
3164 }
3165 
TEST_P(MemorySpaceAssignmentTest,DanglingCopy)3166 TEST_P(MemorySpaceAssignmentTest, DanglingCopy) {
3167   // This situation was encountered in vss, where there is a mismatch in the
3168   // memory space in preset assignments and the output graph.
3169   HloComputation::Builder builder(TestName());
3170   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3171   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3172 
3173   HloInstruction* p = builder.AddInstruction(
3174       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3175   HloInstruction* p0 = builder.AddInstruction(
3176       HloInstruction::CreateGetTupleElement(shape, p, 0));
3177   HloInstruction* p1a = builder.AddInstruction(
3178       HloInstruction::CreateGetTupleElement(shape, p, 1));
3179   HloInstruction* copy = builder.AddInstruction(
3180       HloInstruction::CreateUnary(shape, HloOpcode::kCopy, p1a));
3181   HloInstruction* negate0 = builder.AddInstruction(
3182       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3183   HloInstruction* negate1 = builder.AddInstruction(
3184       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3185   HloInstruction* negate2 = builder.AddInstruction(
3186       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3187   HloInstruction* negate3 = builder.AddInstruction(
3188       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3189   HloInstruction* negate4 = builder.AddInstruction(
3190       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3191   HloInstruction* negate5 = builder.AddInstruction(
3192       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3193   HloInstruction* negate6 = builder.AddInstruction(
3194       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3195   HloInstruction* p1b = builder.AddInstruction(
3196       HloInstruction::CreateGetTupleElement(shape, p, 1));
3197   HloInstruction* add = builder.AddInstruction(
3198       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1b));
3199 
3200   auto module = CreateNewVerifiedModule();
3201   HloComputation* computation = module->AddEntryComputation(builder.Build());
3202 
3203   HloSchedule schedule(module.get());
3204   schedule.set_sequence(
3205       computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3206                     negate6, p1a, copy, p1b, add});
3207   TF_CHECK_OK(module->set_schedule(schedule));
3208 
3209   AssignMemorySpace(module.get());
3210 }
3211 
TEST_P(MemorySpaceAssignmentTest,MultiOutputFusion)3212 TEST_P(MemorySpaceAssignmentTest, MultiOutputFusion) {
3213   HloComputation::Builder builder(TestName());
3214   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3215   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3216   auto module = CreateNewVerifiedModule();
3217 
3218   HloComputation::Builder fusion_builder("fusion");
3219   HloInstruction* fusion_param0 = fusion_builder.AddInstruction(
3220       HloInstruction::CreateParameter(0, shape, "p0"));
3221   HloInstruction* fusion_param1 = fusion_builder.AddInstruction(
3222       HloInstruction::CreateParameter(1, shape, "p1"));
3223   fusion_builder.AddInstruction(
3224       HloInstruction::CreateTuple({fusion_param0, fusion_param1}));
3225   HloComputation* fusion_computation =
3226       module->AddEmbeddedComputation(fusion_builder.Build());
3227 
3228   HloInstruction* p0 =
3229       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3230   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3231       tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3232       fusion_computation));
3233   HloInstruction* element0 = builder.AddInstruction(
3234       HloInstruction::CreateGetTupleElement(shape, fusion, 0));
3235   HloInstruction* element1 = builder.AddInstruction(
3236       HloInstruction::CreateGetTupleElement(shape, fusion, 1));
3237   HloInstruction* add = builder.AddInstruction(
3238       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3239 
3240   HloComputation* computation = module->AddEntryComputation(builder.Build());
3241 
3242   HloSchedule schedule(module.get());
3243   schedule.set_sequence(computation, {p0, fusion, element0, element1, add});
3244   TF_CHECK_OK(module->set_schedule(schedule));
3245 
3246   AssignMemorySpace(module.get());
3247 }
3248 
TEST_P(MemorySpaceAssignmentTest,TupleInput)3249 TEST_P(MemorySpaceAssignmentTest, TupleInput) {
3250   HloComputation::Builder builder(TestName());
3251   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3252   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3253   auto module = CreateNewVerifiedModule();
3254 
3255   HloComputation::Builder fusion_builder("fusion");
3256   HloInstruction* fusion_param = fusion_builder.AddInstruction(
3257       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3258   HloInstruction* fusion_element0 = fusion_builder.AddInstruction(
3259       HloInstruction::CreateGetTupleElement(shape, fusion_param, 0));
3260   HloInstruction* fusion_element1 = fusion_builder.AddInstruction(
3261       HloInstruction::CreateGetTupleElement(shape, fusion_param, 1));
3262   fusion_builder.AddInstruction(HloInstruction::CreateBinary(
3263       shape, HloOpcode::kAdd, fusion_element0, fusion_element1));
3264   HloComputation* fusion_computation =
3265       module->AddEmbeddedComputation(fusion_builder.Build());
3266 
3267   HloInstruction* p0 =
3268       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3269   HloInstruction* p1 =
3270       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3271   HloInstruction* negate0 = builder.AddInstruction(
3272       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3273   HloInstruction* negate1 = builder.AddInstruction(
3274       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3275   HloInstruction* tuple =
3276       builder.AddInstruction(HloInstruction::CreateTuple({negate0, negate1}));
3277   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
3278       shape, HloInstruction::FusionKind::kCustom, {tuple}, fusion_computation));
3279 
3280   HloComputation* computation = module->AddEntryComputation(builder.Build());
3281 
3282   HloSchedule schedule(module.get());
3283   schedule.set_sequence(computation, {p0, p1, negate0, negate1, tuple, fusion});
3284   TF_CHECK_OK(module->set_schedule(schedule));
3285 
3286   AssignMemorySpace(module.get());
3287 }
3288 
TEST_P(MemorySpaceAssignmentTest,TupleToTuple1)3289 TEST_P(MemorySpaceAssignmentTest, TupleToTuple1) {
3290   HloComputation::Builder builder(TestName());
3291   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3292   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3293   auto module = CreateNewVerifiedModule();
3294 
3295   HloComputation::Builder fusion0_builder("fusion0");
3296   HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3297       HloInstruction::CreateParameter(0, shape, "p0"));
3298   HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3299       HloInstruction::CreateParameter(1, shape, "p1"));
3300   fusion0_builder.AddInstruction(
3301       HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3302   HloComputation* fusion0_computation =
3303       module->AddEmbeddedComputation(fusion0_builder.Build());
3304 
3305   HloComputation::Builder fusion1_builder("fusion1");
3306   HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3307       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3308   HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3309       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3310   HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3311       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3312   fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3313       shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3314   HloComputation* fusion1_computation =
3315       module->AddEmbeddedComputation(fusion1_builder.Build());
3316 
3317   HloInstruction* p0 =
3318       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3319   HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3320       tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3321       fusion0_computation));
3322   HloInstruction* element0 = builder.AddInstruction(
3323       HloInstruction::CreateGetTupleElement(shape, fusion0, 0));
3324   HloInstruction* element1 = builder.AddInstruction(
3325       HloInstruction::CreateGetTupleElement(shape, fusion0, 1));
3326   HloInstruction* negate0 = builder.AddInstruction(
3327       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3328   HloInstruction* negate1 = builder.AddInstruction(
3329       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3330   HloInstruction* negate2 = builder.AddInstruction(
3331       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3332   HloInstruction* negate3 = builder.AddInstruction(
3333       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3334   HloInstruction* negate4 = builder.AddInstruction(
3335       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3336   HloInstruction* negate5 = builder.AddInstruction(
3337       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3338   HloInstruction* negate6 = builder.AddInstruction(
3339       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3340   HloInstruction* add0 = builder.AddInstruction(
3341       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, element0, element1));
3342   HloInstruction* add1 = builder.AddInstruction(
3343       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, negate6));
3344   HloInstruction* fusion1 = builder.AddInstruction(
3345       HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3346                                    {fusion0}, fusion1_computation));
3347   HloInstruction* mul = builder.AddInstruction(
3348       HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, add1, fusion1));
3349 
3350   HloComputation* computation = module->AddEntryComputation(builder.Build());
3351 
3352   HloSchedule schedule(module.get());
3353   schedule.set_sequence(
3354       computation,
3355       {p0, fusion0, element0, element1, negate0, negate1, negate2, negate3,
3356        negate4, negate5, negate6, add0, add1, fusion1, mul});
3357   TF_CHECK_OK(module->set_schedule(schedule));
3358 
3359   AssignMemorySpace(module.get(), -1, 5);
3360   EXPECT_THAT(fusion1,
3361               op::Fusion(op::Tuple(
3362                   op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3363                                 op::GetTupleElement(op::Fusion(), 0)),
3364                   op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3365                                 op::GetTupleElement(op::Fusion(), 1)))));
3366 }
3367 
TEST_P(MemorySpaceAssignmentTest,TupleToTuple2)3368 TEST_P(MemorySpaceAssignmentTest, TupleToTuple2) {
3369   HloComputation::Builder builder(TestName());
3370   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3371   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3372   Shape nested_tuple_shape = ShapeUtil::MakeTupleShape({shape, tuple_shape});
3373   auto module = CreateNewVerifiedModule();
3374 
3375   HloComputation::Builder fusion0_builder("fusion0");
3376   HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3377       HloInstruction::CreateParameter(0, shape, "p0"));
3378   HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3379       HloInstruction::CreateParameter(1, shape, "p1"));
3380   HloInstruction* fusion0_tuple = fusion0_builder.AddInstruction(
3381       HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3382   fusion0_builder.AddInstruction(
3383       HloInstruction::CreateTuple({fusion0_param0, fusion0_tuple}));
3384   HloComputation* fusion0_computation =
3385       module->AddEmbeddedComputation(fusion0_builder.Build());
3386 
3387   HloComputation::Builder fusion1_builder("fusion1");
3388   HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3389       HloInstruction::CreateParameter(0, nested_tuple_shape, "p"));
3390   HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3391       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3392   HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3393       HloInstruction::CreateGetTupleElement(tuple_shape, fusion1_param, 1));
3394   HloInstruction* fusion1_element2 = fusion1_builder.AddInstruction(
3395       HloInstruction::CreateGetTupleElement(shape, fusion1_element1, 1));
3396   fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3397       shape, HloOpcode::kAdd, fusion1_element0, fusion1_element2));
3398   HloComputation* fusion1_computation =
3399       module->AddEmbeddedComputation(fusion1_builder.Build());
3400 
3401   HloInstruction* p0 =
3402       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3403   HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3404       nested_tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3405       fusion0_computation));
3406   HloInstruction* negate0 = builder.AddInstruction(
3407       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3408   HloInstruction* negate1 = builder.AddInstruction(
3409       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3410   HloInstruction* negate2 = builder.AddInstruction(
3411       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3412   HloInstruction* negate3 = builder.AddInstruction(
3413       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3414   HloInstruction* negate4 = builder.AddInstruction(
3415       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3416   HloInstruction* negate5 = builder.AddInstruction(
3417       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3418   HloInstruction* negate6 = builder.AddInstruction(
3419       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3420   HloInstruction* fusion1 = builder.AddInstruction(
3421       HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3422                                    {fusion0}, fusion1_computation));
3423 
3424   HloComputation* computation = module->AddEntryComputation(builder.Build());
3425 
3426   HloSchedule schedule(module.get());
3427   schedule.set_sequence(
3428       computation, {p0, fusion0, negate0, negate1, negate2, negate3, negate4,
3429                     negate5, negate6, fusion1});
3430   TF_CHECK_OK(module->set_schedule(schedule));
3431 
3432   AssignMemorySpace(module.get(), -1, 5);
3433 
3434   EXPECT_THAT(
3435       fusion1,
3436       op::Fusion(op::Tuple(
3437           op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3438                         op::GetTupleElement(op::Fusion(), 0)),
3439           op::Tuple(
3440               op::AsyncCopy(
3441                   kAlternateMemorySpace, kDefaultMemorySpace,
3442                   op::GetTupleElement(op::GetTupleElement(op::Fusion(), 1), 0)),
3443               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
3444                             op::GetTupleElement(
3445                                 op::GetTupleElement(op::Fusion(), 1), 1))))));
3446 }
3447 
TEST_P(MemorySpaceAssignmentTest,TupleToTuple3)3448 TEST_P(MemorySpaceAssignmentTest, TupleToTuple3) {
3449   HloComputation::Builder builder(TestName());
3450   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3451   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3452   auto module = CreateNewVerifiedModule();
3453 
3454   HloComputation::Builder fusion0_builder("fusion0");
3455   HloInstruction* fusion0_param0 = fusion0_builder.AddInstruction(
3456       HloInstruction::CreateParameter(0, shape, "p0"));
3457   HloInstruction* fusion0_param1 = fusion0_builder.AddInstruction(
3458       HloInstruction::CreateParameter(1, shape, "p1"));
3459   fusion0_builder.AddInstruction(
3460       HloInstruction::CreateTuple({fusion0_param0, fusion0_param1}));
3461   HloComputation* fusion0_computation =
3462       module->AddEmbeddedComputation(fusion0_builder.Build());
3463 
3464   HloComputation::Builder fusion1_builder("fusion1");
3465   HloInstruction* fusion1_param = fusion1_builder.AddInstruction(
3466       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3467   HloInstruction* fusion1_element0 = fusion1_builder.AddInstruction(
3468       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 0));
3469   HloInstruction* fusion1_element1 = fusion1_builder.AddInstruction(
3470       HloInstruction::CreateGetTupleElement(shape, fusion1_param, 1));
3471   fusion1_builder.AddInstruction(HloInstruction::CreateBinary(
3472       shape, HloOpcode::kAdd, fusion1_element0, fusion1_element1));
3473   HloComputation* fusion1_computation =
3474       module->AddEmbeddedComputation(fusion1_builder.Build());
3475 
3476   HloInstruction* p0 =
3477       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3478   HloInstruction* fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
3479       tuple_shape, HloInstruction::FusionKind::kCustom, {p0, p0},
3480       fusion0_computation));
3481   HloInstruction* fusion1 = builder.AddInstruction(
3482       HloInstruction::CreateFusion(shape, HloInstruction::FusionKind::kCustom,
3483                                    {fusion0}, fusion1_computation));
3484 
3485   HloComputation* computation = module->AddEntryComputation(builder.Build());
3486 
3487   HloSchedule schedule(module.get());
3488   schedule.set_sequence(computation, {p0, fusion0, fusion1});
3489   TF_CHECK_OK(module->set_schedule(schedule));
3490 
3491   AssignMemorySpace(module.get());
3492   EXPECT_THAT(fusion1, op::Fusion(op::Fusion()));
3493 }
3494 
TEST_P(MemorySpaceAssignmentTest,InputOutputAlias)3495 TEST_P(MemorySpaceAssignmentTest, InputOutputAlias) {
3496   HloComputation::Builder builder(TestName());
3497   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3498   Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
3499   HloInstruction* p = builder.AddInstruction(
3500       HloInstruction::CreateParameter(0, tuple_shape, "p"));
3501   HloInstruction* p0 = builder.AddInstruction(
3502       HloInstruction::CreateGetTupleElement(shape, p, 0));
3503   HloInstruction* negate0 = builder.AddInstruction(
3504       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3505   HloInstruction* negate1 = builder.AddInstruction(
3506       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3507   HloInstruction* negate2 = builder.AddInstruction(
3508       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3509   HloInstruction* negate3 = builder.AddInstruction(
3510       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3511   HloInstruction* negate4 = builder.AddInstruction(
3512       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3513   HloInstruction* negate5 = builder.AddInstruction(
3514       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3515   HloInstruction* negate6 = builder.AddInstruction(
3516       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3517   HloInstruction* p1 = builder.AddInstruction(
3518       HloInstruction::CreateGetTupleElement(shape, p, 1));
3519   HloInstruction* add = builder.AddInstruction(
3520       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3521   HloInstruction* negate7 = builder.AddInstruction(
3522       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, add));
3523   HloInstruction* tuple =
3524       builder.AddInstruction(HloInstruction::CreateTuple({p0, add}));
3525 
3526   auto module = CreateNewVerifiedModule();
3527   HloComputation* computation = module->AddEntryComputation(builder.Build());
3528 
3529   HloSchedule schedule(module.get());
3530   schedule.set_sequence(
3531       computation, {p, p0, negate0, negate1, negate2, negate3, negate4, negate5,
3532                     negate6, p1, add, negate7, tuple});
3533   TF_CHECK_OK(module->set_schedule(schedule));
3534 
3535   // Make input {0} alias with output {0} and input {1} alias with output {1}.
3536   TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({0}, 0, {0}));
3537   TF_CHECK_OK(module->input_output_alias_config().SetUpAlias({1}, 0, {1}));
3538 
3539   AssignMemorySpace(module.get());
3540 
3541   // Make sure the input is in the default memory space.
3542   EXPECT_EQ(p->shape().tuple_shapes(0).layout().memory_space(),
3543             kDefaultMemorySpace);
3544   EXPECT_EQ(p->shape().tuple_shapes(1).layout().memory_space(),
3545             kDefaultMemorySpace);
3546 }
3547 
TEST_P(MemorySpaceAssignmentTest,CostAnalysis)3548 TEST_P(MemorySpaceAssignmentTest, CostAnalysis) {
3549   // This is mostly a smoke test since it's difficult and brittle to work out
3550   // the cost of the HLO instructions.
3551   HloComputation::Builder builder(TestName());
3552   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3553   HloInstruction* p0 =
3554       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3555   HloInstruction* p1 =
3556       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3557   HloInstruction* negate0 = builder.AddInstruction(
3558       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3559   HloInstruction* negate1 = builder.AddInstruction(
3560       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3561   HloInstruction* negate2 = builder.AddInstruction(
3562       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3563   HloInstruction* negate3 = builder.AddInstruction(
3564       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3565   HloInstruction* negate4 = builder.AddInstruction(
3566       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3567   HloInstruction* negate5 = builder.AddInstruction(
3568       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3569   HloInstruction* negate6 = builder.AddInstruction(
3570       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3571   HloInstruction* add = builder.AddInstruction(
3572       HloInstruction::CreateBinary(shape, HloOpcode::kAdd, negate6, p1));
3573 
3574   auto module = CreateNewVerifiedModule();
3575   HloComputation* computation = module->AddEntryComputation(builder.Build());
3576 
3577   HloSchedule schedule(module.get());
3578   schedule.set_sequence(computation, {p0, p1, negate0, negate1, negate2,
3579                                       negate3, negate4, negate5, negate6, add});
3580   TF_CHECK_OK(module->set_schedule(schedule));
3581 
3582   AssignMemorySpaceUsingCostAnalysis(module.get());
3583   // Parameters are in the default memory space.
3584   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3585   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3586   // Negate instructions are in the alternate memory space (1).
3587   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3588       F32, {2, 3},
3589       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3590       kAlternateMemorySpace);
3591   EXPECT_THAT(negate0, op::ShapeWithLayout(shape_in_alternate_mem));
3592   EXPECT_THAT(negate1, op::ShapeWithLayout(shape_in_alternate_mem));
3593   EXPECT_THAT(negate2, op::ShapeWithLayout(shape_in_alternate_mem));
3594   EXPECT_THAT(negate3, op::ShapeWithLayout(shape_in_alternate_mem));
3595   EXPECT_THAT(negate4, op::ShapeWithLayout(shape_in_alternate_mem));
3596   EXPECT_THAT(negate5, op::ShapeWithLayout(shape_in_alternate_mem));
3597   EXPECT_THAT(negate6, op::ShapeWithLayout(shape_in_alternate_mem));
3598 }
3599 
TEST_P(MemorySpaceAssignmentTest,MemoryBoundednessBufferIntervalCompare)3600 TEST_P(MemorySpaceAssignmentTest, MemoryBoundednessBufferIntervalCompare) {
3601   // This test is carefully crafted to force only negates to be allocated to the
3602   // alternate memory. The graph consists of interleaving negate and tanh
3603   // operations:
3604   //
3605   //        +------+      +-------+      +-----
3606   //       /        \    /         \    /
3607   //  negate  tanh  negate  tanh   negate  tanh
3608   //             \          /  \           /
3609   //              +--------+    +---------+
3610   //
3611   // The alternate memory is sized to fit only two f32[4,3] tensors at a time.
3612   // Also, transcendentals are made to be lower bandwidth than FLOPs. So, the
3613   // MemoryBoundednessBufferIntervalCompare should prioritize the negates, which
3614   // are more memory bound.
3615   HloComputation::Builder builder(TestName());
3616   Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3617   HloInstruction* p0 =
3618       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3619   HloInstruction* p1 =
3620       builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1"));
3621   HloInstruction* tanh0 = builder.AddInstruction(
3622       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3623   HloInstruction* negate0 = builder.AddInstruction(
3624       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p1));
3625   HloInstruction* tanh1 = builder.AddInstruction(
3626       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh0));
3627   HloInstruction* negate1 = builder.AddInstruction(
3628       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3629   HloInstruction* tanh2 = builder.AddInstruction(
3630       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3631   HloInstruction* negate2 = builder.AddInstruction(
3632       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3633   HloInstruction* tanh3 = builder.AddInstruction(
3634       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3635   HloInstruction* negate3 = builder.AddInstruction(
3636       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3637   HloInstruction* tanh4 = builder.AddInstruction(
3638       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh3));
3639   HloInstruction* negate4 = builder.AddInstruction(
3640       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3641   HloInstruction* tuple =
3642       builder.AddInstruction(HloInstruction::CreateTuple({tanh4, negate4}));
3643 
3644   auto module = CreateNewVerifiedModule();
3645   HloComputation* computation = module->AddEntryComputation(builder.Build());
3646 
3647   HloSchedule schedule(module.get());
3648   schedule.set_sequence(computation,
3649                         {p0, p1, tanh0, negate0, tanh1, negate1, tanh2, negate2,
3650                          tanh3, negate3, tanh4, negate4, tuple});
3651   TF_CHECK_OK(module->set_schedule(schedule));
3652 
3653   AssignMemorySpaceUsingCostAnalysis(module.get());
3654   // Parameters are in the default memory space.
3655   EXPECT_THAT(p0, op::ShapeWithLayout(shape));
3656   EXPECT_THAT(p1, op::ShapeWithLayout(shape));
3657   Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3658       F32, {4, 3},
3659       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3660       kDefaultMemorySpace);
3661   // Expect only negates to be in alternate memory space. Not all might fit but
3662   // make sure at least one does.
3663   std::vector<HloInstruction*> negate_instructions = {negate0, negate1, negate2,
3664                                                       negate3, negate4};
3665   int64 num_negates_in_alternate_mem = absl::c_count_if(
3666       negate_instructions, [&](const HloInstruction* instruction) {
3667         return instruction->shape().layout().memory_space() ==
3668                kAlternateMemorySpace;
3669       });
3670   EXPECT_GE(num_negates_in_alternate_mem, 1);
3671   EXPECT_THAT(tanh0, op::ShapeWithLayout(shape_in_default_mem));
3672   EXPECT_THAT(tanh1, op::ShapeWithLayout(shape_in_default_mem));
3673   EXPECT_THAT(tanh2, op::ShapeWithLayout(shape_in_default_mem));
3674   EXPECT_THAT(tanh3, op::ShapeWithLayout(shape_in_default_mem));
3675   EXPECT_THAT(tanh4, op::ShapeWithLayout(shape_in_default_mem));
3676 }
3677 
TEST_P(MemorySpaceAssignmentTest,SimpleWhileTupleTest)3678 TEST_P(MemorySpaceAssignmentTest, SimpleWhileTupleTest) {
3679   Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
3680   Shape f32v1 = ShapeUtil::MakeShape(F32, {1});
3681   Shape t_s32_f32v1 = ShapeUtil::MakeTupleShape({s32, f32v1});
3682   auto module = CreateNewVerifiedModule("SimpleWhile");
3683   HloSchedule schedule(module.get());
3684 
3685   // A simple compare-to-limit (x < 4) computation for a While.
3686   //
3687   // condition:
3688   //   const4[s32] -----------------------------------\
3689   //                                                   \
3690   //   param[(s32,f32[4])] --- get-tuple-element[0] --- less-than
3691   //
3692   HloComputation* cond_computation;
3693   {
3694     auto builder = HloComputation::Builder("WhileCond");
3695     auto const4 = builder.AddInstruction(
3696         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(4)));
3697     auto param = builder.AddInstruction(
3698         HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3699     auto index = builder.AddInstruction(
3700         HloInstruction::CreateGetTupleElement(const4->shape(), param, 0));
3701     auto compare = builder.AddInstruction(
3702         HloInstruction::CreateCompare(ShapeUtil::MakeShape(PRED, {}), index,
3703                                       const4, ComparisonDirection::kLt));
3704     cond_computation = module->AddEmbeddedComputation(builder.Build());
3705     schedule.set_sequence(cond_computation, {const4, param, index, compare});
3706   }
3707 
3708   // Builds a simple body computation for a While.
3709   //
3710   // body:
3711   //   constv[f32[1]] --------------------------------------\
3712   //                                                         \
3713   //                           /--- get-tuple-elementv[1] --- addv ---\
3714   //   param[(s32,f32[1])] ---|                                    tuple
3715   //                           \--- get-tuple-elementc[0] --- addc ---/
3716   //                                                         /
3717   //   const1[s32] -----------------------------------------/
3718   //
3719   HloComputation* body_computation;
3720   {
3721     auto builder = HloComputation::Builder("WhileBody");
3722     auto const1 = builder.AddInstruction(
3723         HloInstruction::CreateConstant(LiteralUtil::CreateR0<int>(1)));
3724     auto constv = builder.AddInstruction(
3725         HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.1f})));
3726     auto param = builder.AddInstruction(
3727         HloInstruction::CreateParameter(0, t_s32_f32v1, "x"));
3728     auto indexc = builder.AddInstruction(
3729         HloInstruction::CreateGetTupleElement(const1->shape(), param, 0));
3730     auto addc = builder.AddInstruction(HloInstruction::CreateBinary(
3731         indexc->shape(), HloOpcode::kAdd, indexc, const1));
3732     auto indexv = builder.AddInstruction(
3733         HloInstruction::CreateGetTupleElement(constv->shape(), param, 1));
3734     auto addv = builder.AddInstruction(HloInstruction::CreateBinary(
3735         constv->shape(), HloOpcode::kAdd, indexv, constv));
3736     auto tuple =
3737         builder.AddInstruction(HloInstruction::CreateTuple({addc, addv}));
3738     body_computation = module->AddEmbeddedComputation(builder.Build());
3739     schedule.set_sequence(body_computation, {const1, constv, param, indexc,
3740                                              addc, indexv, addv, tuple});
3741   }
3742 
3743   // This tests a simple while loop where the parameters are aliased with the
3744   // output buffers.
3745   auto builder = HloComputation::Builder("SimpleWhile");
3746   auto param = builder.AddInstruction(
3747       HloInstruction::CreateParameter(0, t_s32_f32v1, "param"));
3748   auto gte0 = builder.AddInstruction(
3749       HloInstruction::CreateGetTupleElement(s32, param, 0));
3750   auto gte1 = builder.AddInstruction(
3751       HloInstruction::CreateGetTupleElement(f32v1, param, 1));
3752   auto tuple =
3753       builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1}));
3754   auto while0 = builder.AddInstruction(HloInstruction::CreateWhile(
3755       t_s32_f32v1, cond_computation, body_computation, tuple));
3756 
3757   HloComputation* computation = module->AddEntryComputation(builder.Build());
3758   schedule.set_sequence(computation, {param, gte0, gte1, tuple, while0});
3759   TF_CHECK_OK(module->set_schedule(schedule));
3760 
3761   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
3762                     /*max_prefetch_interval=*/50);
3763 
3764   // Ensure all parameters and while are placed in default memory.
3765   Shape shape_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3766       F32, {4, 6},
3767       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3768       kDefaultMemorySpace);
3769   Shape s32_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3770       xla::S32, {},
3771       /*minor_to_major=*/{}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3772       kDefaultMemorySpace);
3773   Shape f32v1_in_default_mem = ShapeUtil::MakeShapeWithLayout(
3774       F32, {1},
3775       /*minor_to_major=*/{0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3776       kDefaultMemorySpace);
3777   Shape t_s32_f32v1_in_default_mem =
3778       ShapeUtil::MakeTupleShape({s32_in_default_mem, f32v1_in_default_mem});
3779   EXPECT_THAT(param, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3780   EXPECT_THAT(while0, op::ShapeWithLayout(t_s32_f32v1_in_default_mem));
3781 }
3782 
TEST_P(MemorySpaceAssignmentTest,EvictionsShouldntBeDelayed)3783 TEST_P(MemorySpaceAssignmentTest, EvictionsShouldntBeDelayed) {
3784   // This test reproduces an eviction scheduling bug where evictions to default
3785   // memory can happen later than intended, causing memory corruption. This test
3786   // is a variant of MemoryBoundednessBufferIntervalCompare but uses f32[4,3]
3787   // tensors instead, so at most two tensors should fit in the alternate memory
3788   // space at a given time. We have a number of redundant operations
3789   // (tanh_redundant ops) that do not have users. The bug was due to
3790   // SimplifyGraph removing dead instructions, and removing them from the
3791   // schedule. However, the CopyStart/CopyDone insertion relies on the schedule
3792   // indexes, so they could be inserted too late.
3793   HloComputation::Builder builder(TestName());
3794   Shape shape = ShapeUtil::MakeShape(F32, {4, 3});
3795   HloInstruction* p0 =
3796       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3797   HloInstruction* tanh0 = builder.AddInstruction(
3798       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3799   HloInstruction* tanh_redundant0 = builder.AddInstruction(
3800       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3801   HloInstruction* tanh_redundant1 = builder.AddInstruction(
3802       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3803   HloInstruction* tanh_redundant2 = builder.AddInstruction(
3804       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3805   HloInstruction* tanh_redundant3 = builder.AddInstruction(
3806       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3807   HloInstruction* tanh_redundant4 = builder.AddInstruction(
3808       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3809   HloInstruction* tanh_redundant5 = builder.AddInstruction(
3810       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3811   HloInstruction* tanh_redundant6 = builder.AddInstruction(
3812       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, p0));
3813   HloInstruction* negate0 = builder.AddInstruction(
3814       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, tanh0));
3815   HloInstruction* tanh1 = builder.AddInstruction(
3816       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, negate0));
3817   HloInstruction* negate1 = builder.AddInstruction(
3818       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3819   HloInstruction* tanh2 = builder.AddInstruction(
3820       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh1));
3821   HloInstruction* negate2 = builder.AddInstruction(
3822       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3823   HloInstruction* tanh3 = builder.AddInstruction(
3824       HloInstruction::CreateUnary(shape, HloOpcode::kTanh, tanh2));
3825   HloInstruction* negate3 = builder.AddInstruction(
3826       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3827   HloInstruction* tuple = builder.AddInstruction(
3828       HloInstruction::CreateTuple({tanh3, negate3, tanh0}));
3829 
3830   auto module = CreateNewVerifiedModule();
3831   HloComputation* computation = module->AddEntryComputation(builder.Build());
3832 
3833   HloSchedule schedule(module.get());
3834   schedule.set_sequence(
3835       computation,
3836       {p0, tanh0, tanh_redundant0, tanh_redundant1, tanh_redundant2,
3837        tanh_redundant3, tanh_redundant4, tanh_redundant5, tanh_redundant6,
3838        negate0, tanh1, negate1, tanh2, negate2, tanh3, negate3, tuple});
3839   TF_CHECK_OK(module->set_schedule(schedule));
3840 
3841   AssignMemorySpaceUsingCostAnalysis(module.get());
3842 
3843   TF_ASSERT_OK_AND_ASSIGN(auto alias_analysis,
3844                           HloAliasAnalysis::Run(module.get()));
3845   TF_ASSERT_OK_AND_ASSIGN(auto hlo_live_range,
3846                           HloLiveRange::Run(module->schedule(), *alias_analysis,
3847                                             module->entry_computation()));
3848 
3849   std::vector<int> num_live_buffers_in_alternate_mem(
3850       hlo_live_range->flattened_instruction_sequence().size() + 1, 0);
3851 
3852   // Go through each value and for those that are allocated in the alternate
3853   // memory space, increment (inclusive) num_live_buffers_in_alternate_mem for
3854   // every time step that they are live.
3855   for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
3856     const Shape& shape = value->shape();
3857     if (!shape.has_layout() ||
3858         shape.layout().memory_space() == kDefaultMemorySpace) {
3859       continue;
3860     }
3861 
3862     HloLiveRange::TimeBound time_bound =
3863         hlo_live_range->buffer_live_ranges().at(value);
3864     for (int i = time_bound.start; i <= time_bound.end; ++i) {
3865       ++num_live_buffers_in_alternate_mem[i];
3866     }
3867   }
3868 
3869   // The test memory can at most hold two f32[4,3] buffers at a time. If there
3870   // is more than that, it means we have memory corruption.
3871   for (int i = 0; i < num_live_buffers_in_alternate_mem.size(); ++i) {
3872     EXPECT_LE(num_live_buffers_in_alternate_mem[i], 2);
3873   }
3874 }
3875 
TEST_P(MemorySpaceAssignmentTest,InputOutputsInAlternateMemShouldntBeAssigned)3876 TEST_P(MemorySpaceAssignmentTest,
3877        InputOutputsInAlternateMemShouldntBeAssigned) {
3878   // When input/outputs are marked to be in the alternate memory (e.g.
3879   // go/tpu-fast-mem-inference), do not allocate those and assume they will live
3880   // in the alternate memory for the entire computation. The BufferAssignment
3881   // pass, which is run after this, will allocate those buffers.
3882   HloComputation::Builder builder(TestName());
3883   Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
3884   Shape shape_in_alternate_mem = ShapeUtil::MakeShapeWithLayout(
3885       F32, {2, 3},
3886       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
3887       kAlternateMemorySpace);
3888   // p0 is in the default memory space.
3889   HloInstruction* p0 =
3890       builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0"));
3891   // p1 is in the alternate memory space.
3892   HloInstruction* p1 = builder.AddInstruction(
3893       HloInstruction::CreateParameter(1, shape_in_alternate_mem, "p1"));
3894   HloInstruction* negate0 = builder.AddInstruction(
3895       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, p0));
3896   HloInstruction* negate1 = builder.AddInstruction(
3897       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate0));
3898   HloInstruction* negate2 = builder.AddInstruction(
3899       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate1));
3900   HloInstruction* negate3 = builder.AddInstruction(
3901       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate2));
3902   HloInstruction* negate4 = builder.AddInstruction(
3903       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate3));
3904   HloInstruction* negate5 = builder.AddInstruction(
3905       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate4));
3906   HloInstruction* negate6 = builder.AddInstruction(
3907       HloInstruction::CreateUnary(shape, HloOpcode::kNegate, negate5));
3908   HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary(
3909       shape_in_alternate_mem, HloOpcode::kAdd, negate6, p1));
3910   // Index {0} of the root instruction is in the alternate memory space, index
3911   // {1} is in the default memory space.
3912   HloInstruction* tuple =
3913       builder.AddInstruction(HloInstruction::CreateTuple({add, negate5}));
3914 
3915   auto module = CreateNewVerifiedModule();
3916   HloComputation* computation = module->AddEntryComputation(builder.Build());
3917 
3918   HloSchedule schedule(module.get());
3919   schedule.set_sequence(computation,
3920                         {p0, p1, negate0, negate1, negate2, negate3, negate4,
3921                          negate5, negate6, add, tuple});
3922   TF_CHECK_OK(module->set_schedule(schedule));
3923 
3924   std::unique_ptr<PresetAssignments> preset_assignments =
3925       AssignMemorySpace(module.get());
3926 
3927   // Ensure that p1 is in the alternate memory and add, which has p1 as an
3928   // operand, has a direct dependency to p1 (no CopyStart/CopyDone).
3929   EXPECT_THAT(p1, op::ShapeWithLayout(shape_in_alternate_mem));
3930   EXPECT_THAT(add, op::Add(op::Negate(), op::Parameter(1)));
3931   // Make sure add is still in the alternate memory space.
3932   EXPECT_THAT(add, op::ShapeWithLayout(shape_in_alternate_mem));
3933 
3934   // Check the preset assignments and ensure the inputs/outputs in the alternate
3935   // memory space aren't in the preset assignments. Inputs/outputs in the
3936   // alternate memory space are left to BufferAssignment to be allocated.
3937   for (const auto& position_and_chunk : preset_assignments->chunks()) {
3938     const HloPosition& position = position_and_chunk.first;
3939     EXPECT_NE(position.instruction, p1);
3940     EXPECT_NE(position.instruction, add);
3941   }
3942 }
3943 
TEST_P(MemorySpaceAssignmentTest,PendingChunkMemoryCorruptionBug)3944 TEST_P(MemorySpaceAssignmentTest, PendingChunkMemoryCorruptionBug) {
3945   // Tests a memory corruption bug where the allocated chunk overlaps with a
3946   // pending chunk. To test this, we provide a new buffer interval compare where
3947   // we prioritize the allocation of sine, cosine, and tanh to create the
3948   // situation:
3949   //
3950   //    Max memory
3951   //  -------------------------------------------
3952   //      +------------+
3953   //      |     b      |
3954   //      +------------+
3955   //  +-------+
3956   //  |       |
3957   //  |       |
3958   //  |   a   |
3959   //  |       |                 +------------+
3960   //  |       |                 |     n      |
3961   //  +-------+                 +------------+
3962   //  -------------------------------------------
3963   //    Min memory          time ->
3964   //
3965   //
3966   // Then allocating for buffer d, we have these two prefetch buffers
3967   // overlapping:
3968   //
3969   //    Max memory
3970   //  -------------------------------------------
3971   //      +------------+ +----------+
3972   //      |     b      | | prefetch |
3973   //      +------------+ | for o    |
3974   //  +-------+     +---------+     |
3975   //  |       |     |    |    |     |
3976   //  |       |     |    |    |     |
3977   //  |   a   |     |    +----|-----+
3978   //  |       |     | prefetch| +------------+
3979   //  |       |     | for m   | |     n      |
3980   //  +-------+     +---------+ +------------+
3981   //  -------------------------------------------
3982   //    Min memory          time ->
3983   //
3984   absl::string_view hlo_string = R"(
3985   HloModule bug, is_scheduled=true
3986 
3987   ENTRY %Entry {
3988     %param0 = f32[8,3] parameter(0)
3989     %param1 = f32[2,4] parameter(1)
3990     %a = f32[8,3] sine(%param0)
3991     %b = f32[2,4] cosine(%param1)
3992     %d = f32[8,3] tanh(%a)
3993     %c = f32[8,3] negate(%a)
3994     %e = f32[2,4] negate(%b)
3995     %f = f32[2,4] negate(%e)
3996     %g = f32[2,4] negate(%f)
3997     %h = f32[2,4] negate(%g)
3998     %i = f32[2,4] negate(%h)
3999     %j = f32[2,4] negate(%i)
4000     %k = f32[2,4] negate(%j)
4001     %l = f32[2,4] negate(%k)
4002     %m = f32[8,3] negate(%d)
4003     %n = f32[2,4] sine(%l)
4004     %o = f32[8,3] negate(%d)
4005     %p = f32[2,4] negate(%n)
4006     %q = f32[8,3] negate(%m)
4007     ROOT %tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(%p, %q, %o)
4008   }
4009   )";
4010 
4011   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4012       [](const MemorySpaceAssignment::BufferInterval& a,
4013          const MemorySpaceAssignment::BufferInterval& b) {
4014         auto get_opcode_priority = [](const HloOpcode& opcode) {
4015           switch (opcode) {
4016             case HloOpcode::kSin:
4017               return 0;
4018             case HloOpcode::kCos:
4019               return 1;
4020             case HloOpcode::kTanh:
4021               return 2;
4022             default:
4023               return 3;
4024           }
4025         };
4026 
4027         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4028                get_opcode_priority(b.buffer->defining_instruction()->opcode());
4029       };
4030   TF_ASSERT_OK_AND_ASSIGN(auto module,
4031                           ParseAndReturnVerifiedModule(hlo_string));
4032 
4033   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4034   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4035                     buffer_interval_compare, &prefetch_interval_picker);
4036 }
4037 
TEST_P(MemorySpaceAssignmentTest,MoveCopyDoneEarlier)4038 TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) {
4039   // This tests the case where an earlier placed smaller buffer may block a
4040   // larger buffer due to asynchronous copy ordering. The smaller buffer (the
4041   // operand of sin) will be placed first. The cos, whose operand is 3 times
4042   // larger than sin's, needs longer time for the asynhronous copy. The cos is
4043   // placed right after sin, leading to a copy ordering violation:
4044   //
4045   // param1------------------>CS----->CD->sin
4046   // param0------------->CS------------------->CD->cos
4047   //
4048   // To fix this, we need to move copy done for cos earlier and ensure both of
4049   // these buffers get alternate memory allocations:
4050   //
4051   // param1------------------>CS----->CD->sin
4052   // param0-->CS------------------->CD------------>cos
4053   absl::string_view hlo_string = R"(
4054   HloModule module, is_scheduled=true
4055 
4056   ENTRY Entry {
4057     param0 = f32[8,3] parameter(0)
4058     param1 = f32[2,4] parameter(1)
4059     a = f32[2,4] negate(param1)
4060     b = f32[2,4] negate(a)
4061     c = f32[2,4] negate(b)
4062     d = f32[2,4] negate(c)
4063     e = f32[2,4] negate(d)
4064     f = f32[2,4] negate(e)
4065     g = f32[2,4] negate(f)
4066     h = f32[2,4] negate(g)
4067     i = f32[2,4] negate(h)
4068     j = f32[2,4] negate(i)
4069     k = f32[2,4] negate(j)
4070     l = f32[2,4] negate(k)
4071     m = f32[2,4] negate(l)
4072     n = f32[2,4] negate(m)
4073     sin = f32[2,4] sine(param1)
4074     o = f32[2,4] negate(n)
4075     cos = f32[8,3] cosine(param0)
4076     ROOT tuple = (f32[8,3], f32[2,4], f32[2,4]) tuple(cos, sin, o)
4077   }
4078   )";
4079 
4080   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4081       [](const MemorySpaceAssignment::BufferInterval& a,
4082          const MemorySpaceAssignment::BufferInterval& b) {
4083         auto get_opcode_priority = [](const HloOpcode& opcode) {
4084           switch (opcode) {
4085             case HloOpcode::kSin:
4086               return 0;
4087             case HloOpcode::kCos:
4088               return 1;
4089             case HloOpcode::kTanh:
4090               return 2;
4091             default:
4092               return 3;
4093           }
4094         };
4095 
4096         auto get_user_priority = [&](const HloValue& value) {
4097           int priority = INT_MAX;
4098           for (const auto& use : value.uses()) {
4099             priority = std::min(priority,
4100                                 get_opcode_priority(use.instruction->opcode()));
4101           }
4102           return priority;
4103         };
4104 
4105         return get_user_priority(*a.buffer) < get_user_priority(*b.buffer);
4106       };
4107   TF_ASSERT_OK_AND_ASSIGN(auto module,
4108                           ParseAndReturnVerifiedModule(hlo_string));
4109 
4110   HloCostAnalysis hlo_cost_analysis(ShapeSize);
4111   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
4112                           FakeMemorySpaceAssignmentCostAnalysis::Create(
4113                               hlo_cost_analysis, *module));
4114   cost_analysis->SetOverrideForGetAsyncCopyElapsed([](const Shape& shape) {
4115     // This should return 2 for f32[2,4] and 6 for f32[8,3].
4116     return ShapeSize(shape) / 16;
4117   });
4118   CostAnalysisPrefetchIntervalPicker interval_picker(
4119       *cost_analysis,
4120       /*min_async_copy_to_overlap_ratio=*/1.0,
4121       /*max_async_copy_to_overlap_ratio=*/4.0,
4122       /*preferred_async_copy_to_overlap_ratio=*/1.5);
4123   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4124                     buffer_interval_compare, &interval_picker);
4125 
4126   // Check that both cos and sin could get their operands prefetched.
4127   const HloInstruction* cos =
4128       module->entry_computation()->GetInstructionWithName("cos");
4129   const HloInstruction* sin =
4130       module->entry_computation()->GetInstructionWithName("sin");
4131   EXPECT_THAT(sin->operand(0),
4132               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4133                             op::Parameter(1)));
4134   EXPECT_THAT(cos->operand(0),
4135               op::AsyncCopy(kAlternateMemorySpace, kDefaultMemorySpace,
4136                             op::Parameter(0)));
4137 
4138   // Sanity check that the cos' operand copy-done is scheduled earlier than
4139   // sin's operand.
4140   auto find_schedule_index = [&](const HloInstruction* instruction) {
4141     const auto& instructions =
4142         module->schedule().sequence(module->entry_computation()).instructions();
4143     for (int i = 0; i < instructions.size(); ++i) {
4144       if (instruction == instructions[i]) {
4145         return i;
4146       }
4147     }
4148     CHECK(false);
4149     return -1;
4150   };
4151   EXPECT_GT(find_schedule_index(sin->operand(0)),
4152             find_schedule_index(cos->operand(0)));
4153 }
4154 
TEST_P(MemorySpaceAssignmentTest,WhileAliasedArgumentRequiredAssignmentBug)4155 TEST_P(MemorySpaceAssignmentTest, WhileAliasedArgumentRequiredAssignmentBug) {
4156   // Tests an overly pessimistic assertion when the same HloValue is passed
4157   // multiple times to a while HLO. We already handle this case that the two
4158   // arguments must alias and get the same allocation in AllocateSegment so the
4159   // assertion isn't necessary.
4160   absl::string_view hlo_string = R"(
4161   HloModule bug, is_scheduled=true
4162 
4163   while_condition {
4164     param1 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4165     ROOT cond = pred[] constant(true)
4166   }
4167 
4168   while_body {
4169     param2 = (f32[2,4], f32[2,4], f32[2,4]) parameter(0)
4170     gte2 = f32[2,4] get-tuple-element(param2), index=0
4171     gte3 = f32[2,4] get-tuple-element(param2), index=1
4172     gte4 = f32[2,4] get-tuple-element(param2), index=2
4173     add = f32[2,4] add(gte2, gte3)
4174     ROOT tuple2 = (f32[2,4], f32[2,4], f32[2,4]) tuple(add, gte3, gte4)
4175   }
4176 
4177   ENTRY Entry {
4178     param0 = f32[2,4] parameter(0)
4179     a = f32[2,4] negate(param0)
4180     b = f32[2,4] negate(param0)
4181     tuple = (f32[2,4], f32[2,4], f32[2,4]) tuple(a, b, b)
4182     while = (f32[2,4], f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
4183     gte1 = f32[2,4] get-tuple-element(while), index=0
4184     gte2 = f32[2,4] get-tuple-element(while), index=1
4185     ROOT root = f32[2,4] add(gte1, gte2)
4186   }
4187   )";
4188   TF_ASSERT_OK_AND_ASSIGN(auto module,
4189                           ParseAndReturnVerifiedModule(hlo_string));
4190   AssignMemorySpace(module.get());
4191 }
4192 
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBug)4193 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBug) {
4194   // When we have a disallowed use (in this case tanh), we aren't allowed to
4195   // allocate this use in alternate memory. However, if we have another use
4196   // after this on the same buffer (o), this use may refer to "a" instead of the
4197   // evicted value, which is illegal because "a" will be allocated in the
4198   // alternate memory space.
4199   absl::string_view hlo_string = R"(
4200   HloModule bug, is_scheduled=true
4201 
4202   ENTRY Entry {
4203     param0 = f32[8,3] parameter(0)
4204     param1 = f32[2,4] parameter(1)
4205     a = f32[8,3] cosine(param0)
4206     b = f32[2,4] negate(param1)
4207     d = f32[8,3] negate(a)
4208     c = f32[2,4] negate(b)
4209     e = f32[2,4] negate(c)
4210     f = f32[8,3] tanh(a)
4211     g = f32[2,4] negate(e)
4212     h = f32[2,4] negate(g)
4213     i = f32[2,4] negate(h)
4214     j = f32[2,4] negate(i)
4215     k = f32[2,4] negate(j)
4216     l = f32[2,4] negate(k)
4217     m = f32[2,4] negate(l)
4218     n = f32[2,4] sine(m)
4219     o = f32[8,3] negate(a)
4220     p = f32[2,4] negate(n)
4221     q = f32[8,3] add(o, f)
4222     r = f32[8,3] add(q, d)
4223     ROOT tuple = (f32[2,4], f32[8,3]) tuple(p, r)
4224   }
4225   )";
4226 
4227   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4228       [](const MemorySpaceAssignment::BufferInterval& a,
4229          const MemorySpaceAssignment::BufferInterval& b) {
4230         auto get_opcode_priority = [](const HloOpcode& opcode) {
4231           switch (opcode) {
4232             case HloOpcode::kSin:
4233               return 0;
4234             case HloOpcode::kCos:
4235               return 1;
4236             case HloOpcode::kTanh:
4237               return 2;
4238             default:
4239               return 3;
4240           }
4241         };
4242 
4243         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4244                get_opcode_priority(b.buffer->defining_instruction()->opcode());
4245       };
4246   TF_ASSERT_OK_AND_ASSIGN(auto module,
4247                           ParseAndReturnVerifiedModule(hlo_string));
4248 
4249   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4250   MemorySpaceAssignment::Options options;
4251   options.max_size_in_bytes = 128;
4252   options.alignment_in_bytes = 8;
4253   options.verify = true;
4254   options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4255     return use.instruction->opcode() != HloOpcode::kTanh;
4256   };
4257   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4258                     buffer_interval_compare, &prefetch_interval_picker,
4259                     options);
4260 }
4261 
TEST_P(MemorySpaceAssignmentTest,DisallowedUseBugInWhile)4262 TEST_P(MemorySpaceAssignmentTest, DisallowedUseBugInWhile) {
4263   // Test for situations where we disallow a use (tanh in this case) in the
4264   // alternate memory space and there is a subsequent use that also requires the
4265   // buffer to be in the default memory space. In this case, the allocation in
4266   // the default memory space might not be the very last one, so we need to
4267   // search the allocation sequence and find the one in the default memory
4268   // space.
4269   absl::string_view hlo_string = R"(
4270   HloModule module, is_scheduled=true
4271 
4272   while_cond {
4273     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4274     ROOT gte = pred[] get-tuple-element(p0), index=3
4275   }
4276 
4277   while_body {
4278     p0 = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) parameter(0)
4279     gte0 = f32[3]{0} get-tuple-element(p0), index=0
4280     gte1 = f32[3]{0} get-tuple-element(p0), index=1
4281     gte2 = f32[3]{0} get-tuple-element(p0), index=2
4282     gte3 = pred[] get-tuple-element(p0), index=3
4283     add = f32[3]{0} add(gte0, gte0)
4284     negate0 = f32[3]{0} negate(add)
4285     negate1 = f32[3]{0} negate(negate0)
4286     negate2 = f32[3]{0} negate(negate1)
4287     negate3 = f32[3]{0} negate(negate2)
4288     negate4 = f32[3]{0} negate(negate3)
4289     negate5 = f32[3]{0} negate(negate4)
4290     negate6 = f32[3]{0} negate(negate5)
4291     negate7 = f32[3]{0} negate(negate6)
4292     negate8 = f32[3]{0} negate(negate7)
4293     negate9 = f32[3]{0} negate(negate8)
4294     negate10 = f32[3]{0} negate(negate9)
4295     negate11 = f32[3]{0} negate(negate10)
4296     negate12 = f32[3]{0} negate(negate11)
4297     negate13 = f32[3]{0} negate(negate12)
4298     negate14 = f32[3]{0} negate(negate13)
4299     negate15 = f32[3]{0} negate(gte2)
4300     tanh = f32[3]{0} tanh(gte2)
4301     ROOT tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(negate14, tanh, gte2, gte3)
4302   }
4303 
4304   ENTRY entry {
4305     p0 = f32[3]{0} parameter(0)
4306     p1 = pred[] parameter(1)
4307     copy0 = f32[3]{0} copy(p0)
4308     copy1 = f32[3]{0} copy(p0)
4309     tuple = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) tuple(copy0, copy0, copy1, p1)
4310     while = (f32[3]{0}, f32[3]{0}, f32[3]{0}, pred[]) while(tuple), condition=while_cond, body=while_body
4311     ROOT gte = f32[3]{0} get-tuple-element(while), index=2
4312   }
4313   )";
4314 
4315   TF_ASSERT_OK_AND_ASSIGN(auto module,
4316                           ParseAndReturnVerifiedModule(hlo_string));
4317   MemorySpaceAssignment::Options options;
4318   options.max_size_in_bytes = 128;
4319   options.alignment_in_bytes = 8;
4320   options.verify = true;
4321   options.is_use_allowed_in_alternate_mem_fn = [](const HloUse& use) {
4322     return use.instruction->opcode() != HloOpcode::kTanh;
4323   };
4324   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4325                     /*max_prefetch_interval=*/10, /*min_prefetch_interval=*/2,
4326                     options);
4327 }
4328 
TEST_P(MemorySpaceAssignmentTest,BitcastRoot)4329 TEST_P(MemorySpaceAssignmentTest, BitcastRoot) {
4330   // Tests against a bug where the root of entry computation is a bitcast
4331   // instruction and it ends up getting an allocation in the alternate memory.
4332   absl::string_view hlo_string = R"(
4333 HloModule primitive_computation_gather.4, is_scheduled=true
4334 
4335 %while_body {
4336   %param.1 = (s32[], f32[3,3,3]) parameter(0)
4337   %get-tuple-element.32 = s32[] get-tuple-element(%param.1), index=0
4338   %copy.6 = s32[] copy(s32[] %get-tuple-element.32)
4339   %constant.8 = s32[] constant(1)
4340   %add = s32[] add(s32[] %copy.6, s32[] %constant.8)
4341   %get-tuple-element.35 = f32[3,3,3] get-tuple-element(%param.1), index=1
4342   negate = f32[3,3,3] negate(get-tuple-element.35)
4343   ROOT %tuple.10 = (s32[], f32[3,3,3]) tuple(s32[] %add, f32[3,3,3] negate)
4344 }
4345 
4346 %while_cond {
4347   %param.0 = (s32[], f32[3,3,3]) parameter(0)
4348   %get-tuple-element = s32[] get-tuple-element(%param.0), index=0
4349   %constant.3 = s32[] constant(3)
4350   ROOT %compare = pred[] compare(s32[] %get-tuple-element, s32[] %constant.3), direction=LT
4351 }
4352 
4353 ENTRY %primitive_computation_gather.4 (parameter.1: f32[3,10,5], parameter.2: s32[3,1]) -> f32[3,3,3] {
4354   %constant.1 = s32[] constant(0)
4355   %copy.11 = s32[] copy(s32[] %constant.1)
4356   %constant = f32[] constant(0)
4357   %broadcast = f32[3,3,3] broadcast(f32[] %constant), dimensions={}
4358   %tuple.8 = (s32[], f32[3,10,5], s32[3,1], f32[3,3,3]) tuple(s32[] %copy.11, f32[3,3,3] %broadcast)
4359   %while = (s32[], f32[3,3,3]) while(%tuple.8), condition=%while_cond, body=%while_body
4360   %get-tuple-element.7 = f32[3,3,3] get-tuple-element(%while), index=1
4361   ROOT %bitcast.1 = f32[3,3,3] bitcast(f32[3,3,3] %get-tuple-element.7)
4362 }
4363   )";
4364 
4365   TF_ASSERT_OK_AND_ASSIGN(auto module,
4366                           ParseAndReturnVerifiedModule(hlo_string));
4367   AssignMemorySpace(module.get());
4368 
4369   const HloInstruction* root = module->entry_computation()->root_instruction();
4370   EXPECT_TRUE(!root->shape().has_layout() ||
4371               root->shape().layout().memory_space() == kDefaultMemorySpace);
4372 }
4373 
4374 // A mock MemorySpaceAssignmentRepacker class that accepst a map of
4375 // (start_time,offset) -> new_offset values. Using this map, the repacker
4376 // repacks the allocations to the new_offset.
4377 class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker {
4378  public:
FakeMemorySpaceAssignmentRepacker(absl::flat_hash_map<std::pair<int64,int64>,int64> & repack_map,std::function<void (absl::Span<AllocationBlock * >)> check_fun=nullptr,bool always_return_modified=false)4379   explicit FakeMemorySpaceAssignmentRepacker(
4380       absl::flat_hash_map<std::pair<int64, int64>, int64>& repack_map,
4381       std::function<void(absl::Span<AllocationBlock*>)> check_fun = nullptr,
4382       bool always_return_modified = false)
4383       : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8),
4384         repack_map_(repack_map),
4385         check_fun_(check_fun),
4386         always_return_modified_(always_return_modified) {}
4387 
Repack(absl::Span<AllocationBlock * > allocations)4388   StatusOr<bool> Repack(absl::Span<AllocationBlock*> allocations) override {
4389     bool modified = false;
4390     for (AllocationBlock* block : allocations) {
4391       absl::flat_hash_set<int64> colocations;
4392       std::string colocations_str;
4393       for (const AllocationBlock* colocation : block->colocations) {
4394         absl::StrAppend(&colocations_str, colocation->id, ", ");
4395         colocations.insert(colocation->id);
4396       }
4397       VLOG(1) << "Alloc id: " << block->id << " time: [" << block->start_time
4398               << ", " << block->end_time << "] size: " << block->size
4399               << " init offset: " << block->initial_offset << " colocations: {"
4400               << colocations_str << "}";
4401       auto it = repack_map_.find({block->start_time, block->initial_offset});
4402       if (it != repack_map_.end()) {
4403         modified = true;
4404         block->offset = it->second;
4405       } else {
4406         block->offset = block->initial_offset;
4407       }
4408       for (AllocationBlock* colocation : block->colocations) {
4409         if (it != repack_map_.end()) {
4410           colocation->offset = it->second;
4411         } else {
4412           colocation->offset = colocation->initial_offset;
4413         }
4414       }
4415     }
4416     if (check_fun_) {
4417       check_fun_(allocations);
4418     }
4419 
4420     return always_return_modified_ || modified;
4421   }
4422 
4423  private:
4424   // A map from (start_time, offset) to new_offset.
4425   absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map_;
4426   std::function<void(absl::Span<AllocationBlock*>)> check_fun_;
4427   bool always_return_modified_;
4428 };
4429 
TEST_P(MemorySpaceAssignmentTest,Repack)4430 TEST_P(MemorySpaceAssignmentTest, Repack) {
4431   // We initially perform the following allocations at these offsets.
4432   //
4433   //    Max memory
4434   //  -------------------------------------------
4435   //
4436   //
4437   //
4438   //
4439   //      +------------+
4440   //      |     b      |
4441   //      +------------+
4442   //  +-------+                 +------------+
4443   //  |   a   |                 |     n      |
4444   //  +-------+                 +------------+
4445   //  -------------------------------------------
4446   //    Min memory          time ->
4447   //
4448   // Next up, we try to allocate the prefetch for m. However due to
4449   // fragmentation, this won't be possible:
4450   //
4451   //    Max memory
4452   //  -------------------------------------------
4453   //
4454   //
4455   //
4456   //                +---------+
4457   //      +------------+      |
4458   //      |     b   |  |      |
4459   //      +------------+      |
4460   //  +-------+     |         | +------------+
4461   //  |   a   |     |    d    | |     n      |
4462   //  +-------+     +---------+ +------------+
4463   //  -------------------------------------------
4464   //    Min memory          time ->
4465   //
4466   // We then call repack to repack the existing allocations which allows us to
4467   // allocate the prefetch for m:
4468   //
4469   //    Max memory
4470   //  -------------------------------------------
4471   //                +---------+
4472   //                |         |
4473   //                |         |
4474   //                |         |
4475   //  +-------+     |         |
4476   //  |   a   |     |    d    |
4477   //  +-------+     +---------+
4478   //      +------------+        +------------+
4479   //      |      b     |        |     n      |
4480   //      +------------+        +------------+
4481   //  -------------------------------------------
4482   //    Min memory          time ->
4483   absl::string_view hlo_string = R"(
4484   HloModule bug, is_scheduled=true
4485 
4486   ENTRY Entry {
4487     param0 = f32[8,3] parameter(0)
4488     param1 = f32[2,4] parameter(1)
4489     a = f32[2,4] sine(param1)
4490     b = f32[2,4] cosine(param1)
4491     c = f32[8,3] negate(param0)
4492     j = f32[2,4] negate(a)
4493     d = f32[8,3] tanh(param0)
4494     k = f32[2,4] negate(j)
4495     l = f32[2,4] add(b, k)
4496     m = f32[8,3] negate(d)
4497     n = f32[2,4] sine(l)
4498     o = f32[8,3] negate(m)
4499     p = f32[2,4] negate(n)
4500     q = f32[8,3] negate(m)
4501     ROOT tuple = (f32[2,4], f32[8,3], f32[8,3]) tuple(p, q, o)
4502   }
4503   )";
4504 
4505   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4506       [](const MemorySpaceAssignment::BufferInterval& a,
4507          const MemorySpaceAssignment::BufferInterval& b) {
4508         auto get_opcode_priority = [](const HloOpcode& opcode) {
4509           switch (opcode) {
4510             case HloOpcode::kSin:
4511               return 0;
4512             case HloOpcode::kCos:
4513               return 1;
4514             case HloOpcode::kTanh:
4515               return 2;
4516             default:
4517               return 3;
4518           }
4519         };
4520 
4521         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4522                get_opcode_priority(b.buffer->defining_instruction()->opcode());
4523       };
4524   TF_ASSERT_OK_AND_ASSIGN(auto module,
4525                           ParseAndReturnVerifiedModule(hlo_string));
4526 
4527   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4528   absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map;
4529   // Move "a" from offset 0 to 32.
4530   repack_map[{2, 0}] = 32;
4531   // Move "b" from offset 32 to 0.
4532   repack_map[{3, 32}] = 0;
4533   FakeMemorySpaceAssignmentRepacker repacker =
4534       FakeMemorySpaceAssignmentRepacker(repack_map);
4535   MemorySpaceAssignment::Options options;
4536   options.max_size_in_bytes = 128;
4537   options.alignment_in_bytes = 8;
4538   options.verify = true;
4539   options.max_repacks = 1;
4540   options.repacker = &repacker;
4541   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4542                     buffer_interval_compare, &prefetch_interval_picker,
4543                     options);
4544 
4545   // If repacking succeeds, we should find the buffer for d in alternate memory.
4546   const HloInstruction* d =
4547       module->entry_computation()->GetInstructionWithName("d");
4548   EXPECT_EQ(d->shape().layout().memory_space(), kAlternateMemorySpace);
4549 }
4550 
TEST_P(MemorySpaceAssignmentTest,RepackExportsAliasedOffsets)4551 TEST_P(MemorySpaceAssignmentTest, RepackExportsAliasedOffsets) {
4552   // This test is that we are correctly exporting aliased offsets for repacking.
4553   // In this example, the buffer produced at HLO "a" will be allocated first,
4554   // and will consist of four allocations:
4555   //    1) a produced in the alternate memory (and then evicted to the default
4556   //    memory). 2) a prefetched to the alternate memory to be used by q and
4557   //    while HLOs. 3) a used within the while loop body. 4) the output of while
4558   //    HLO, used by u.
4559   //
4560   // Since a will be allocated first (the test is crafted to prioritize sine
4561   // HLO), all four allocations should get the same (zero) offsets. However,
4562   // while allocations 2, 3, and 4 need to be colocated with each other,
4563   // allocation 1 doesn't need to be colocated with the other three.
4564   absl::string_view hlo_string = R"(
4565   HloModule bug, is_scheduled=true
4566 
4567   while_condition {
4568     param1 = (f32[2,4], f32[2,4]) parameter(0)
4569     ROOT cond = pred[] constant(true)
4570   }
4571 
4572   while_body {
4573     param2 = (f32[2,4], f32[2,4]) parameter(0)
4574     gte2 = f32[2,4] get-tuple-element(param2), index=0
4575     gte3 = f32[2,4] get-tuple-element(param2), index=1
4576     add = f32[2,4] add(gte2, gte3)
4577     ROOT tuple2 = (f32[2,4], f32[2,4]) tuple(add, gte3)
4578   }
4579 
4580   ENTRY Entry {
4581     param0 = f32[2,4] parameter(0)
4582     a = f32[2,4] sine(param0)
4583     b = f32[2,4] negate(a)
4584     c = f32[2,4] negate(b)
4585     d = f32[2,4] negate(c)
4586     e = f32[2,4] negate(d)
4587     f = f32[2,4] negate(e)
4588     g = f32[2,4] negate(f)
4589     h = f32[2,4] negate(g)
4590     i = f32[2,4] negate(h)
4591     j = f32[2,4] negate(i)
4592     k = f32[2,4] negate(j)
4593     l = f32[2,4] negate(k)
4594     m = f32[2,4] negate(l)
4595     n = f32[2,4] negate(m)
4596     o = f32[2,4] negate(n)
4597     p = f32[2,4] negate(o)
4598     q = f32[2,4] add(p, a)
4599     tuple = (f32[2,4], f32[2,4]) tuple(q, a)
4600     while = (f32[2,4], f32[2,4]) while(tuple), condition=while_condition, body=while_body
4601     gte0 = f32[2,4] get-tuple-element(while), index=0
4602     gte1 = f32[2,4] get-tuple-element(while), index=1
4603     r = f32[2,4] negate(gte0)
4604     s = f32[2,4] negate(r)
4605     t = f32[2,4] negate(s)
4606     constant = f32[] constant(0)
4607     broadcast = f32[8,4] broadcast(constant), dimensions={}
4608     cos = f32[8,4] cosine(broadcast)
4609     u = f32[2,4] add(t, gte1)
4610     v = f32[2,4] add(u, param0)
4611     w = f32[8,4] negate(cos)
4612     ROOT tuple3 = (f32[2,4], f32[8,4]) tuple(v, w)
4613   }
4614   )";
4615 
4616   MemorySpaceAssignment::BufferIntervalCompare buffer_interval_compare =
4617       [](const MemorySpaceAssignment::BufferInterval& a,
4618          const MemorySpaceAssignment::BufferInterval& b) {
4619         auto get_opcode_priority = [](const HloOpcode& opcode) {
4620           switch (opcode) {
4621             case HloOpcode::kSin:
4622               return 0;
4623             case HloOpcode::kCos:
4624               return 1;
4625             case HloOpcode::kTanh:
4626               return 2;
4627             default:
4628               return 3;
4629           }
4630         };
4631 
4632         return get_opcode_priority(a.buffer->defining_instruction()->opcode()) <
4633                get_opcode_priority(b.buffer->defining_instruction()->opcode());
4634       };
4635   TF_ASSERT_OK_AND_ASSIGN(auto module,
4636                           ParseAndReturnVerifiedModule(hlo_string));
4637 
4638   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4639   absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map;
4640 
4641   // Expect that of the four separate allocations for the "a" buffer, the first
4642   // and the next three are in separate colocations.
4643   auto check_fun =
4644       [](absl::Span<MemorySpaceAssignmentRepacker::AllocationBlock*>
4645              allocations) {
4646         EXPECT_TRUE(allocations.at(0)->colocations.size() == 1 ||
4647                     allocations.at(0)->colocations.size() == 3);
4648         EXPECT_EQ(allocations.at(1)->colocations.size(), 3);
4649         EXPECT_EQ(allocations.at(2)->colocations.size(), 3);
4650         EXPECT_TRUE(allocations.at(3)->colocations.size() == 1 ||
4651                     allocations.at(3)->colocations.size() == 3);
4652       };
4653   FakeMemorySpaceAssignmentRepacker repacker =
4654       FakeMemorySpaceAssignmentRepacker(repack_map, check_fun);
4655   MemorySpaceAssignment::Options options;
4656   options.max_size_in_bytes = 128;
4657   options.alignment_in_bytes = 8;
4658   options.verify = true;
4659   options.max_repacks = 1;
4660   options.repacker = &repacker;
4661   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4662                     buffer_interval_compare, &prefetch_interval_picker,
4663                     options);
4664 }
4665 
TEST_P(MemorySpaceAssignmentTest,RepackShouldntEraseRequiredAssignmentForConditionalOutput)4666 TEST_P(MemorySpaceAssignmentTest,
4667        RepackShouldntEraseRequiredAssignmentForConditionalOutput) {
4668   // This is a test case for b/171040271. Repacks erase the required assignments
4669   // (since some required assignments are inserted conditionally based on
4670   // allocation decisions), including the fact that conditional outputs are
4671   // always required to get assignments in the default memory. After repacking,
4672   // this required assignment was never added back, causing conditionals to get
4673   // alternate-memory allocations.
4674   absl::string_view hlo_string = R"(
4675   HloModule CondAllocation, is_scheduled=true
4676 
4677   true_computation {
4678     p0 = (f32[3]) parameter(0)
4679     gte = f32[3] get-tuple-element(p0), index=0
4680     neg1 = f32[3] negate(gte)
4681     ROOT tuple1 = (f32[3]) tuple(neg1)
4682   }
4683 
4684   false_computation {
4685     p0 = (f32[3]) parameter(0)
4686     gte = f32[3] get-tuple-element(p0), index=0
4687     neg2 = f32[3] negate(gte)
4688     ROOT tuple2 = (f32[3]) tuple(neg2)
4689   }
4690 
4691   ENTRY entry {
4692     p0 = f32[3] parameter(0)
4693     p1 = pred[] parameter(1)
4694     copy = f32[3] copy(p0)
4695     tuple = (f32[3]) tuple(copy)
4696     conditional = (f32[3]) conditional(p1, tuple, tuple), true_computation=true_computation, false_computation=false_computation
4697     ROOT gte = f32[3] get-tuple-element(conditional), index=0
4698   }
4699   )";
4700   TF_ASSERT_OK_AND_ASSIGN(auto module,
4701                           ParseAndReturnVerifiedModule(hlo_string));
4702   absl::flat_hash_map<std::pair<int64, int64>, int64> repack_map;
4703   FakeMemorySpaceAssignmentRepacker repacker =
4704       FakeMemorySpaceAssignmentRepacker(repack_map, nullptr,
4705                                         /*always_return_modified=*/true);
4706   MemorySpaceAssignment::Options options;
4707   options.max_size_in_bytes = 128;
4708   options.alignment_in_bytes = 8;
4709   options.verify = true;
4710   options.max_repacks = 10;
4711   options.repacker = &repacker;
4712   options.repack_after_every_allocation = true;
4713   InstructionCountPrefetchIntervalPicker prefetch_interval_picker(2, 10);
4714   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
4715                     /*buffer_interval_compare=*/{}, &prefetch_interval_picker,
4716                     options);
4717   // Make sure the root of the entry computation is in the default memory space.
4718   EXPECT_EQ(module->entry_computation()
4719                 ->root_instruction()
4720                 ->shape()
4721                 .layout()
4722                 .memory_space(),
4723             kDefaultMemorySpace);
4724 }
4725 
TEST_P(MemorySpaceAssignmentTest,Determinism)4726 TEST_P(MemorySpaceAssignmentTest, Determinism) {
4727   // Run memory space assignment a few times to make sure every time it compiles
4728   // to the same thing.
4729   std::unique_ptr<HloModule> module = CreateEvictAndPrefetchModule();
4730 
4731   AssignMemorySpace(module.get());
4732   std::string module_str = module->ToString();
4733 
4734   for (int i = 0; i < 10; ++i) {
4735     std::unique_ptr<HloModule> other_module = CreateEvictAndPrefetchModule();
4736     AssignMemorySpace(other_module.get());
4737     EXPECT_EQ(module_str, other_module->ToString());
4738   }
4739 }
4740 
TEST_P(MemorySpaceAssignmentTest,InPlaceOp)4741 TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
4742   // Tests that in-place ops like DynamicUpdateSlice get the same allocation as
4743   // its input.
4744   absl::string_view hlo_string = R"(
4745 HloModule Module, is_scheduled=true
4746 
4747 fused_computation {
4748   param0 = f32[2,3] parameter(0)
4749   constant.1 = f32[] constant(0)
4750   broadcast = f32[2,1] broadcast(constant.1), dimensions={}
4751   constant.3 = s32[] constant(0)
4752   ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
4753 }
4754 
4755 ENTRY main {
4756   param = f32[2,3] parameter(0)
4757   negate = f32[2,3] negate(param)
4758   fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
4759   ROOT add = f32[2,3] add(fusion, fusion)
4760 }
4761   )";
4762 
4763   TF_ASSERT_OK_AND_ASSIGN(auto module,
4764                           ParseAndReturnVerifiedModule(hlo_string));
4765   auto preset_assignments = AssignMemorySpace(module.get());
4766   HloInstruction* negate_instruction =
4767       module->entry_computation()->GetInstructionWithName("negate");
4768   int64 negate_offset =
4769       GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
4770   HloInstruction* fusion_instruction =
4771       module->entry_computation()->GetInstructionWithName("fusion");
4772   int64 fusion_offset =
4773       GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
4774   // We expect negate and fusion to get the same offsets.
4775   EXPECT_EQ(negate_offset, fusion_offset);
4776   const bool allocate_across_sequential_calls = GetParam();
4777   if (allocate_across_sequential_calls) {
4778     EXPECT_NE(negate_offset, -1);
4779   }
4780 }
4781 
4782 INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
4783                          MemorySpaceAssignmentTest,
4784                          ::testing::Values(false, true));
4785 
4786 using AsynchronousCopyOrderingTest = ::testing::Test;
4787 
TEST_F(AsynchronousCopyOrderingTest,Simple)4788 TEST_F(AsynchronousCopyOrderingTest, Simple) {
4789   // Given asynchronous copies like the following, ensure the pipelining order
4790   // is maintained (earlier start time must have earlier end time).
4791   // 3,11       +-------+         OK
4792   // 1,8      +------+            OK
4793   // 5,14         +--------+      OK
4794   // 7,14           +------+      OK
4795   // 2,16      +-------------+    Violate
4796   // 9,12             +--+        Violate
4797   // 6,17          +----------+   Violate
4798   // 5,13         +-------+       OK (same start as 5,14)
4799   // 5,14         +--------+      OK (same as 5,14)
4800   auto alternate_mem_space = MemorySpaceAssignment::MemorySpace::kAlternate;
4801   AsynchronousCopyOrdering ordering;
4802   EXPECT_FALSE(ordering.ViolatesOrdering(3, 11));
4803   ordering.AddCopy({3, 11, alternate_mem_space});
4804   EXPECT_FALSE(ordering.ViolatesOrdering(1, 8));
4805   ordering.AddCopy({1, 8, alternate_mem_space});
4806   EXPECT_FALSE(ordering.ViolatesOrdering(5, 14));
4807   ordering.AddCopy({5, 14, alternate_mem_space});
4808   EXPECT_FALSE(ordering.ViolatesOrdering(7, 14));
4809   ordering.AddCopy({7, 14, alternate_mem_space});
4810   EXPECT_TRUE(ordering.ViolatesOrdering(2, 16));
4811   EXPECT_TRUE(ordering.ViolatesOrdering(9, 12));
4812   EXPECT_TRUE(ordering.ViolatesOrdering(6, 17));
4813   EXPECT_FALSE(ordering.ViolatesOrdering(5, 13));
4814   ordering.AddCopy({5, 13, alternate_mem_space});
4815   EXPECT_FALSE(ordering.ViolatesOrdering(5, 14));
4816   ordering.AddCopy({5, 14, alternate_mem_space});
4817 }
4818 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTest)4819 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) {
4820   HloComputation::Builder builder(TestName());
4821 
4822   constexpr int kBatch = 8;
4823   constexpr int kFeature = 8;
4824   constexpr int kOutput = 2;
4825 
4826   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4827   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4828   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4829   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4830   HloInstruction* param = builder.AddInstruction(
4831       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
4832 
4833   auto lhs = builder.AddInstruction(
4834       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
4835   auto rhs = builder.AddInstruction(
4836       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
4837 
4838   DotDimensionNumbers dot_dnums;
4839   dot_dnums.add_lhs_contracting_dimensions(1);
4840   dot_dnums.add_rhs_contracting_dimensions(0);
4841   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4842       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4843 
4844   auto module = CreateNewVerifiedModule();
4845   HloComputation* computation = module->AddEntryComputation(builder.Build());
4846 
4847   HloSchedule schedule(module.get());
4848   schedule.set_sequence(computation, {param, lhs, rhs, dot});
4849   TF_CHECK_OK(module->set_schedule(schedule));
4850 
4851   AssignMemorySpace(module.get());
4852 
4853   auto cross_program_prefetches = module->CrossProgramPrefetches();
4854   EXPECT_EQ(cross_program_prefetches.size(), 1);
4855   if (!cross_program_prefetches.empty()) {
4856     EXPECT_EQ(cross_program_prefetches[0].first, 0);
4857     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
4858   }
4859 }
4860 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchBitcastTest)4861 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchBitcastTest) {
4862   HloComputation::Builder builder(TestName());
4863 
4864   constexpr int kBatch = 8;
4865   constexpr int kFeature = 8;
4866   constexpr int kOutput = 2;
4867 
4868   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4869   auto rhs_shape = ShapeUtil::MakeShape(F32, {kOutput, kFeature});
4870   auto bitcast_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4871   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4872   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4873   HloInstruction* param = builder.AddInstruction(
4874       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
4875 
4876   auto lhs = builder.AddInstruction(
4877       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
4878   auto rhs = builder.AddInstruction(
4879       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
4880 
4881   auto bitcast =
4882       builder.AddInstruction(HloInstruction::CreateBitcast(bitcast_shape, rhs));
4883 
4884   DotDimensionNumbers dot_dnums;
4885   dot_dnums.add_lhs_contracting_dimensions(1);
4886   dot_dnums.add_rhs_contracting_dimensions(0);
4887   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4888       result_shape, lhs, bitcast, dot_dnums, DefaultPrecisionConfig(2)));
4889 
4890   auto module = CreateNewVerifiedModule();
4891   HloComputation* computation = module->AddEntryComputation(builder.Build());
4892 
4893   HloSchedule schedule(module.get());
4894   schedule.set_sequence(computation, {param, lhs, rhs, bitcast, dot});
4895   TF_CHECK_OK(module->set_schedule(schedule));
4896 
4897   AssignMemorySpace(module.get());
4898 
4899   auto cross_program_prefetches = module->CrossProgramPrefetches();
4900   EXPECT_EQ(cross_program_prefetches.size(), 1);
4901   if (!cross_program_prefetches.empty()) {
4902     EXPECT_EQ(cross_program_prefetches[0].first, 0);
4903     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
4904   }
4905 }
4906 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNestedTupleTest)4907 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) {
4908   HloComputation::Builder builder(TestName());
4909 
4910   constexpr int kBatch = 8;
4911   constexpr int kFeature = 8;
4912   constexpr int kOutput = 2;
4913 
4914   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4915   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4916   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4917   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4918   auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape});
4919   HloInstruction* param = builder.AddInstruction(
4920       HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0"));
4921 
4922   auto gte = builder.AddInstruction(
4923       HloInstruction::CreateGetTupleElement(tuple_shape, param, 0));
4924 
4925   auto lhs = builder.AddInstruction(
4926       HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0));
4927   auto rhs = builder.AddInstruction(
4928       HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1));
4929 
4930   DotDimensionNumbers dot_dnums;
4931   dot_dnums.add_lhs_contracting_dimensions(1);
4932   dot_dnums.add_rhs_contracting_dimensions(0);
4933   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4934       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4935 
4936   auto module = CreateNewVerifiedModule();
4937   HloComputation* computation = module->AddEntryComputation(builder.Build());
4938 
4939   HloSchedule schedule(module.get());
4940   schedule.set_sequence(computation, {param, gte, lhs, rhs, dot});
4941   TF_CHECK_OK(module->set_schedule(schedule));
4942 
4943   AssignMemorySpace(module.get());
4944 
4945   auto cross_program_prefetches = module->CrossProgramPrefetches();
4946   EXPECT_EQ(cross_program_prefetches.size(), 0);
4947 }
4948 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchUnusedParamTest)4949 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) {
4950   HloComputation::Builder builder(TestName());
4951 
4952   constexpr int kFeature = 8;
4953   constexpr int kOutput = 2;
4954 
4955   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4956   HloInstruction* param = builder.AddInstruction(
4957       HloInstruction::CreateParameter(0, rhs_shape, "p0"));
4958 
4959   auto module = CreateNewVerifiedModule();
4960   HloComputation* computation = module->AddEntryComputation(builder.Build());
4961 
4962   HloSchedule schedule(module.get());
4963   schedule.set_sequence(computation, {param});
4964   TF_CHECK_OK(module->set_schedule(schedule));
4965 
4966   AssignMemorySpace(module.get());
4967 
4968   auto cross_program_prefetches = module->CrossProgramPrefetches();
4969   EXPECT_EQ(cross_program_prefetches.size(), 0);
4970 }
4971 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchTooBigTest)4972 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) {
4973   HloComputation::Builder builder(TestName());
4974 
4975   constexpr int kBatch = 8;
4976   constexpr int kFeature = 8;
4977   constexpr int kOutput = 8;
4978 
4979   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
4980   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
4981   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
4982   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
4983   HloInstruction* param = builder.AddInstruction(
4984       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
4985 
4986   auto lhs = builder.AddInstruction(
4987       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
4988   auto rhs = builder.AddInstruction(
4989       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
4990 
4991   DotDimensionNumbers dot_dnums;
4992   dot_dnums.add_lhs_contracting_dimensions(1);
4993   dot_dnums.add_rhs_contracting_dimensions(0);
4994   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
4995       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
4996 
4997   auto module = CreateNewVerifiedModule();
4998   HloComputation* computation = module->AddEntryComputation(builder.Build());
4999 
5000   HloSchedule schedule(module.get());
5001   schedule.set_sequence(computation, {param, lhs, rhs, dot});
5002   TF_CHECK_OK(module->set_schedule(schedule));
5003 
5004   AssignMemorySpace(module.get());
5005 
5006   auto cross_program_prefetches = module->CrossProgramPrefetches();
5007   EXPECT_EQ(cross_program_prefetches.size(), 0);
5008 }
5009 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchFusionTest)5010 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
5011   HloComputation::Builder builder(TestName());
5012 
5013   constexpr int kBatch = 2;
5014   constexpr int kFeature = 2;
5015   constexpr int kOutput = 2;
5016 
5017   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
5018   auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
5019   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
5020   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
5021 
5022   auto module = CreateNewVerifiedModule();
5023   HloComputation::Builder fusion_builder("fusion");
5024   {
5025     HloInstruction* param = fusion_builder.AddInstruction(
5026         HloInstruction::CreateParameter(0, tuple_shape, "p0"));
5027     auto lhs = fusion_builder.AddInstruction(
5028         HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
5029     auto rhs = fusion_builder.AddInstruction(
5030         HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
5031     DotDimensionNumbers dot_dnums;
5032     dot_dnums.add_lhs_contracting_dimensions(1);
5033     dot_dnums.add_rhs_contracting_dimensions(0);
5034     auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
5035         result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5036     (void)dot;
5037   }
5038   HloComputation* fusion_computation =
5039       module->AddEmbeddedComputation(fusion_builder.Build());
5040 
5041   auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
5042       LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
5043   auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
5044       LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
5045   HloInstruction* tuple = builder.AddInstruction(
5046       HloInstruction::CreateTuple({activations, weights}));
5047   HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
5048       result_shape, HloInstruction::FusionKind::kCustom, {tuple},
5049       fusion_computation));
5050 
5051   HloComputation* computation = module->AddEntryComputation(builder.Build());
5052 
5053   HloSchedule schedule(module.get());
5054   schedule.set_sequence(computation, {activations, weights, tuple, fusion});
5055   TF_CHECK_OK(module->set_schedule(schedule));
5056 
5057   AssignMemorySpace(module.get());
5058 
5059   auto cross_program_prefetches = module->CrossProgramPrefetches();
5060   EXPECT_EQ(cross_program_prefetches.size(), 0);
5061 }
5062 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchPinnedTest)5063 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) {
5064   HloComputation::Builder builder(TestName());
5065 
5066   constexpr int kBatch = 8;
5067   constexpr int kFeature = 8;
5068   constexpr int kOutput = 2;
5069 
5070   auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
5071   auto rhs_shape = ShapeUtil::MakeShapeWithLayout(
5072       F32, {kFeature, kOutput},
5073       /*minor_to_major=*/{1, 0}, /*tiles=*/{}, /*element_size_in_bits=*/0,
5074       kAlternateMemorySpace);
5075   auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
5076   auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
5077   HloInstruction* param = builder.AddInstruction(
5078       HloInstruction::CreateParameter(0, tuple_shape, "p0"));
5079 
5080   auto lhs = builder.AddInstruction(
5081       HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
5082   auto rhs = builder.AddInstruction(
5083       HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
5084 
5085   DotDimensionNumbers dot_dnums;
5086   dot_dnums.add_lhs_contracting_dimensions(1);
5087   dot_dnums.add_rhs_contracting_dimensions(0);
5088   auto dot = builder.AddInstruction(HloInstruction::CreateDot(
5089       result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
5090 
5091   auto module = CreateNewVerifiedModule();
5092   HloComputation* computation = module->AddEntryComputation(builder.Build());
5093 
5094   HloSchedule schedule(module.get());
5095   schedule.set_sequence(computation, {param, lhs, rhs, dot});
5096   TF_CHECK_OK(module->set_schedule(schedule));
5097 
5098   AssignMemorySpace(module.get());
5099 
5100   auto cross_program_prefetches = module->CrossProgramPrefetches();
5101   EXPECT_EQ(cross_program_prefetches.size(), 0);
5102 }
5103 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchReuse)5104 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) {
5105   // This test is for checking if the cross-program-prefetched buffer is freed
5106   // after its last use and there is an end-of-program prefetch.
5107   absl::string_view hlo_string = R"(
5108   HloModule cross_program_prefetch, is_scheduled=true
5109 
5110   ENTRY CrossProgramPrefetch {
5111     p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
5112     get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
5113     get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
5114     dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
5115     negate.1 = f32[8,2]{1,0} negate(dot)
5116     negate.2 = f32[8,2]{1,0} negate(negate.1)
5117     negate.3 = f32[8,2]{1,0} negate(negate.2)
5118     negate.4 = f32[8,2]{1,0} negate(negate.3)
5119     negate.5 = f32[8,2]{1,0} negate(negate.4)
5120     negate.6 = f32[8,2]{1,0} negate(negate.5)
5121     negate.7 = f32[8,2]{1,0} negate(negate.6)
5122     negate.8 = f32[8,2]{1,0} negate(negate.7)
5123     ROOT negate.9 = f32[8,2]{1,0} negate(negate.8)
5124   }
5125   )";
5126   TF_ASSERT_OK_AND_ASSIGN(auto module,
5127                           ParseAndReturnVerifiedModule(hlo_string));
5128 
5129   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5130                     /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
5131 
5132   auto cross_program_prefetches = module->CrossProgramPrefetches();
5133   EXPECT_EQ(cross_program_prefetches.size(), 1);
5134   if (!cross_program_prefetches.empty()) {
5135     EXPECT_EQ(cross_program_prefetches[0].first, 0);
5136     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
5137   }
5138 
5139   TF_ASSERT_OK_AND_ASSIGN(
5140       std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
5141       HloDataflowAnalysis::Run(*module));
5142   const HloValue& cross_program_prefetched_value =
5143       dataflow_analysis->GetValueDefinedAt(
5144           module->entry_computation()->parameter_instruction(0), {1});
5145   // Expect that there are two prefetches that use this value, one is the
5146   // cross-program prefetch, the other is the end-of-program prefetch.
5147   auto is_cross_program_prefetch = [](const HloUse& use) {
5148     return use.instruction->opcode() == HloOpcode::kCopyStart &&
5149            use.instruction->is_cross_program_prefetch();
5150   };
5151   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5152                              is_cross_program_prefetch),
5153             1);
5154   auto is_end_of_program_prefetch = [](const HloUse& use) {
5155     return use.instruction->opcode() == HloOpcode::kCopyStart &&
5156            !use.instruction->is_cross_program_prefetch();
5157   };
5158   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5159                              is_end_of_program_prefetch),
5160             1);
5161 }
5162 
TEST_P(MemorySpaceAssignmentTest,CrossProgramPrefetchNoReuse)5163 TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) {
5164   // This tests the scenario that the cross-program-prefetched buffer is used
5165   // again close to the end of the computation. In this case, it is better not
5166   // to free the buffer.
5167   absl::string_view hlo_string = R"(
5168   HloModule cross_program_prefetch, is_scheduled=true
5169 
5170   ENTRY CrossProgramPrefetch {
5171     p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0)
5172     get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0
5173     get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1
5174     dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
5175     negate.1 = f32[8,2]{1,0} negate(dot)
5176     negate.2 = f32[8,2]{1,0} negate(negate.1)
5177     negate.3 = f32[8,2]{1,0} negate(negate.2)
5178     negate.4 = f32[8,2]{1,0} negate(negate.3)
5179     negate.5 = f32[8,2]{1,0} negate(negate.4)
5180     negate.6 = f32[8,2]{1,0} negate(negate.5)
5181     negate.7 = f32[8,2]{1,0} negate(negate.6)
5182     negate.8 = f32[8,2]{1,0} negate(negate.7)
5183     ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0}
5184   }
5185   )";
5186   TF_ASSERT_OK_AND_ASSIGN(auto module,
5187                           ParseAndReturnVerifiedModule(hlo_string));
5188 
5189   AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1,
5190                     /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2);
5191 
5192   auto cross_program_prefetches = module->CrossProgramPrefetches();
5193   EXPECT_EQ(cross_program_prefetches.size(), 1);
5194   if (!cross_program_prefetches.empty()) {
5195     EXPECT_EQ(cross_program_prefetches[0].first, 0);
5196     EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
5197   }
5198 
5199   TF_ASSERT_OK_AND_ASSIGN(
5200       std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
5201       HloDataflowAnalysis::Run(*module));
5202   const HloValue& cross_program_prefetched_value =
5203       dataflow_analysis->GetValueDefinedAt(
5204           module->entry_computation()->parameter_instruction(0), {1});
5205   // Expect that there is one prefetch that use this value, the cross-program
5206   // prefetch. There shouldn't be an end-of-program prefetch.
5207   auto is_cross_program_prefetch = [](const HloUse& use) {
5208     return use.instruction->opcode() == HloOpcode::kCopyStart &&
5209            use.instruction->is_cross_program_prefetch();
5210   };
5211   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5212                              is_cross_program_prefetch),
5213             1);
5214   auto is_end_of_program_prefetch = [](const HloUse& use) {
5215     return use.instruction->opcode() == HloOpcode::kCopyStart &&
5216            !use.instruction->is_cross_program_prefetch();
5217   };
5218   EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(),
5219                              is_end_of_program_prefetch),
5220             0);
5221 }
5222 
5223 using CostAnalysisPrefetchIntervalPickerTest = HloTestBase;
5224 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrder)5225 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) {
5226   absl::string_view hlo_string = R"(
5227   HloModule bug, is_scheduled=true
5228 
5229   ENTRY Entry {
5230     param0 = f32[2,4] parameter(0)
5231     a = f32[2,4] negate(param0)
5232     b = f32[2,4] negate(a)
5233     c = f32[2,4] negate(b)
5234     d = f32[2,4] negate(c)
5235     e = f32[2,4] negate(d)
5236     f = f32[2,4] negate(e)
5237     g = f32[2,4] negate(f)
5238     h = f32[2,4] negate(g)
5239     i = f32[2,4] negate(h)
5240     j = f32[2,4] negate(i)
5241     k = f32[2,4] negate(j)
5242     l = f32[2,4] negate(k)
5243     m = f32[2,4] negate(l)
5244     n = f32[2,4] negate(m)
5245     o = f32[2,4] negate(n)
5246     p = f32[2,4] negate(o)
5247     q = f32[2,4] negate(p)
5248     r = f32[2,4] negate(q)
5249     s = f32[2,4] negate(r)
5250     t = f32[2,4] negate(s)
5251     u = f32[2,4] negate(t)
5252     ROOT v = f32[2,4] add(u, param0)
5253   }
5254   )";
5255   TF_ASSERT_OK_AND_ASSIGN(auto module,
5256                           ParseAndReturnVerifiedModule(hlo_string));
5257 
5258   HloCostAnalysis hlo_cost_analysis(ShapeSize);
5259   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5260                           FakeMemorySpaceAssignmentCostAnalysis::Create(
5261                               hlo_cost_analysis, *module));
5262   CostAnalysisPrefetchIntervalPicker interval_picker(
5263       *cost_analysis,
5264       /*min_async_copy_to_overlap_ratio=*/1.0,
5265       /*max_async_copy_to_overlap_ratio=*/4.0,
5266       /*preferred_async_copy_to_overlap_ratio=*/2.0);
5267 
5268   HloInstruction* root = module->entry_computation()->root_instruction();
5269   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
5270   interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/22);
5271 
5272   // Expect that the first interval is (15, 22), which has elapsed time of 6.0,
5273   // twice of the async copy elased (3.0). Then we expect that intervals will be
5274   // visited in alternating increasing and decreasing orders until hitting the
5275   // min and max async copy overlap ratios, which are the intervals (18, 22)
5276   // and (9, 22) respectively.
5277   LOG(INFO) << interval_picker.ToDebugString();
5278   EXPECT_EQ(interval_picker.Next(), 15);
5279   LOG(INFO) << interval_picker.ToDebugString();
5280   EXPECT_EQ(interval_picker.Next(), 16);
5281   LOG(INFO) << interval_picker.ToDebugString();
5282   EXPECT_EQ(interval_picker.Next(), 14);
5283   LOG(INFO) << interval_picker.ToDebugString();
5284   EXPECT_EQ(interval_picker.Next(), 17);
5285   LOG(INFO) << interval_picker.ToDebugString();
5286   EXPECT_EQ(interval_picker.Next(), 13);
5287   LOG(INFO) << interval_picker.ToDebugString();
5288   EXPECT_EQ(interval_picker.Next(), 18);  // Min async overlap ratio reached.
5289   LOG(INFO) << interval_picker.ToDebugString();
5290   EXPECT_EQ(interval_picker.Next(), 12);
5291   LOG(INFO) << interval_picker.ToDebugString();
5292   EXPECT_EQ(interval_picker.Next(), 11);
5293   LOG(INFO) << interval_picker.ToDebugString();
5294   EXPECT_EQ(interval_picker.Next(), 10);
5295   LOG(INFO) << interval_picker.ToDebugString();
5296   EXPECT_EQ(interval_picker.Next(), 9);  // Max async overlap ratio reached.
5297   LOG(INFO) << interval_picker.ToDebugString();
5298   EXPECT_TRUE(interval_picker.Done());
5299 
5300   // Expect that if the time between start_time and end_time is too short, there
5301   // won't be any available intervals.
5302   interval_picker.Begin(use, /*start_time=*/19, /*end_time=*/22);
5303   LOG(INFO) << interval_picker.ToDebugString();
5304   EXPECT_TRUE(interval_picker.Done());
5305 }
5306 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,PrefetchIntervalOrderWhile)5307 TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrderWhile) {
5308   absl::string_view hlo_string = R"(
5309   HloModule bug, is_scheduled=true
5310 
5311   while_condition {
5312     param1 = (f32[2,4]) parameter(0)    // 19
5313     ROOT cond = pred[] constant(true)   // 20
5314   }
5315 
5316   while_body {
5317     param2 = (f32[2,4]) parameter(0)    // 21
5318     gte2 = f32[2,4] get-tuple-element(param2), index=0  // 22
5319     add = f32[2,4] add(gte2, gte2)      // 23
5320     ROOT tuple2 = (f32[2,4]) tuple(add) // 24
5321   }
5322 
5323   ENTRY Entry {
5324     param0 = f32[2,4] parameter(0)  // 0
5325     a = f32[2,4] negate(param0)     // 1
5326     b = f32[2,4] negate(a)          // 2
5327     c = f32[2,4] negate(b)          // 3
5328     d = f32[2,4] negate(c)          // 4
5329     e = f32[2,4] negate(d)          // 5
5330     f = f32[2,4] negate(e)          // 6
5331     g = f32[2,4] negate(f)          // 7
5332     h = f32[2,4] negate(g)          // 8
5333     i = f32[2,4] negate(h)          // 9
5334     j = f32[2,4] negate(i)          // 10
5335     k = f32[2,4] negate(j)          // 11
5336     l = f32[2,4] negate(k)          // 12
5337     m = f32[2,4] negate(l)          // 13
5338     n = f32[2,4] negate(m)          // 14
5339     o = f32[2,4] negate(n)          // 15
5340     p = f32[2,4] negate(o)          // 16
5341     q = f32[2,4] negate(p)          // 17
5342     tuple = (f32[2,4]) tuple(q)     // 18
5343     while = (f32[2,4]) while(tuple), condition=while_condition, body=while_body  // 25
5344     gte1 = f32[2,4] get-tuple-element(while), index=0  // 26
5345     r = f32[2,4] negate(gte1)       // 27
5346     s = f32[2,4] negate(r)          // 28
5347     t = f32[2,4] negate(s)          // 29
5348     u = f32[2,4] negate(t)          // 30
5349     ROOT v = f32[2,4] add(u, param0)  // 31
5350   }
5351   )";
5352   TF_ASSERT_OK_AND_ASSIGN(auto module,
5353                           ParseAndReturnVerifiedModule(hlo_string));
5354 
5355   HloCostAnalysis hlo_cost_analysis(ShapeSize);
5356   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5357                           FakeMemorySpaceAssignmentCostAnalysis::Create(
5358                               hlo_cost_analysis, *module));
5359   CostAnalysisPrefetchIntervalPicker interval_picker(
5360       *cost_analysis,
5361       /*min_async_copy_to_overlap_ratio=*/1.0,
5362       /*max_async_copy_to_overlap_ratio=*/12.0,
5363       /*preferred_async_copy_to_overlap_ratio=*/2.0);
5364 
5365   HloInstruction* root = module->entry_computation()->root_instruction();
5366   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
5367   interval_picker.Begin(use, /*start_time=*/0, /*end_time=*/31);
5368 
5369   // Because there are while loop computations between [19, 24], we ensure that
5370   // the interval picker avoids this interval.
5371   LOG(INFO) << interval_picker.ToDebugString();
5372   EXPECT_EQ(interval_picker.Next(), 25);
5373   LOG(INFO) << interval_picker.ToDebugString();
5374   EXPECT_EQ(interval_picker.Next(), 26);
5375   LOG(INFO) << interval_picker.ToDebugString();
5376   EXPECT_EQ(interval_picker.Next(), 18);
5377   LOG(INFO) << interval_picker.ToDebugString();
5378   EXPECT_EQ(interval_picker.Next(), 27);  // Min async overlap ratio reached.
5379   LOG(INFO) << interval_picker.ToDebugString();
5380   EXPECT_EQ(interval_picker.Next(), 17);  // Max async overlap ratio reached.
5381   LOG(INFO) << interval_picker.ToDebugString();
5382   EXPECT_TRUE(interval_picker.Done());
5383 }
5384 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,NestedWhile)5385 TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) {
5386   // This test is to check against a bug where we didn't assign
5387   // while_nest_level_ for while instructions, and defaulting to 0. This could
5388   // cause the prefetch interval logic to think a nested while instruction is
5389   // the same level as the outermost computation.
5390   absl::string_view hlo_string = R"(
5391   HloModule bug, is_scheduled=true
5392 
5393   while_condition.2 {
5394     param1 = (f32[2,4]) parameter(0)    // 11
5395     ROOT cond = pred[] constant(true)   // 12
5396   }
5397 
5398   while_body.2 {
5399     param2 = (f32[2,4]) parameter(0)    // 13
5400     gte2 = f32[2,4] get-tuple-element(param2), index=0  // 14
5401     add = f32[2,4] add(gte2, gte2)      // 15
5402     ROOT tuple2 = (f32[2,4]) tuple(add) // 16
5403   }
5404 
5405   while_condition.1 {
5406     param3 = (f32[2,4]) parameter(0)    // 5
5407     ROOT cond = pred[] constant(true)   // 6
5408   }
5409 
5410   while_body.1 {
5411     param4 = (f32[2,4]) parameter(0)    // 7
5412     gte1 = f32[2,4] get-tuple-element(param4), index=0  // 8
5413     add1 = f32[2,4] add(gte1, gte1)     // 9
5414     tuple1 = (f32[2,4]) tuple(add1)     // 10
5415     while = (f32[2,4]) while(tuple1), condition=while_condition.2, body=while_body.2  // 17
5416     gte2 = f32[2,4] get-tuple-element(while), index=0  // 18
5417     add2 = f32[2,4] add(gte2, gte2)     // 19
5418     ROOT tuple2 = (f32[2,4]) tuple(add2)  // 20
5419   }
5420 
5421   ENTRY Entry {
5422     param0 = f32[2,4] parameter(0)  // 0
5423     a = f32[2,4] negate(param0)     // 1
5424     b = f32[2,4] negate(a)          // 2
5425     c = f32[2,4] negate(b)          // 3
5426     tuple = (f32[2,4]) tuple(c)     // 4
5427     while = (f32[2,4]) while(tuple), condition=while_condition.1, body=while_body.1  // 21
5428     gte1 = f32[2,4] get-tuple-element(while), index=0  // 22
5429     ROOT root = f32[2,4] add(gte1, param0)  // 23
5430   }
5431   )";
5432   TF_ASSERT_OK_AND_ASSIGN(auto module,
5433                           ParseAndReturnVerifiedModule(hlo_string));
5434 
5435   HloCostAnalysis hlo_cost_analysis(ShapeSize);
5436   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5437                           FakeMemorySpaceAssignmentCostAnalysis::Create(
5438                               hlo_cost_analysis, *module));
5439   CostAnalysisPrefetchIntervalPicker interval_picker(
5440       *cost_analysis,
5441       /*min_async_copy_to_overlap_ratio=*/1.0,
5442       /*max_async_copy_to_overlap_ratio=*/12.0,
5443       /*preferred_async_copy_to_overlap_ratio=*/2.0);
5444 
5445   HloInstruction* root = module->entry_computation()->root_instruction();
5446   const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}};
5447   const Shape& shape = root->operand(1)->shape();
5448 
5449   // We expect the root's latest prefetch start time to be before the while loop
5450   // (logical time 4).
5451   EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
5452                                                     /*end_time=*/23, &use),
5453             4);
5454 }
5455 
TEST_F(CostAnalysisPrefetchIntervalPickerTest,ConsecutiveConditionals)5456 TEST_F(CostAnalysisPrefetchIntervalPickerTest, ConsecutiveConditionals) {
5457   // This is a test for b/170668492, where prefetching for consecutive
5458   // conditionals can cause the prefetch to start in the conditional's
5459   // computation.
5460   absl::string_view hlo_string = R"(
5461   HloModule bug, is_scheduled=true
5462 
5463   true_computation.0 {
5464     p0 = (f32[3]{0}) parameter(0)                   // 5
5465     gte = f32[3]{0} get-tuple-element(p0), index=0  // 6
5466     ROOT neg1 = f32[3]{0} negate(gte)               // 7
5467   }
5468 
5469   false_computation.0 {
5470     p0 = (f32[3]{0}) parameter(0)                   // 8
5471     gte = f32[3]{0} get-tuple-element(p0), index=0  // 9
5472     ROOT neg2 = f32[3]{0} negate(gte)               // 10
5473   }
5474 
5475   true_computation.1 {
5476     p0 = (f32[3]{0}) parameter(0)                   // 12
5477     gte = f32[3]{0} get-tuple-element(p0), index=0  // 13
5478     ROOT neg1 = f32[3]{0} negate(gte)               // 14
5479   }
5480 
5481   false_computation.1 {
5482     p0 = (f32[3]{0}) parameter(0)                   // 15
5483     gte = f32[3]{0} get-tuple-element(p0), index=0  // 16
5484     ROOT neg2 = f32[3]{0} negate(gte)               // 17
5485   }
5486 
5487   ENTRY entry {
5488     p0 = f32[3]{0} parameter(0)       // 0
5489     p1 = f32[3]{0} parameter(1)       // 1
5490     p2 = pred[] parameter(2)          // 2
5491     tuple0 = (f32[3]{0}) tuple(p0)    // 3
5492     tuple1 = (f32[3]{0}) tuple(p1)    // 4
5493     conditional0 = f32[3]{0} conditional(p2, tuple0, tuple0), true_computation=true_computation.0, false_computation=false_computation.0  // 11
5494     conditional1 = f32[3]{0} conditional(p2, tuple1, tuple1), true_computation=true_computation.1, false_computation=false_computation.1  // 18
5495     ROOT tuple2 = (f32[3]{0}, f32[3]{0}) tuple(conditional0, conditional1)  // 19
5496   }
5497   )";
5498   TF_ASSERT_OK_AND_ASSIGN(auto module,
5499                           ParseAndReturnVerifiedModule(hlo_string));
5500 
5501   HloCostAnalysis hlo_cost_analysis(ShapeSize);
5502   TF_ASSERT_OK_AND_ASSIGN(auto cost_analysis,
5503                           FakeMemorySpaceAssignmentCostAnalysis::Create(
5504                               hlo_cost_analysis, *module));
5505   CostAnalysisPrefetchIntervalPicker interval_picker(
5506       *cost_analysis,
5507       /*min_async_copy_to_overlap_ratio=*/1.0,
5508       /*max_async_copy_to_overlap_ratio=*/12.0,
5509       /*preferred_async_copy_to_overlap_ratio=*/2.0);
5510 
5511   LOG(INFO) << module->ToString();
5512 
5513   HloInstruction* conditional1 =
5514       module->entry_computation()->GetInstructionWithName("conditional1");
5515   const HloUse use{conditional1, /*operand_number=*/1, /*operand_index=*/{0}};
5516   const Shape& shape =
5517       module->entry_computation()->parameter_instruction(0)->shape();
5518 
5519   // Expect that the prefetch to start before conditional0's called
5520   // computations.
5521   EXPECT_LT(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0,
5522                                                     /*end_time=*/11, &use),
5523             5);
5524 }
5525 
5526 }  // namespace
5527 }  // namespace xla
5528