1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/memory_space_assignment.h"
17 
18 #include "tensorflow/compiler/xla/debug_options_flags.h"
19 #include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
20 #include "tensorflow/core/lib/math/math_util.h"
21 namespace xla {
22 
23 namespace {
24 // Define a dummy chunk for chunks that will be allocated in the default memory
25 // space and for keeping track of number of asynchronous copies.
26 const HeapSimulator::Chunk kDummyChunk{-1, -1};
27 // This variable is used by the cost analysis in estimating how many times each
28 // while loop will execute. Nested loops will be assumed to have executed
29 // pow(kWhileExecutionCount, nesting_level) times.
30 const int kWhileExecutionCount = 5;
31 
LooksLikeAnActivation(const HloInstruction * inst)32 bool LooksLikeAnActivation(const HloInstruction* inst) {
33   for (HloInstruction* user : inst->users()) {
34     switch (user->opcode()) {
35       case HloOpcode::kConvolution:
36       case HloOpcode::kDot:
37         if (user->operand(0) == inst) {
38           return true;
39         }
40         break;
41       case HloOpcode::kGather:
42         if (user->operand(1) == inst) {
43           return true;
44         }
45         break;
46       case HloOpcode::kFusion:
47         for (int i = 0; i < user->operand_count(); ++i) {
48           if (user->operand(i) == inst &&
49               LooksLikeAnActivation(user->fused_parameter(i))) {
50             return true;
51           }
52         }
53         break;
54       case HloOpcode::kBitcast:
55         return LooksLikeAnActivation(user);
56       default:
57         return true;
58     }
59   }
60   return false;
61 }
62 
IsCrossProgramPrefetchCandidate(const HloValue & value,const MemorySpaceAssignment::Options & options)63 bool IsCrossProgramPrefetchCandidate(
64     const HloValue& value, const MemorySpaceAssignment::Options& options) {
65   return value.instruction()->parent() ==
66              value.instruction()->GetModule()->entry_computation() &&
67          value.instruction()->opcode() == HloOpcode::kParameter &&
68          (!value.shape().has_layout() ||
69           value.shape().layout().memory_space() !=
70               options.alternate_memory_space) &&
71          value.index().size() == 1 && value.shape().IsArray() &&
72          !value.uses().empty() &&
73          options.size_fn(value) <= options.max_size_in_bytes &&
74          absl::c_all_of(value.uses(), [&](const HloUse& use) {
75            const HloInstruction* inst =
76                use.instruction->operand(use.operand_number);
77 
78            // Skip the LooksLikeAnActivation test since we're testing the
79            // parent GTE and its children below.
80            if (inst->opcode() == HloOpcode::kBitcast &&
81                inst->operand(0)->opcode() == HloOpcode::kGetTupleElement &&
82                inst->operand(0)->operand(0)->opcode() ==
83                    HloOpcode::kParameter) {
84              return true;
85            }
86 
87            return inst->opcode() == HloOpcode::kGetTupleElement &&
88                   !LooksLikeAnActivation(inst);
89          });
90 }
91 
92 absl::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(const HloAliasAnalysis & alias_analysis,const HloLiveRange & hlo_live_range,const MemorySpaceAssignment::Options & options)93 FindCrossProgramPrefetchCandidate(
94     const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range,
95     const MemorySpaceAssignment::Options& options) {
96   std::vector<MemorySpaceAssignment::BufferInterval> candidates;
97   for (const HloBuffer& buffer : alias_analysis.buffers()) {
98     CHECK_GE(buffer.values().size(), 1);
99     const HloValue* value = buffer.values().at(0);
100     if (IsCrossProgramPrefetchCandidate(*value, options)) {
101       MemorySpaceAssignment::BufferInterval interval;
102       interval.buffer = value;
103       interval.size = options.size_fn(*value);
104       interval.start = 0;
105       interval.end = hlo_live_range.schedule_end_time();
106       interval.need_allocation = true;
107       interval.colocations = {++buffer.values().begin(), buffer.values().end()};
108       candidates.emplace_back(interval);
109     }
110   }
111 
112   // The buffer_interval_compare ought to do a good job picking the most
113   // appropriate buffer to cross program prefetch, but empirically, it makes
114   // worse choices than just picking the largest buffer.
115   // TODO(b/152421603): Investigate.
116   auto size_compare = [](const auto& x, const auto& y) {
117     return x.size < y.size;
118   };
119   auto& compare = options.default_cross_program_prefetch_heuristic &&
120                           options.buffer_interval_compare
121                       ? *options.buffer_interval_compare
122                       : size_compare;
123 
124   auto best_candidate = absl::c_max_element(candidates, compare);
125   if (best_candidate == candidates.end()) {
126     return absl::nullopt;
127   }
128   return *best_candidate;
129 }
130 
131 }  // namespace
132 
133 /*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
Create(const HloCostAnalysis & cost_analysis,float async_copy_bandwidth_bytes_per_second,float alternate_mem_bandwidth_bytes_per_second,const HloModule & module)134 MemorySpaceAssignmentCostAnalysis::Create(
135     const HloCostAnalysis& cost_analysis,
136     float async_copy_bandwidth_bytes_per_second,
137     float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) {
138   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
139   TF_ASSIGN_OR_RETURN(auto hlo_live_range,
140                       HloLiveRange::Run(module.schedule(), *alias_analysis,
141                                         module.entry_computation()));
142   auto call_graph = CallGraph::Build(&module);
143   return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
144       cost_analysis, async_copy_bandwidth_bytes_per_second,
145       alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
146       std::move(hlo_live_range), std::move(call_graph)));
147 }
148 
GetAlternateMemoryBenefit(const HloInstruction & instruction,float elapsed_time_due_to_alternate_mem,MemorySpaceAssignmentCostAnalysis::Cache * cache) const149 float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
150     const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
151     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
152   float elapsed_time_due_to_compute =
153       GetInstructionElapsedDueToCompute(instruction);
154   float elapsed_time_due_to_memory =
155       GetInstructionElapsedDueToMemory(instruction);
156   if (elapsed_time_due_to_memory > elapsed_time_due_to_compute) {
157     // Memory bound, return how much alternate memory is better.
158     float while_nest_multiplier;
159     if (cache) {
160       // If there is a cache provided, memoize the while nest multiplier.
161       auto it = cache->while_nest_multiplier.find(&instruction);
162       if (it != cache->while_nest_multiplier.end()) {
163         while_nest_multiplier = it->second;
164       } else {
165         while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
166             kWhileExecutionCount,
167             CalculateComputationNestLevel(&instruction,
168                                           /*while_only=*/true));
169         cache->while_nest_multiplier[&instruction] = while_nest_multiplier;
170       }
171     } else {
172       while_nest_multiplier = tensorflow::MathUtil::IPow<float>(
173           kWhileExecutionCount,
174           CalculateComputationNestLevel(&instruction,
175                                         /*while_only=*/true));
176     }
177     return (elapsed_time_due_to_memory - elapsed_time_due_to_alternate_mem) *
178            while_nest_multiplier;
179   } else {
180     // Compute bound, return how far off are we to memory boundedness.
181     return elapsed_time_due_to_memory - elapsed_time_due_to_compute;
182   }
183 }
184 
GetMemoryBoundedness(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval,MemorySpaceAssignmentCostAnalysis::Cache * cache) const185 float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
186     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval,
187     MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
188   const HloInstruction& defining_instruction =
189       *interval.buffer->defining_instruction();
190   float alternate_mem_benefit = GetAlternateMemoryBenefit(
191       defining_instruction,
192       GetInstructionElapsedDueToMemory(defining_instruction,
193                                        /*operand_in_alternate_mem=*/{},
194                                        /*output_in_alternate_mem=*/true),
195       cache);
196   for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
197            interval.buffer->defining_position().instruction,
198            interval.buffer->defining_position().index)) {
199     for (const HloValue* value : buffer->values()) {
200       for (const HloUse& use : value->uses()) {
201         // We look inside the called computations of while and conditional, so
202         // don't use the benefit of while and conditional directly.
203         if (use.instruction->opcode() == HloOpcode::kWhile ||
204             use.instruction->opcode() == HloOpcode::kConditional) {
205           continue;
206         }
207         float use_alternate_mem_benefit =
208             GetAlternateMemoryBenefit(*use.instruction,
209                                       GetInstructionElapsedDueToMemory(
210                                           *use.instruction, use.operand_number),
211                                       cache);
212         // If the benefit is positive (memory bound), add it to this buffer's
213         // benefit. If the benefit is negative (compute bound), calculate the
214         // maximum.
215         if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
216           alternate_mem_benefit += use_alternate_mem_benefit;
217         } else {
218           alternate_mem_benefit =
219               std::max(alternate_mem_benefit, use_alternate_mem_benefit);
220         }
221       }
222     }
223   }
224 
225   // Penalize larger buffers by dividing the benefit by the square root of the
226   // size. Empirically, we observed this resulted in better performance compared
227   // to dividing by the size.
228   return alternate_mem_benefit / std::sqrt(interval.size);
229 }
230 
CalculateComputationNestLevel(const HloInstruction * instruction,bool while_only) const231 int MemorySpaceAssignmentCostAnalysis::CalculateComputationNestLevel(
232     const HloInstruction* instruction, bool while_only) const {
233   int nest_level = 0;
234   const HloComputation* computation = instruction->parent();
235   while (!computation->IsEntryComputation()) {
236     auto node = call_graph_->GetNode(computation);
237     auto callsites = node.caller_callsites();
238     CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
239     auto callsite = callsites[0];
240     if (!while_only || callsite.instruction()->opcode() == HloOpcode::kWhile) {
241       ++nest_level;
242     }
243     computation = callsite.instruction()->parent();
244   }
245   return nest_level;
246 }
247 
GetInstructionElapsedDueToCompute(const HloInstruction & instruction) const248 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToCompute(
249     const HloInstruction& instruction) const {
250   return std::max(
251       cost_analysis_.flop_count(instruction) /
252           cost_analysis_.per_second_rate(HloCostAnalysis::kFlopsKey),
253       cost_analysis_.transcendental_count(instruction) /
254           cost_analysis_.per_second_rate(HloCostAnalysis::kTranscendentalsKey));
255 }
256 
GetInstructionElapsedDueToMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const257 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedDueToMemory(
258     const HloInstruction& instruction,
259     absl::optional<int64> operand_in_alternate_mem,
260     bool output_in_alternate_mem) const {
261   float bytes_accessed = cost_analysis_.bytes_accessed(instruction);
262   float elapsed_due_to_bytes =
263       bytes_accessed /
264       cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
265   if (operand_in_alternate_mem) {
266     // Estimate the elapsed time due to the operand being in the alternate
267     // memory space.
268     float operand_bytes_accessed = cost_analysis_.operand_bytes_accessed(
269         instruction, *operand_in_alternate_mem);
270     float elapsed_due_to_operand_bytes =
271         operand_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
272     bytes_accessed -= operand_bytes_accessed;
273     elapsed_due_to_bytes =
274         elapsed_due_to_operand_bytes +
275         bytes_accessed /
276             cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
277   }
278   if (output_in_alternate_mem) {
279     // Estimate the elapsed time due to the output being in the alternate memory
280     // space.
281     float output_bytes_accessed =
282         cost_analysis_.output_bytes_accessed(instruction);
283     float elapsed_due_to_output_bytes =
284         output_bytes_accessed / alternate_mem_bandwidth_bytes_per_second_;
285     bytes_accessed -= output_bytes_accessed;
286     elapsed_due_to_bytes =
287         elapsed_due_to_output_bytes +
288         bytes_accessed /
289             cost_analysis_.per_second_rate(HloCostAnalysis::kBytesAccessedKey);
290   }
291   return elapsed_due_to_bytes;
292 }
293 
GetInstructionElapsed(const HloInstruction & instruction) const294 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsed(
295     const HloInstruction& instruction) const {
296   return std::max(GetInstructionElapsedDueToCompute(instruction),
297                   GetInstructionElapsedDueToMemory(instruction));
298 }
299 
GetInstructionElapsedInAlternateMemory(const HloInstruction & instruction,absl::optional<int64> operand_in_alternate_mem,bool output_in_alternate_mem) const300 float MemorySpaceAssignmentCostAnalysis::GetInstructionElapsedInAlternateMemory(
301     const HloInstruction& instruction,
302     absl::optional<int64> operand_in_alternate_mem,
303     bool output_in_alternate_mem) const {
304   return std::max(
305       GetInstructionElapsedDueToCompute(instruction),
306       GetInstructionElapsedDueToMemory(instruction, operand_in_alternate_mem,
307                                        output_in_alternate_mem));
308 }
309 
GetAsyncCopyElapsed(const Shape & shape) const310 float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
311     const Shape& shape) const {
312   int64 size_in_bytes = cost_analysis_.GetShapeSize(shape);
313   return static_cast<float>(size_in_bytes) /
314          async_copy_bandwidth_bytes_per_second_;
315 }
316 
GetScheduleEndTime() const317 int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
318   return hlo_live_range_->schedule_end_time();
319 }
320 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const321 bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
322     const Shape& shape, int64 start_time, int64 end_time) const {
323   return end_time - start_time <= max_overlap_count_;
324 }
325 
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const326 int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime(
327     const Shape& shape, int64 start_time, int64 latest_end_time) const {
328   return std::min(start_time + min_overlap_count_, latest_end_time);
329 }
330 
LatestPrefetchStartTime(const Shape & shape,int64 start_time,int64 end_time,const HloUse * use) const331 int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime(
332     const Shape& shape, int64 start_time, int64 end_time,
333     const HloUse* use) const {
334   return end_time - min_overlap_count_;
335 }
336 
PreferredPrefetchStartTime(const Shape & shape,int64 earliest_prefetch_start_time,int64 latest_prefetch_start_time,int64 prefetch_end_time) const337 int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime(
338     const Shape& shape, int64 earliest_prefetch_start_time,
339     int64 latest_prefetch_start_time, int64 prefetch_end_time) const {
340   return std::max(earliest_prefetch_start_time,
341                   prefetch_end_time - max_overlap_count_);
342 }
343 
Begin(const HloUse & use,int64 start_time,int64 end_time)344 void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use,
345                                                    int64 start_time,
346                                                    int64 end_time) {
347   end_time_ = end_time;
348   const Shape& shape = ShapeUtil::GetSubshape(
349       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
350   current_prefetch_time_ =
351       PreferredPrefetchStartTime(shape, start_time, end_time, end_time);
352 }
353 
Next()354 int64 InstructionCountPrefetchIntervalPicker::Next() {
355   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
356                     "Done() is false";
357   return current_prefetch_time_++;
358 }
359 
Done() const360 bool InstructionCountPrefetchIntervalPicker::Done() const {
361   return end_time_ - current_prefetch_time_ <= min_overlap_count_;
362 }
363 
ToDebugString() const364 std::string InstructionCountPrefetchIntervalPicker::ToDebugString() const {
365   return absl::StrCat("Overlapped HLOs = ", end_time_ - current_prefetch_time_);
366 }
367 
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const368 std::string InstructionCountPrefetchIntervalPicker::ToNoCopyDebugString(
369     const Shape& shape, int64 start_time, int64 end_time) const {
370   return absl::StrCat("Overlapped HLOs = ", end_time - start_time);
371 }
372 
CostAnalysisPrefetchIntervalPicker(const MemorySpaceAssignmentCostAnalysis & cost_analysis,float min_async_copy_to_overlap_ratio,float max_async_copy_to_overlap_ratio,float preferred_async_copy_to_overlap_ratio)373 CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
374     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
375     float min_async_copy_to_overlap_ratio,
376     float max_async_copy_to_overlap_ratio,
377     float preferred_async_copy_to_overlap_ratio)
378     : while_nest_level_(
379           cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
380       computation_nest_level_(
381           cost_analysis.hlo_live_range().instruction_schedule().size(), 0),
382       cost_analysis_(cost_analysis),
383       min_async_copy_to_overlap_ratio_(min_async_copy_to_overlap_ratio),
384       max_async_copy_to_overlap_ratio_(max_async_copy_to_overlap_ratio),
385       preferred_async_copy_to_overlap_ratio_(
386           preferred_async_copy_to_overlap_ratio) {
387   instruction_schedule_ =
388       &cost_analysis_.hlo_live_range().instruction_schedule();
389 
390   // Create a vector of elapsed times and while nesting levels of HLO
391   // instructions. The elapsed times are multiplied by pow(kWhileExecutionCount,
392   // nest_level) to account for executing the HLOs multiple times in while
393   // loops.
394   std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
395                                                0.0);
396   for (const auto& instruction_and_logical_time : *instruction_schedule_) {
397     // To avoid double counting, don't include the elapsed time of while and
398     // conditional HLOs.
399     const HloInstruction* instruction = instruction_and_logical_time.first;
400     int64 logical_time = instruction_and_logical_time.second;
401     if (logical_time >= instructions_elapsed_time.size()) {
402       instructions_elapsed_time.resize(logical_time + 1, 0.0);
403       while_nest_level_.resize(logical_time + 1, 0);
404     }
405     int while_nest_level = cost_analysis_.CalculateComputationNestLevel(
406         instruction_and_logical_time.first, /*while_only=*/true);
407     while_nest_level_[logical_time] = while_nest_level;
408     int computation_nest_level = cost_analysis_.CalculateComputationNestLevel(
409         instruction_and_logical_time.first, /*while_only=*/false);
410     computation_nest_level_[logical_time] = computation_nest_level;
411     if (instruction->opcode() == HloOpcode::kWhile ||
412         instruction->opcode() == HloOpcode::kConditional) {
413       continue;
414     }
415     float elapsed_time = cost_analysis_.GetInstructionElapsed(
416         *instruction_and_logical_time.first);
417     instructions_elapsed_time[logical_time] =
418         elapsed_time * tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
419                                                          while_nest_level);
420   }
421   // As an optimization, create a cumulative sum vector of elapsed time.
422   float cumsum = 0.0;
423   elapsed_time_cumsum_.reserve(instructions_elapsed_time.size());
424   for (float elapsed_time : instructions_elapsed_time) {
425     cumsum += elapsed_time;
426     elapsed_time_cumsum_.push_back(cumsum);
427   }
428   // To be able to accurately determine the minimum nest level between a start
429   // time and an end time efficiently, populate a data structure that stores the
430   // closest nest level change index.
431   int prev_nest_level = 0;
432   int change_idx = -1;
433   while_nest_level_change_.reserve(instructions_elapsed_time.size());
434   for (int i = 0; i < while_nest_level_.size(); ++i) {
435     int nest_level = while_nest_level_[i];
436     if (nest_level != prev_nest_level) {
437       prev_nest_level = nest_level;
438       change_idx = i - 1;
439     }
440     while_nest_level_change_.push_back(change_idx);
441   }
442 }
443 
CanAllocateInAlternateMemoryNoCopy(const Shape & shape,int64 start_time,int64 end_time) const444 bool CostAnalysisPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
445     const Shape& shape, int64 start_time, int64 end_time) const {
446   // Even though this method returns if we allow the buffer in alternate memory
447   // _without_ asynchronous copies, calculate how long it would have taken to
448   // copy it and compare it to the elapsed time in the logical interval.
449   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
450   float logical_interval_elapsed =
451       GetLogicalIntervalElapsed(start_time, end_time);
452   return max_async_copy_to_overlap_ratio_ * max_overlap_multiplier_ *
453              async_copy_elapsed >
454          logical_interval_elapsed;
455 }
456 
PreferredEvictionEndTime(const Shape & shape,int64 start_time,int64 latest_end_time) const457 int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime(
458     const Shape& shape, int64 start_time, int64 latest_end_time) const {
459   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
460   int64 end_time;
461   for (end_time = start_time + 1; end_time <= latest_end_time; ++end_time) {
462     float logical_interval_elapsed =
463         GetLogicalIntervalElapsed(start_time, end_time);
464     if (logical_interval_elapsed >=
465         min_async_copy_to_overlap_ratio_ * async_copy_elapsed) {
466       break;
467     }
468   }
469   return end_time;
470 }
471 
LatestPrefetchStartTime(const Shape & shape,int64 start_time,int64 end_time,const HloUse * use) const472 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime(
473     const Shape& shape, int64 start_time, int64 end_time,
474     const HloUse* use) const {
475   // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
476   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
477   // If there is a use, estimate the time we would save by having this op in
478   // alternate memory.
479   float inst_elapsed_reduction = 0.0f;
480   if (use) {
481     float elapsed_time =
482         cost_analysis_.GetInstructionElapsed(*use->instruction);
483     float elapsed_time_in_alternate_mem =
484         cost_analysis_.GetInstructionElapsedInAlternateMemory(
485             *use->instruction, use->operand_number,
486             /*output_in_alternate_mem=*/false);
487     inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem;
488   }
489   int end_nest_level = computation_nest_level_[end_time];
490 
491   // Find the latest time we're allowed to start prefetching.
492   float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed;
493   int latest_prefetch_time;
494   for (latest_prefetch_time = end_time - 1;
495        latest_prefetch_time >= start_time &&
496        (computation_nest_level_[latest_prefetch_time] != end_nest_level ||
497         min_interval >
498             GetLogicalIntervalElapsed(latest_prefetch_time, end_time) +
499                 inst_elapsed_reduction);
500        --latest_prefetch_time) {
501   }
502 
503   return latest_prefetch_time;
504 }
505 
PreferredPrefetchStartTime(const Shape & shape,int64 earliest_prefetch_start_time,int64 latest_prefetch_start_time,int64 prefetch_end_time) const506 int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime(
507     const Shape& shape, int64 earliest_prefetch_start_time,
508     int64 latest_prefetch_start_time, int64 prefetch_end_time) const {
509   // Between the earliest and latest prefetch interval, find the interval
510   // closest to the preferred interval and start iterating from there.
511   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
512   int64 preferred_prefetch_start_time = earliest_prefetch_start_time;
513   float preferred_interval =
514       preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed;
515   float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time,
516                                                   prefetch_end_time);
517   int end_nest_level = computation_nest_level_[prefetch_end_time];
518   for (int64 prefetch_start_time = earliest_prefetch_start_time + 1;
519        prefetch_start_time <= latest_prefetch_start_time;
520        ++prefetch_start_time) {
521     float interval =
522         GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time);
523     if (computation_nest_level_[prefetch_start_time] == end_nest_level &&
524         std::abs(preferred_interval - interval) <
525             std::abs(preferred_interval - best_interval)) {
526       best_interval = interval;
527       preferred_prefetch_start_time = prefetch_start_time;
528     }
529   }
530   return preferred_prefetch_start_time;
531 }
532 
LatestPrefetchEndTime(int64 original_prefetch_end_time,int64 proposed_prefetch_end_time) const533 int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime(
534     int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const {
535   // Iterate towards the beginning until we find a suitable end time that is the
536   // same while nest level as the original prefetch end time.
537   int64 original_nest_level =
538       computation_nest_level_[original_prefetch_end_time];
539   int64 new_prefetch_end_time;
540   for (new_prefetch_end_time = proposed_prefetch_end_time;
541        computation_nest_level_[new_prefetch_end_time] != original_nest_level;
542        --new_prefetch_end_time) {
543   }
544   return new_prefetch_end_time;
545 }
546 
Begin(const HloUse & use,int64 start_time,int64 end_time)547 void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use,
548                                                int64 start_time,
549                                                int64 end_time) {
550   const Shape& shape = ShapeUtil::GetSubshape(
551       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
552   // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_.
553   async_copy_elapsed_ = cost_analysis_.GetAsyncCopyElapsed(shape);
554   // Estimate the time we would save by having this op in alternate memory.
555   float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction);
556   float elapsed_time_in_alternate_mem =
557       cost_analysis_.GetInstructionElapsedInAlternateMemory(
558           *use.instruction, use.operand_number,
559           /*output_in_alternate_mem=*/false);
560   inst_elapsed_reduction_ = elapsed_time - elapsed_time_in_alternate_mem;
561   end_logical_time_ = end_time;
562   int end_nest_level = computation_nest_level_[end_logical_time_];
563 
564   // Find the latest time we're allowed to start prefetching.
565   float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
566   latest_prefetch_time_ =
567       LatestPrefetchStartTime(shape, start_time, end_time, &use);
568 
569   // Find the earliest time we're allowed to start prefetching.
570   float max_interval = max_async_copy_to_overlap_ratio_ *
571                        max_overlap_multiplier_ * async_copy_elapsed_;
572   for (earliest_prefetch_time_ = start_time;
573        earliest_prefetch_time_ <= end_logical_time_ &&
574        (computation_nest_level_[earliest_prefetch_time_] != end_nest_level ||
575         max_interval < GetLogicalIntervalElapsed(earliest_prefetch_time_,
576                                                  end_logical_time_));
577        ++earliest_prefetch_time_) {
578   }
579   if (earliest_prefetch_time_ > latest_prefetch_time_) {
580     // There is no available prefetch interval for the given start and end
581     // times. Set the iterators accordingly to ensure Done() returns true.
582     increasing_prefetch_time_iterator_ = earliest_prefetch_time_;
583     decreasing_prefetch_time_iterator_ = latest_prefetch_time_;
584     CHECK(Done());
585     return;
586   }
587 
588   int64 starting_prefetch_time = PreferredPrefetchStartTime(
589       shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_);
590   float preferred_interval =
591       preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_;
592   VLOG(4) << "Interval min/max/preferred = " << min_interval << " "
593           << max_interval << " " << preferred_interval
594           << " prefetch time earliest/latest/starting = "
595           << earliest_prefetch_time_ << " " << latest_prefetch_time_ << " "
596           << starting_prefetch_time;
597 
598   increasing_prefetch_time_iterator_ = starting_prefetch_time;
599   decreasing_prefetch_time_iterator_ = starting_prefetch_time;
600   using_increasing_prefetch_time_iterator_ = true;
601   // Since both iterators start at the same position, call Next() once to
602   // advance one of the iterators.
603   Next();
604 }
605 
Next()606 int64 CostAnalysisPrefetchIntervalPicker::Next() {
607   CHECK(!Done()) << "Prefetch interval picker's Next() is called even though "
608                     "Done() is false";
609   if (using_increasing_prefetch_time_iterator_) {
610     int64 prefetch_time = increasing_prefetch_time_iterator_++;
611     while (increasing_prefetch_time_iterator_ <= latest_prefetch_time_ &&
612            computation_nest_level_[increasing_prefetch_time_iterator_] !=
613                computation_nest_level_[end_logical_time_]) {
614       ++increasing_prefetch_time_iterator_;
615     }
616     if (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_) {
617       using_increasing_prefetch_time_iterator_ = false;
618     }
619     return prefetch_time;
620   } else {
621     int64 prefetch_time = decreasing_prefetch_time_iterator_--;
622     while (decreasing_prefetch_time_iterator_ >= earliest_prefetch_time_ &&
623            computation_nest_level_[decreasing_prefetch_time_iterator_] !=
624                computation_nest_level_[end_logical_time_]) {
625       --decreasing_prefetch_time_iterator_;
626     }
627     if (increasing_prefetch_time_iterator_ <= latest_prefetch_time_) {
628       using_increasing_prefetch_time_iterator_ = true;
629     }
630     return prefetch_time;
631   }
632 }
633 
Done() const634 bool CostAnalysisPrefetchIntervalPicker::Done() const {
635   return increasing_prefetch_time_iterator_ > latest_prefetch_time_ &&
636          decreasing_prefetch_time_iterator_ < earliest_prefetch_time_;
637 }
638 
SetRetryNumber(int retry_number)639 void CostAnalysisPrefetchIntervalPicker::SetRetryNumber(int retry_number) {
640   // Use twice as large max overlap limit in each retry.
641   max_overlap_multiplier_ = 1 << retry_number;
642 }
643 
GetMinWhileNestLevel(int64 start_time,int64 end_time) const644 int CostAnalysisPrefetchIntervalPicker::GetMinWhileNestLevel(
645     int64 start_time, int64 end_time) const {
646   int min_nest_level =
647       std::min(while_nest_level_[start_time], while_nest_level_[end_time]);
648   int change_idx = while_nest_level_change_[end_time];
649   while (change_idx >= start_time) {
650     min_nest_level = std::min(min_nest_level, while_nest_level_[change_idx]);
651     change_idx = while_nest_level_change_[change_idx];
652   }
653   return min_nest_level;
654 }
655 
GetLogicalIntervalElapsed(int64 start_time,int64 end_time) const656 float CostAnalysisPrefetchIntervalPicker::GetLogicalIntervalElapsed(
657     int64 start_time, int64 end_time) const {
658   CHECK_LE(start_time, end_time);
659   if (start_time == end_time) {
660     return 0.0;
661   }
662   if (start_time < 0) {
663     start_time = 0;
664   }
665   // Since elapsed_time_cumsum_ is already weighed by the while loop nesting
666   // level, normalize the elapsed time by dividing with the nesting factor of
667   // the interval (start and end times).
668   int interval_while_nest_level = GetMinWhileNestLevel(start_time, end_time);
669   return (elapsed_time_cumsum_[end_time - 1] -
670           elapsed_time_cumsum_[start_time]) /
671          tensorflow::MathUtil::IPow<float>(kWhileExecutionCount,
672                                            interval_while_nest_level);
673 }
674 
ToDebugString() const675 std::string CostAnalysisPrefetchIntervalPicker::ToDebugString() const {
676   int current_logical_prefetch_time = using_increasing_prefetch_time_iterator_
677                                           ? increasing_prefetch_time_iterator_
678                                           : decreasing_prefetch_time_iterator_;
679   float logical_interval_elapsed = GetLogicalIntervalElapsed(
680       current_logical_prefetch_time, end_logical_time_);
681   return absl::StrCat(
682       "Async copy elapsed (s) = ", async_copy_elapsed_,
683       ", inst elapsed reduction (s) = ", inst_elapsed_reduction_,
684       ", logical interval elapsed (s) = ", logical_interval_elapsed,
685       ", interval = (", current_logical_prefetch_time, ", ", end_logical_time_,
686       ")");
687 }
688 
ToNoCopyDebugString(const Shape & shape,int64 start_time,int64 end_time) const689 std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString(
690     const Shape& shape, int64 start_time, int64 end_time) const {
691   float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape);
692   float logical_interval_elapsed =
693       GetLogicalIntervalElapsed(start_time, end_time);
694   return absl::StrCat(
695       "Async copy elapsed (s) = ", async_copy_elapsed,
696       ", logical interval elapsed (s) = ", logical_interval_elapsed);
697 }
698 
699 absl::optional<float>
BufferIntervalAlternateMemoryBenefit(const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval & interval) const700 CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit(
701     const GlobalDecreasingSizeBestFitHeap<HloValue>::BufferInterval& interval)
702     const {
703   return cost_analysis_.GetMemoryBoundedness(interval);
704 }
705 
operator ==(const MemorySpaceAssignment::Allocation & other) const706 bool MemorySpaceAssignment::Allocation::operator==(
707     const MemorySpaceAssignment::Allocation& other) const {
708   return defining_position() == other.defining_position() &&
709          uses() == other.uses() && memory_space() == other.memory_space() &&
710          chunk() == other.chunk() && start_time() == other.start_time() &&
711          end_time() == other.end_time() &&
712          is_copy_allocation() == other.is_copy_allocation();
713 }
714 
operator ==(const MemorySpaceAssignment::CopyAllocation & other) const715 bool MemorySpaceAssignment::CopyAllocation::operator==(
716     const MemorySpaceAssignment::CopyAllocation& other) const {
717   return static_cast<const Allocation&>(*this) ==
718              static_cast<const Allocation&>(other) &&
719          copy_done_schedule_before() == other.copy_done_schedule_before() &&
720          copy_start_schedule_after() == other.copy_start_schedule_after() &&
721          copy_start() == other.copy_start() && copy_done() == other.copy_done();
722 }
723 
ToString() const724 std::string MemorySpaceAssignment::AllocationValue::ToString() const {
725   std::string out = absl::StrCat("computation = ", computation()->name());
726   absl::StrAppend(&out, "\n position:\n");
727   absl::StrAppend(&out, "  ", defining_position_.ToString(), "\n");
728   absl::StrAppend(&out, " uses:\n");
729   for (const Use& use : uses_) {
730     absl::StrAppend(&out, "  ", use.hlo_use.ToString(), "\n");
731   }
732   return out;
733 }
734 
ToShortString() const735 std::string MemorySpaceAssignment::AllocationValue::ToShortString() const {
736   return absl::StrCat("computation = ", computation()->name(),
737                       ", position = ", defining_position_.ToString(),
738                       ", value = ", value_->ToShortString());
739 }
740 
CreateAllocationValues(const AlternateMemoryBestFitHeap::BufferInterval & buffer_interval,std::vector<AllocationValue> & allocation_values) const741 void AlternateMemoryBestFitHeap::CreateAllocationValues(
742     const AlternateMemoryBestFitHeap::BufferInterval& buffer_interval,
743     std::vector<AllocationValue>& allocation_values) const {
744   const HloValue* value = buffer_interval.buffer;
745   VLOG(3) << "Creating AllocationValues for: " << value->ToString();
746 
747   // Find and sort all non-trivial (excluding GTE, Tuple, and bitcast)
748   // positions. We create an AllocationValue object for each non-trivial
749   // position. And for each AllocationValue object, we create an
750   // AllocationSequence consisting of one or more Allocation objects.The reason
751   // why we exclude the trivial positions from AllocationValue is because
752   // Allocation objects have special support for tuples and bitcasts.
753   const absl::flat_hash_map<const HloInstruction*, int64>&
754       instruction_schedule = hlo_live_range_.instruction_schedule();
755   std::vector<HloPosition> positions;
756   for (const HloPosition& position : value->positions()) {
757     const HloInstruction* instruction = position.instruction;
758     if (instruction->opcode() != HloOpcode::kGetTupleElement &&
759         instruction->opcode() != HloOpcode::kTuple &&
760         instruction->opcode() != HloOpcode::kBitcast) {
761       positions.push_back(position);
762     }
763   }
764   absl::c_stable_sort(positions,
765                       [&](const HloPosition& pos1, const HloPosition& pos2) {
766                         return instruction_schedule.at(pos1.instruction) <
767                                instruction_schedule.at(pos2.instruction);
768                       });
769 
770   // Create an AllocationValue for each non-trivial position.
771   absl::flat_hash_set<const HloComputation*> computations;
772   int beginning_idx = allocation_values.size();
773   for (int i = 0; i < positions.size(); ++i) {
774     const HloPosition& position = positions.at(i);
775     allocation_values.emplace_back(value, position, buffer_interval.size);
776   }
777 
778   std::vector<HloUse> uses(value->uses());
779   absl::c_stable_sort(uses, [&](const HloUse& use1, const HloUse& use2) {
780     return instruction_schedule.at(use1.instruction) <
781            instruction_schedule.at(use2.instruction);
782   });
783 
784   // Associate each use with an AllocationValue. Each AllocationValue contains a
785   // position and uses in the same computation. Furthermore, if the original
786   // HloValue had multiple non-trivial positions in the same computation, those
787   // will get their own AllocationValue as well. We split these HloValues so
788   // that when we insert CopyStart/CopyDone in CopyAllocation::Process, they
789   // point to the latest position. We then replace the operand of the use with
790   // CopyStart/CopyDone with an operand of the latest position.
791   for (const HloUse& use : uses) {
792     int64 use_time = instruction_schedule.at(use.instruction);
793     HloComputation* use_computation = use.instruction->parent();
794 
795     AllocationValue* last_allocation_value = nullptr;
796     for (int i = beginning_idx; i < allocation_values.size(); ++i) {
797       AllocationValue* allocation_value = &allocation_values.at(i);
798       if (allocation_value->computation() == use_computation &&
799           instruction_schedule.at(
800               allocation_value->defining_position().instruction) < use_time) {
801         last_allocation_value = allocation_value;
802       }
803     }
804     CHECK(last_allocation_value != nullptr);
805     last_allocation_value->AddUse(use, use_time);
806   }
807 
808   for (int i = beginning_idx; i < allocation_values.size(); ++i) {
809     VLOG(3) << "Created allocation value: "
810             << allocation_values.at(i).ToString();
811   }
812 }
813 
FindAliases(std::vector<AllocationValue> * allocation_values,bool skip_values_with_no_uses) const814 void AlternateMemoryBestFitHeap::FindAliases(
815     std::vector<AllocationValue>* allocation_values,
816     bool skip_values_with_no_uses) const {
817   absl::flat_hash_map<const HloInstruction*, const AllocationValue*>
818       values_by_defining_inst;
819   for (AllocationValue& value : *allocation_values) {
820     // Skip the value if it doesn't have any uses.
821     if (value.uses().empty() && skip_values_with_no_uses) {
822       continue;
823     }
824     CHECK_EQ(values_by_defining_inst.count(value.defining_instruction()), 0);
825     values_by_defining_inst[value.defining_instruction()] = &value;
826   }
827   auto maybe_add_alias_with_instruction = [&](const HloInstruction* instruction,
828                                               AllocationValue::Use* use) {
829     auto aliased_value_it = values_by_defining_inst.find(instruction);
830     if (aliased_value_it != values_by_defining_inst.end()) {
831       VLOG(3) << "Adding aliasing for use " << use->hlo_use.ToString() << " to "
832               << aliased_value_it->second->ToShortString();
833       use->aliases.push_back(aliased_value_it->second->defining_position());
834     }
835   };
836 
837   for (AllocationValue& value : *allocation_values) {
838     for (AllocationValue::Use& use : value.uses()) {
839       // Find any aliases with the instruction itself (operand and output must
840       // alias).
841       maybe_add_alias_with_instruction(use.hlo_use.instruction, &use);
842 
843       // Find any aliases with the parameters of called computations.
844       for (const HloComputation* called_computation :
845            use.hlo_use.instruction->called_computations()) {
846         for (const HloInstruction* parameter_instruction :
847              called_computation->parameter_instructions()) {
848           maybe_add_alias_with_instruction(parameter_instruction, &use);
849         }
850       }
851 
852       // Special case for kWhile: the root of the body computation must alias as
853       // well.
854       if (use.hlo_use.instruction->opcode() == HloOpcode::kWhile) {
855         HloPosition root_alias{
856             use.hlo_use.instruction->while_body()->root_instruction(),
857             use.hlo_use.operand_index};
858         VLOG(3) << "Adding while body root aliasing for use "
859                 << use.hlo_use.ToString() << " to " << root_alias;
860         use.aliases.push_back(root_alias);
861       }
862     }
863   }
864 }
865 
866 std::vector<const AlternateMemoryBestFitHeap::BufferInterval*>
GetSortedColocatedIntervals(const AlternateMemoryBestFitHeap::BufferInterval & interval) const867 AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
868     const AlternateMemoryBestFitHeap::BufferInterval& interval) const {
869   std::vector<const BufferInterval*> colocated_intervals;
870   std::vector<const BufferInterval*> worklist = {&interval};
871   while (!worklist.empty()) {
872     const BufferInterval* item = worklist.back();
873     worklist.pop_back();
874     colocated_intervals.push_back(item);
875     for (const HloValue* buffer_colocated : item->colocations) {
876       worklist.push_back(&buffer_intervals_.at(buffer_colocated));
877     }
878   }
879 
880   absl::c_stable_sort(colocated_intervals, [&](const BufferInterval* x,
881                                                const BufferInterval* y) {
882     return std::make_pair(x->start, x->end) < std::make_pair(y->start, y->end);
883   });
884   return colocated_intervals;
885 }
886 
IsUseAllowedInAlternateMemory(const AllocationValue & value,const HloUse & use) const887 bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
888     const AllocationValue& value, const HloUse& use) const {
889   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
890   if (!options_.is_use_allowed_in_alternate_mem_fn(use)) {
891     return false;
892   }
893   if (use.instruction->opcode() == HloOpcode::kWhile) {
894     HloComputation* while_body = use.instruction->while_body();
895 
896     // We don't want to allocate this buffer in alternate memory if it will be
897     // evicted anyway. Find out if it has an early use or a late definition that
898     // would make sense to keep it in the alternate memory.
899     HloValue* parameter_value =
900         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
901             while_body->parameter_instruction(0), use.operand_index);
902     int64 parameter_time =
903         instruction_schedule.at(while_body->parameter_instruction(0));
904     int64 root_time = instruction_schedule.at(while_body->root_instruction());
905     int64 min_use_time = root_time;
906     for (const HloUse& parameter_use : parameter_value->uses()) {
907       int64 use_time = instruction_schedule.at(parameter_use.instruction);
908       if (parameter_use.instruction->opcode() != HloOpcode::kGetTupleElement &&
909           parameter_use.instruction->opcode() != HloOpcode::kTuple &&
910           parameter_use.instruction->opcode() != HloOpcode::kBitcast &&
911           use_time > parameter_time) {
912         min_use_time = std::min(min_use_time, use_time);
913       }
914     }
915     // If there is no use of this buffer inside the while loop, there is no need
916     // to allocate it in the loop.
917     if (min_use_time == root_time) {
918       VLOG(4) << "While allocation not allowed in alternate memory. "
919               << "use time = " << min_use_time << ", root time = " << root_time;
920       return false;
921     }
922     const Shape& shape = parameter_value->shape();
923     // Allow the buffer in alternate memory if the buffer has a short live range
924     // either at the beginning or end of the while loop body.
925     if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
926             shape, parameter_time, min_use_time)) {
927       VLOG(4) << "While allocation not allowed in alternate memory. "
928               << "use time = " << min_use_time << ", root time = " << root_time;
929       return false;
930     }
931     // Check if there is a required assignment for the while loop output.
932     HloValue* while_value =
933         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
934             use.instruction, use.operand_index);
935     int64 while_time = instruction_schedule.at(use.instruction);
936     auto existing_required_assignment =
937         RequiredMemoryAssignmentAt(while_value, while_time);
938     if (existing_required_assignment &&
939         existing_required_assignment->memory_space == MemorySpace::kDefault) {
940       VLOG(4) << "While allocation not allowed in alternate memory because "
941                  "there is a required default memory assignment.";
942       return false;
943     }
944   } else if (use.instruction->opcode() == HloOpcode::kConditional) {
945     // For any use of this conditional (the same value might be passed into
946     // multiple called computations), determine if the parameter->first use
947     // dependency is short.
948     int64 conditional_time = instruction_schedule.at(use.instruction);
949     for (const AllocationValue::Use& other_use : value.uses()) {
950       if (other_use.hlo_use.instruction != use.instruction) {
951         continue;
952       }
953       HloComputation* called_computation =
954           use.instruction->called_computations().at(
955               other_use.hlo_use.operand_number - 1);
956       const HloInstruction* parameter_instruction =
957           called_computation->parameter_instruction(0);
958       HloValue* parameter_value =
959           &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
960               parameter_instruction, other_use.hlo_use.operand_index);
961       int64 parameter_time = instruction_schedule.at(parameter_instruction);
962       int64 min_use_time = conditional_time;
963       for (const HloUse& parameter_use : parameter_value->uses()) {
964         if (parameter_use.instruction->parent() == called_computation &&
965             parameter_use.instruction->opcode() !=
966                 HloOpcode::kGetTupleElement &&
967             parameter_use.instruction->opcode() != HloOpcode::kTuple &&
968             parameter_use.instruction->opcode() != HloOpcode::kBitcast) {
969           min_use_time = std::min(
970               min_use_time, instruction_schedule.at(parameter_use.instruction));
971         }
972       }
973       if (options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
974               parameter_value->shape(), parameter_time, min_use_time)) {
975         VLOG(4) << "Conditional allocation allowed in alternate memory for "
976                    "computation = "
977                 << called_computation->name()
978                 << ", parameter time = " << parameter_time
979                 << ", min use time = " << min_use_time;
980         return true;
981       } else {
982         VLOG(4) << "Conditional allocation not allowed in alternate memory for "
983                    "computation = "
984                 << called_computation->name()
985                 << ", parameter time = " << parameter_time
986                 << ", min use time = " << min_use_time;
987       }
988     }
989     return false;
990   }
991 
992   return true;
993 }
994 
AppendBufferInfoDebugString(const AlternateMemoryBestFitHeap::BufferInterval & interval,std::string * debug_str) const995 void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString(
996     const AlternateMemoryBestFitHeap::BufferInterval& interval,
997     std::string* debug_str) const {
998   // Columns in buffer information:
999   // buffer_id: int. This value can be used to match the allocation in
1000   // allocation information.
1001   // buffer_name: string.
1002   // alt_mem_benefit: float. Roughly corresponds to how much the cost analysis
1003   // thought it would be beneficial to put this in the alternate memory. The
1004   // higher the value, the more it is memory bound.
1005   // size: int. In bytes.
1006   // definition_time: int. Logical time this value was defined in the schedule.
1007   // use_times: string. This is a semicolon-separated list of integers for all
1008   // the use times.
1009   // use_names: string. This is a semicolon-separated list of string
1010   // representation of uses.
1011   if (debug_str->empty()) {
1012     // Append the column names.
1013     absl::StrAppend(debug_str,
1014                     "buffer_id,buffer_name,alt_mem_benefit,size,"
1015                     "definition_time,use_times,use_names\n");
1016   }
1017   const HloBuffer& buffer =
1018       alias_analysis_.GetBufferContainingValue(*interval.buffer);
1019   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1020   int64 definition_time =
1021       instruction_schedule.at(interval.buffer->defining_position().instruction);
1022   std::vector<std::pair<int64, std::string>> uses;
1023   for (const HloValue* value : buffer.values()) {
1024     for (const HloUse& use : value->uses()) {
1025       uses.push_back(
1026           {instruction_schedule.at(use.instruction), use.ToString()});
1027     }
1028   }
1029   absl::c_sort(uses);
1030   std::vector<int64> use_times;
1031   std::vector<std::string> use_names;
1032   use_times.reserve(uses.size());
1033   use_names.reserve(uses.size());
1034   for (const auto& use : uses) {
1035     use_times.push_back(use.first);
1036     use_names.push_back(use.second);
1037   }
1038 
1039   absl::StrAppend(debug_str, buffer.id(), ",");
1040   absl::StrAppend(debug_str, "\"", interval.buffer->ToShortString(), "\",");
1041   auto alternate_memory_benefit =
1042       options_.prefetch_interval_picker->BufferIntervalAlternateMemoryBenefit(
1043           interval);
1044   absl::StrAppend(
1045       debug_str, alternate_memory_benefit ? *alternate_memory_benefit : 0, ",");
1046   absl::StrAppend(debug_str, interval.size, ",");
1047   absl::StrAppend(debug_str, definition_time, ",");
1048   absl::StrAppend(debug_str, "\"", absl::StrJoin(use_times, ";"), "\",");
1049   absl::StrAppend(debug_str, "\"", absl::StrJoin(use_names, ";"), "\"");
1050   absl::StrAppend(debug_str, "\n");
1051 }
1052 
AppendAllocationInfoDebugString(const AllocationValue & value,const MemorySpaceAssignment::Allocation & allocation,std::string & debug_str) const1053 void AlternateMemoryBestFitHeap::AppendAllocationInfoDebugString(
1054     const AllocationValue& value,
1055     const MemorySpaceAssignment::Allocation& allocation,
1056     std::string& debug_str) const {
1057   // Columns in allocation information:
1058   // buffer_id: int. This value can be used the match with buffer info.
1059   // size: int. In bytes.
1060   // offset: int. In bytes.
1061   // start_time: int. Logical start time of the allocation.
1062   // end_time: int. Logical end time of the allocation.
1063   if (debug_str.empty()) {
1064     // Append the column names.
1065     absl::StrAppend(&debug_str, "buffer_id,size,offset,start_time,end_time\n");
1066   }
1067   if (allocation.memory_space() == MemorySpace::kAlternate) {
1068     const HloBuffer& buffer =
1069         alias_analysis_.GetBufferContainingValue(*value.value());
1070     absl::StrAppend(&debug_str, buffer.id(), ",");
1071     absl::StrAppend(&debug_str, value.size(), ",");
1072     absl::StrAppend(&debug_str, allocation.chunk().offset, ",");
1073     absl::StrAppend(&debug_str, allocation.start_time(), ",");
1074     absl::StrAppend(&debug_str, allocation.end_time(), "\n");
1075   }
1076 }
1077 
DumpDebugStringsIfEnabled() const1078 void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const {
1079   if (!options_.dump_fn) {
1080     return;
1081   }
1082   options_.dump_fn("bufferinfo", buffer_info_str_);
1083   options_.dump_fn("allocinfo", allocation_info_str_);
1084 }
1085 
Finish()1086 HeapSimulator::Result<HloValue> AlternateMemoryBestFitHeap::Finish() {
1087   if (options_.enable_cross_program_prefetch) {
1088     absl::optional<AlternateMemoryBestFitHeap::BufferInterval>
1089         prefetch_candidate = FindCrossProgramPrefetchCandidate(
1090             alias_analysis_, hlo_live_range_, options_);
1091     if (prefetch_candidate) {
1092       HloModule* module =
1093           prefetch_candidate->buffer->instruction()->GetModule();
1094       AllocateCrossProgramPrefetchBuffer(module, prefetch_candidate);
1095     }
1096   }
1097 
1098   std::vector<BufferInterval> sorted_buffer_intervals =
1099       GetSortedBufferIntervals();
1100 
1101   VLOG(1) << "Assigning buffers to alternate memory. Max heap size = "
1102           << options_.max_size_in_bytes;
1103 
1104   AddInputAndOutputRequiredAssignments();
1105 
1106   if (VLOG_IS_ON(3)) {
1107     VLOG(3) << "Flattened instruction sequence:";
1108     const auto& instruction_sequence =
1109         hlo_live_range_.flattened_instruction_sequence().instructions();
1110     for (int i = 0; i < instruction_sequence.size(); ++i) {
1111       VLOG(3) << " " << i << ": " << instruction_sequence[i]->parent()->name()
1112               << " " << instruction_sequence[i]->name();
1113     }
1114   }
1115 
1116   for (const auto& interval : sorted_buffer_intervals) {
1117     auto colocated_intervals = GetSortedColocatedIntervals(interval);
1118     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1119       // Increment the reserved part of alternate memory so that it is not
1120       // available for other buffers.
1121       reserved_in_bytes_ += options_.size_fn(*interval.buffer);
1122     }
1123   }
1124   VLOG(2) << "Total reserved bytes = " << reserved_in_bytes_;
1125 
1126   for (auto& interval : sorted_buffer_intervals) {
1127     if (!interval.need_allocation) {
1128       continue;
1129     }
1130 
1131     if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
1132             interval)) {
1133       continue;
1134     }
1135 
1136     HloInstruction* inst = interval.buffer->instruction();
1137     HloModule* module = inst->GetModule();
1138 
1139     // Don't intra-program prefetch a cross program prefetch
1140     if (inst->opcode() == HloOpcode::kParameter &&
1141         absl::c_count(module->CrossProgramPrefetches(),
1142                       std::make_pair(inst->parameter_number(),
1143                                      interval.buffer->index())) > 0) {
1144       VLOG(3) << "Skip " << interval.buffer->ToShortString()
1145               << " because it is cross-program prefetched.";
1146       continue;
1147     }
1148 
1149     if (interval.size > available_heap_size()) {
1150       VLOG(3) << "Skip " << interval.buffer->ToShortString()
1151               << " because the buffer is larger than the heap size.";
1152       continue;
1153     }
1154 
1155     auto colocated_intervals = GetSortedColocatedIntervals(interval);
1156 
1157     if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
1158       VLOG(3) << "Interval " << interval.buffer->ToShortString()
1159               << " is reserved in the alternate memory.";
1160       for (const BufferInterval* colocated_interval : colocated_intervals) {
1161         const HloValue* value = colocated_interval->buffer;
1162         // Color all of the aliased reserved buffers here because reserved
1163         // alternate memory allocations will not have an entry in preset
1164         // allocations that is normally used for coloring.
1165         for (auto& position : value->positions()) {
1166           VLOG(4) << "Coloring " << position.ToString();
1167           Shape* shape = ShapeUtil::GetMutableSubshape(
1168               position.instruction->mutable_shape(), position.index);
1169           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
1170                                   << position.ToString();
1171           shape->mutable_layout()->set_memory_space(
1172               options_.alternate_memory_space);
1173         }
1174       }
1175       continue;
1176     }
1177 
1178     if (colocated_intervals.size() > 1 &&
1179         !options_.allocate_across_sequential_calls) {
1180       VLOG(4) << "Not allocating " << interval.buffer->ToShortString()
1181               << " because it aliases with another interval and "
1182               << " allocate_across_sequential_calls is false.";
1183       continue;
1184     }
1185 
1186     if (!ConsumeFuel("memory_space_assignment", [&] {
1187           return absl::StrCat("Ran out of fuel at buffer: ",
1188                               colocated_intervals[0]->buffer->ToShortString());
1189         })) {
1190       continue;
1191     }
1192 
1193     AppendBufferInfoDebugString(interval, &buffer_info_str_);
1194 
1195     std::vector<AllocationValue> allocation_values;
1196     CreateAllocationValuesFromColocatedIntervals(colocated_intervals,
1197                                                  allocation_values);
1198 
1199     // Retry allocating this value with larger limits if allocation fails.
1200     bool repacked = false;
1201     for (int retry_number = 0; retry_number < options_.max_retries;
1202          retry_number++) {
1203       AddRequiredAssignmentsForColocatedIntervals(colocated_intervals);
1204       bool final_retry = (retry_number == options_.max_retries - 1);
1205       options_.prefetch_interval_picker->SetRetryNumber(retry_number);
1206       Result result =
1207           AllocateAllocationValues(absl::MakeSpan(allocation_values));
1208       VLOG(2) << "Allocation result = "
1209               << absl::StrFormat("%x", static_cast<int>(result));
1210       if (result_requires_uncommit(result) ||
1211           (!final_retry && result_failed_because_of_async_copy(result))) {
1212         UncommitPendingChunks(absl::MakeSpan(allocation_values));
1213         VLOG(2) << "Couldn't allocate. Retry number " << retry_number;
1214       } else if ((result_is(result, Result::kFailOutOfMemory) ||
1215                   options_.repack_after_every_allocation) &&
1216                  num_repacks_ < options_.max_repacks && !repacked) {
1217         UncommitPendingChunks(absl::MakeSpan(allocation_values));
1218         ++num_repacks_;
1219         repacked = true;
1220         CHECK_NE(options_.repacker, nullptr);
1221         std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>
1222             repack_allocation_blocks;
1223         ExportAllocationsForRepacking(repack_allocation_blocks);
1224         VLOG(2) << "Repacking.";
1225         auto repack_status =
1226             options_.repacker->Repack(absl::MakeSpan(repack_allocation_blocks));
1227         CHECK_EQ(repack_status.status(), Status::OK());
1228         VLOG(2) << "Repack complete. Modified = " << *repack_status;
1229         if (*repack_status) {
1230           ImportRepackedAllocations();
1231           --retry_number;
1232         }
1233       } else {
1234         FinalizeAllocations(absl::MakeSpan(allocation_values));
1235         break;
1236       }
1237     }
1238   }
1239 
1240   VLOG(3) << "Debug buffer info: ";
1241   VLOG(3) << buffer_info_str_;
1242   VLOG(3) << "Debug allocation info: ";
1243   VLOG(3) << allocation_info_str_;
1244   DumpDebugStringsIfEnabled();
1245 
1246   HeapSimulator::Result<HloValue> result;
1247   result.heap_size = result_.heap_size;
1248   result.heap_results.emplace_back(std::move(result_));
1249   return result;
1250 }
1251 
AddRequiredAssignmentsForColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals)1252 void AlternateMemoryBestFitHeap::AddRequiredAssignmentsForColocatedIntervals(
1253     absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1254         colocated_intervals) {
1255   // TODO(berkin): For now, place the phi values due to conditionals in
1256   // default memory.
1257   for (const BufferInterval* colocated_interval : colocated_intervals) {
1258     const HloValue* value = colocated_interval->buffer;
1259     for (const auto& position : value->positions()) {
1260       if (position.instruction->opcode() == HloOpcode::kConditional) {
1261         VLOG(3) << "Adding required assignment for condition output: "
1262                 << value->ToShortString();
1263         AddRequiredAssignment(position.instruction, position.index,
1264                               MemorySpace::kDefault);
1265         for (const HloComputation* called_computation :
1266              position.instruction->called_computations()) {
1267           AddRequiredAssignment(called_computation->root_instruction(),
1268                                 position.index, MemorySpace::kDefault);
1269         }
1270       }
1271     }
1272   }
1273 }
1274 
CreateAllocationValuesFromColocatedIntervals(absl::Span<const AlternateMemoryBestFitHeap::BufferInterval * const> colocated_intervals,std::vector<MemorySpaceAssignment::AllocationValue> & allocation_values)1275 void AlternateMemoryBestFitHeap::CreateAllocationValuesFromColocatedIntervals(
1276     absl::Span<const AlternateMemoryBestFitHeap::BufferInterval* const>
1277         colocated_intervals,
1278     std::vector<MemorySpaceAssignment::AllocationValue>& allocation_values) {
1279   // Create AllocationValues for all the colocated intervals.
1280   for (const auto& colocated_interval : colocated_intervals) {
1281     CreateAllocationValues(*colocated_interval, allocation_values);
1282   }
1283   FindAliases(&allocation_values, /*skip_values_with_no_uses=*/true);
1284 }
1285 
1286 AlternateMemoryBestFitHeap::Result
AllocateAllocationValues(absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values)1287 AlternateMemoryBestFitHeap::AllocateAllocationValues(
1288     absl::Span<MemorySpaceAssignment::AllocationValue> allocation_values) {
1289   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1290 
1291   // Data structure to contain the preferred offset for a given computation.
1292   // We ensure that the same offset will be allocated outside the while loop
1293   // as well as inside the while loop.
1294   absl::flat_hash_map<const HloComputation*, AliasedOffset*>
1295       preferred_offset_for_computation;
1296 
1297   Result result = Result::kSuccess;
1298   for (AllocationValue& allocation_value : allocation_values) {
1299     int64 definition_time =
1300         instruction_schedule.at(allocation_value.defining_instruction());
1301 
1302     AliasedOffset* preferred_offset = nullptr;
1303     auto preferred_offset_it =
1304         preferred_offset_for_computation.find(allocation_value.computation());
1305     if (preferred_offset_it != preferred_offset_for_computation.end()) {
1306       preferred_offset = preferred_offset_it->second;
1307     }
1308 
1309     // Iterate over the uses.
1310     for (int use_idx = 0; use_idx < allocation_value.uses().size(); ++use_idx) {
1311       const AllocationValue::Use& use = allocation_value.uses().at(use_idx);
1312       const HloUse hlo_use = use.hlo_use;
1313       int64 use_time = instruction_schedule.at(hlo_use.instruction);
1314       int64 latest_prefetch_time = use_time;
1315       bool allow_no_copy_alternate_mem_allocation = true;
1316       absl::optional<int64> earliest_prefetch_time = absl::nullopt;
1317 
1318       // Sequential calls include kWhile, kCall, and kConditional opcodes.
1319       bool is_sequential_call =
1320           (GetInstructionCallContext(hlo_use.instruction->opcode()) ==
1321            CallContext::kSequential);
1322       if (is_sequential_call) {
1323         for (const HloComputation* called_computation :
1324              hlo_use.instruction->called_computations()) {
1325           const HloLiveRange::TimeBound& computation_span =
1326               hlo_live_range_.computation_span_times().at(called_computation);
1327           latest_prefetch_time =
1328               std::min(computation_span.start - 1, latest_prefetch_time);
1329         }
1330         if (hlo_use.instruction->opcode() == HloOpcode::kWhile) {
1331           // Given an example while loop and flattened schedule (logical times
1332           // shown on the left):
1333           //
1334           // 0:  a = ...
1335           // 1:  ...
1336           //     cond {
1337           // 2:   p = param(0)
1338           // 3:   ...
1339           //     }
1340           //     body {
1341           // 4:   p = param(0)
1342           // 5:   ...
1343           // 6:   ROOT ...
1344           //     }
1345           // 7:  w = while(a), body=body, cond=cond
1346           //
1347           // When processing "a" (time 0) and its while use (time 7), we update
1348           // the interval to time 0-4. This is so that the remaining interval
1349           // (5-6) can be allocated separately and this buffer doesn't waste
1350           // alternate memory space within the while loop body.
1351           HloComputation* while_body = hlo_use.instruction->while_body();
1352           // We require while body ROOTs to be the last in the schedule.
1353           CHECK_EQ(instruction_schedule.at(while_body->root_instruction()) + 1,
1354                    instruction_schedule.at(hlo_use.instruction))
1355               << "While body ROOTs need to be the last in the schedule!  "
1356                  "Please run RootInstructionSinker.";
1357           // Replace the use time with the parameter time so that we can decide
1358           // on alternate memory allocations within the while loop body when we
1359           // look at uses within the while loop body.
1360           use_time =
1361               instruction_schedule.at(while_body->parameter_instruction(0));
1362         } else if (hlo_use.instruction->opcode() == HloOpcode::kConditional) {
1363           // Replace the use time with the earliest parameter of called
1364           // computations.
1365           for (const HloComputation* called_computation :
1366                hlo_use.instruction->called_computations()) {
1367             use_time = std::min(
1368                 use_time, instruction_schedule.at(
1369                               called_computation->parameter_instruction(0)));
1370           }
1371         }
1372       }
1373 
1374       // Add a required assignment in default memory if the use not allowed in
1375       // alternate memory.
1376       if (!IsUseAllowedInAlternateMemory(allocation_value, hlo_use)) {
1377         AddRequiredAssignment(allocation_value.value(), hlo_use.instruction,
1378                               MemorySpace::kDefault, use_time);
1379       } else if (use_idx > 0) {
1380         // We allow buffers in alternate memory that are passed into
1381         // conditionals to give up their alternate memory allocation inside the
1382         // called computation. This means that if a conditional operator has an
1383         // alternate memory allocation, subsequent uses cannot use the same
1384         // alternate memory allocation in order not to clobber data. So we force
1385         // default memory allocation for these subsequent uses.
1386         const AllocationValue::Use& previous_use =
1387             allocation_value.uses().at(use_idx - 1);
1388         if (previous_use.hlo_use.instruction->opcode() ==
1389                 HloOpcode::kConditional &&
1390             previous_use.hlo_use.instruction != hlo_use.instruction) {
1391           allow_no_copy_alternate_mem_allocation = false;
1392           earliest_prefetch_time =
1393               instruction_schedule.at(previous_use.hlo_use.instruction);
1394           VLOG(3) << "Previous use (" << previous_use.hlo_use.ToString()
1395                   << ") of use (" << hlo_use.ToString()
1396                   << ") is a conditional, so this use will need to evict. "
1397                   << "Earliest prefetch time = " << *earliest_prefetch_time;
1398         }
1399       }
1400 
1401       // Bitcasts don't define buffers and don't directly consume buffers. Skip
1402       // allocating buffers for bitcast uses (unless they are the root
1403       // instruction). The uses that feed from bitcasts will be handled
1404       // specially.
1405       if (hlo_use.instruction->opcode() != HloOpcode::kBitcast ||
1406           hlo_use.instruction ==
1407               hlo_use.instruction->parent()->root_instruction()) {
1408         AllocationRequest request;
1409         // Rarely, (e.g., when conditional true and false parameters are the
1410         // same), definition time can be the time of the conditional and use
1411         // time is the parameter use, which is less.
1412         request.start_time = std::min(definition_time, use_time);
1413         request.end_time = use_time;
1414         request.latest_prefetch_time = latest_prefetch_time;
1415         request.size = allocation_value.size();
1416         request.allow_no_copy_alternate_mem_allocation =
1417             allow_no_copy_alternate_mem_allocation;
1418         request.earliest_prefetch_time = earliest_prefetch_time;
1419         request.preferred_offset = preferred_offset;
1420         request.use = &use;
1421         request.allocation_value = &allocation_value;
1422         result_mark(AllocateSegment(request), result);
1423         if (result_requires_uncommit(result)) {
1424           // If the allocation finding failed (e.g., due to running out of
1425           // asynchronous copies), then fall back to allocating the buffer
1426           // entirely in the default memory.
1427           return result;
1428         }
1429 
1430         // If there are multiple uses, they can try using the memory allocation
1431         // already at the alternate memory.
1432         definition_time = instruction_schedule.at(hlo_use.instruction);
1433       }
1434 
1435       // Propagate the allocation to any aliases this use might have had.
1436       MemorySpaceAssignment::Allocation* aliased_allocation =
1437           GetLiveAllocationAt(*allocation_value.allocation_sequence(),
1438                               use_time);
1439       for (const HloPosition& aliased_position : use.aliases) {
1440         AddAliasedRequiredAssignment(aliased_position.instruction,
1441                                      aliased_position.index,
1442                                      aliased_allocation);
1443       }
1444 
1445       // Special case for while loops since the root offset must agree with
1446       // other offsets: remember the preferred offset for the while loop body.
1447       if (hlo_use.instruction->opcode() == HloOpcode::kWhile &&
1448           aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1449         preferred_offset_for_computation[hlo_use.instruction->while_body()] =
1450             GetAliasedOffset(*aliased_allocation);
1451       }
1452     }
1453   }
1454   return result;
1455 }
1456 
operator <(const AsynchronousCopy & a,const AsynchronousCopy & b)1457 bool operator<(const AsynchronousCopy& a, const AsynchronousCopy& b) {
1458   return (a.start_time < b.start_time && a.end_time <= b.end_time) ||
1459          (a.start_time <= b.start_time && a.end_time < b.end_time);
1460 }
1461 
AddCopy(const AsynchronousCopy & copy)1462 void AsynchronousCopyOrdering::AddCopy(const AsynchronousCopy& copy) {
1463   auto it_and_inserted = ranges_.insert(copy);
1464   CHECK(it_and_inserted.second ||
1465         it_and_inserted.first->start_time == copy.start_time);
1466 }
1467 
RemoveCopy(const AsynchronousCopy & copy)1468 void AsynchronousCopyOrdering::RemoveCopy(const AsynchronousCopy& copy) {
1469   auto copy_it = ranges_.find(copy);
1470   CHECK(copy_it != ranges_.end());
1471   ranges_.erase(copy_it);
1472 }
1473 
ViolatesOrdering(int64 start_time,int64 end_time) const1474 absl::optional<AsynchronousCopy> AsynchronousCopyOrdering::ViolatesOrdering(
1475     int64 start_time, int64 end_time) const {
1476   // We allow identical start and end times. It is enough to check for just the
1477   // start time in case we find a match in ranges_ because the found value will
1478   // either be identical to {start_time, end_time} (and this doesn't violate) or
1479   // its start_time will be smaller and end_time will be larger (this violates).
1480   auto copy_it = ranges_.find(
1481       {start_time, end_time, MemorySpaceAssignment::MemorySpace::kAlternate});
1482   if (copy_it != ranges_.end() && copy_it->start_time != start_time) {
1483     VLOG(4) << "Violates ordering: (" << start_time << ", " << end_time
1484             << ") and (" << copy_it->start_time << ", " << copy_it->end_time
1485             << ")";
1486     return *copy_it;
1487   }
1488   return absl::nullopt;
1489 }
1490 
1491 AlternateMemoryBestFitHeap::AliasedOffset*
GetAliasedOffset(const MemorySpaceAssignment::Allocation & allocation)1492 AlternateMemoryBestFitHeap::GetAliasedOffset(
1493     const MemorySpaceAssignment::Allocation& allocation) {
1494   auto aliased_offset_it = aliased_offset_map_.find(&allocation);
1495   CHECK(aliased_offset_it != aliased_offset_map_.end());
1496   return aliased_offset_it->second;
1497 }
1498 
CreateOrAddToAliasedOffset(const MemorySpaceAssignment::Allocation & allocation,AlternateMemoryBestFitHeap::AliasedOffset * aliased_offset)1499 void AlternateMemoryBestFitHeap::CreateOrAddToAliasedOffset(
1500     const MemorySpaceAssignment::Allocation& allocation,
1501     AlternateMemoryBestFitHeap::AliasedOffset* aliased_offset) {
1502   CHECK(allocation.memory_space() == MemorySpace::kAlternate);
1503   CHECK(!aliased_offset_map_.contains(&allocation));
1504   if (!aliased_offset) {
1505     aliased_offsets_.push_back({allocation.chunk().offset});
1506     aliased_offset = &aliased_offsets_.back();
1507   }
1508   CHECK_EQ(allocation.chunk().offset, aliased_offset->offset);
1509   CHECK(aliased_offset->allocations.insert(&allocation).second);
1510   aliased_offset_map_[&allocation] = aliased_offset;
1511 }
1512 
1513 /*static*/ MemorySpaceAssignment::Allocation*
GetLiveAllocationAt(const MemorySpaceAssignment::AllocationSequence & allocations,int64 time)1514 AlternateMemoryBestFitHeap::GetLiveAllocationAt(
1515     const MemorySpaceAssignment::AllocationSequence& allocations, int64 time) {
1516   for (auto allocation_it = allocations.rbegin();
1517        allocation_it != allocations.rend(); ++allocation_it) {
1518     if ((*allocation_it)->start_time() <= time &&
1519         (*allocation_it)->end_time() >= time) {
1520       return allocation_it->get();
1521     }
1522   }
1523   return nullptr;
1524 }
1525 
AllocateCrossProgramPrefetchBuffer(HloModule * module,absl::optional<BufferInterval> prefetch_candidate)1526 void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
1527     HloModule* module, absl::optional<BufferInterval> prefetch_candidate) {
1528   if (!prefetch_candidate) {
1529     return;
1530   }
1531 
1532   ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate);
1533   if (chunk_candidate.chunk.offset != 0 ||
1534       chunk_candidate.heap_size > available_heap_size()) {
1535     LOG(WARNING)
1536         << "Could not allocate preferred memory for cross program prefetch";
1537     return;
1538   }
1539   AddToPendingChunks(*prefetch_candidate, chunk_candidate);
1540 
1541   const HloValue* buffer = prefetch_candidate->buffer;
1542   int64 parameter = buffer->instruction()->parameter_number();
1543   module->AddCrossProgramPrefetch(parameter, buffer->index());
1544 
1545   MemorySpaceAssignment::AllocationSequence allocations;
1546   allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
1547       buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
1548       prefetch_candidate->start, prefetch_candidate->end));
1549 
1550   // Find the earliest use.
1551   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1552   auto uses = buffer->uses();
1553   auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) {
1554     return instruction_schedule.at(lhs.instruction) <
1555            instruction_schedule.at(rhs.instruction);
1556   };
1557   auto first_use = absl::c_min_element(uses, use_schedule_compare);
1558   int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction);
1559 
1560   // Find the latest use time.
1561   int64 last_use_time = instruction_schedule.at(
1562       absl::c_max_element(uses, use_schedule_compare)->instruction);
1563   for (const HloValue* colocation : prefetch_candidate->colocations) {
1564     last_use_time = std::max(
1565         last_use_time,
1566         instruction_schedule.at(
1567             absl::c_max_element(colocation->uses(), use_schedule_compare)
1568                 ->instruction));
1569   }
1570 
1571   int64 end_of_program_prefetch_end_time = instruction_schedule.size() - 1;
1572   int64 end_of_program_prefetch_start_time =
1573       options_.prefetch_interval_picker->PreferredPrefetchStartTime(
1574           buffer->defining_position().shape(), last_use_time,
1575           end_of_program_prefetch_end_time, end_of_program_prefetch_end_time);
1576   VLOG(2) << "last use time = " << last_use_time
1577           << ", end-of-program prefetch start time = "
1578           << end_of_program_prefetch_start_time;
1579   bool free_buffer =
1580       (end_of_program_prefetch_start_time > last_use_time &&
1581        end_of_program_prefetch_start_time < end_of_program_prefetch_end_time);
1582   int64 cross_program_prefetch_end_time =
1583       free_buffer ? last_use_time : prefetch_candidate->end;
1584 
1585   AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
1586                chunk_candidate.chunk, prefetch_candidate->start,
1587                cross_program_prefetch_end_time, latest_prefetch_time,
1588                &allocations, /*aliased_offset=*/nullptr,
1589                /*is_cross_program_prefetch=*/true);
1590   absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
1591   AliasedOffset* cross_program_prefetch_offset =
1592       GetAliasedOffset(*allocations.back());
1593 
1594   if (free_buffer) {
1595     VLOG(2) << "Adding an end-of-program prefetch for freed "
1596                "cross-program-prefetched buffer.";
1597     AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate,
1598                  chunk_candidate.chunk, end_of_program_prefetch_start_time,
1599                  end_of_program_prefetch_end_time,
1600                  end_of_program_prefetch_end_time, &allocations,
1601                  cross_program_prefetch_offset);
1602     CHECK_EQ(cross_program_prefetch_offset->offset,
1603              allocations.back()->chunk().offset);
1604   }
1605 
1606   for (auto& allocation : allocations) {
1607     allocations_->push_back(std::move(allocation));
1608   }
1609 
1610   // Add a repack allocation block for the Allocation objects in alternate
1611   // memory.
1612   CHECK_EQ(repack_allocation_blocks_.size(), 0);
1613   for (const auto& allocation : *allocations_) {
1614     if (allocation->memory_space() == MemorySpace::kAlternate) {
1615       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1616           allocation->start_time(), allocation->end_time(),
1617           allocation->chunk().size, allocation->chunk().offset,
1618           static_cast<int64>(repack_allocation_blocks_.size()),
1619           allocation.get()));
1620       RepackAllocationBlock* inserted = &repack_allocation_blocks_.back();
1621       for (RepackAllocationBlock& colocation : repack_allocation_blocks_) {
1622         colocation.colocations.push_back(inserted);
1623         if (&colocation != inserted) {
1624           inserted->colocations.push_back(&colocation);
1625         }
1626       }
1627     }
1628   }
1629 
1630   ClearPendingChunks();
1631 }
1632 
1633 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
RequiredMemoryAssignmentAt(const HloValue * buffer,int64 time) const1634 AlternateMemoryBestFitHeap::RequiredMemoryAssignmentAt(const HloValue* buffer,
1635                                                        int64 time) const {
1636   auto required_assignment_it = required_assignments_.find(buffer);
1637   absl::optional<RequiredMemoryAssignment> required_assignment_at_time;
1638   if (required_assignment_it != required_assignments_.end()) {
1639     for (const RequiredMemoryAssignment& required_assignment :
1640          required_assignment_it->second) {
1641       if (required_assignment.time == time) {
1642         // Sanity check that there is only one required at time.
1643         CHECK(!required_assignment_at_time);
1644         required_assignment_at_time = required_assignment;
1645       }
1646     }
1647   }
1648   return required_assignment_at_time;
1649 }
1650 
1651 absl::optional<AlternateMemoryBestFitHeap::RequiredMemoryAssignment>
AliasedRequiredAssignmentForUse(const AllocationValue::Use & use) const1652 AlternateMemoryBestFitHeap::AliasedRequiredAssignmentForUse(
1653     const AllocationValue::Use& use) const {
1654   absl::optional<RequiredMemoryAssignment> required_assignment;
1655   for (const HloPosition& position : use.aliases) {
1656     const HloValue* value =
1657         &alias_analysis_.dataflow_analysis().GetUniqueValueAt(
1658             position.instruction, position.index);
1659     int64 time =
1660         hlo_live_range_.instruction_schedule().at(position.instruction);
1661     absl::optional<RequiredMemoryAssignment> required_assignment_for_alias =
1662         RequiredMemoryAssignmentAt(value, time);
1663     if (required_assignment == absl::nullopt) {
1664       required_assignment = required_assignment_for_alias;
1665     } else {
1666       CHECK(required_assignment_for_alias == absl::nullopt ||
1667             required_assignment->equals_ignoring_time(
1668                 *required_assignment_for_alias));
1669     }
1670   }
1671   return required_assignment;
1672 }
1673 
AddAliasedRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,const MemorySpaceAssignment::Allocation * aliased_allocation)1674 void AlternateMemoryBestFitHeap::AddAliasedRequiredAssignment(
1675     const HloInstruction* instruction, ShapeIndex index,
1676     const MemorySpaceAssignment::Allocation* aliased_allocation) {
1677   AliasedOffset* offset = nullptr;
1678   if (aliased_allocation->memory_space() == MemorySpace::kAlternate) {
1679     offset = GetAliasedOffset(*aliased_allocation);
1680   }
1681   AddRequiredAssignment(instruction, index, aliased_allocation->memory_space(),
1682                         offset);
1683 }
1684 
AddRequiredAssignment(const HloValue * value,const HloInstruction * instruction,MemorySpaceAssignment::MemorySpace memory_space,int64 time,AliasedOffset * offset)1685 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1686     const HloValue* value, const HloInstruction* instruction,
1687     MemorySpaceAssignment::MemorySpace memory_space, int64 time,
1688     AliasedOffset* offset) {
1689   // Check for existing required assignment at this time and make sure it is the
1690   // same as this if there is one.
1691   auto existing_required_assignment = RequiredMemoryAssignmentAt(value, time);
1692   if (existing_required_assignment) {
1693     CHECK(memory_space == existing_required_assignment->memory_space)
1694         << "inst = " << instruction->ToString() << " at " << time;
1695     CHECK((!offset && !existing_required_assignment->offset) ||
1696           offset == existing_required_assignment->offset);
1697     VLOG(3) << "Not adding required assignment because there is one already: "
1698             << value->ToShortString() << " at " << time << " at "
1699             << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1700   } else {
1701     VLOG(3) << "Adding required assignment: " << value->ToShortString()
1702             << " at " << time << " at "
1703             << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1704     RequiredMemoryAssignment required_assignment{memory_space, time, offset};
1705     required_assignments_[value].push_back(required_assignment);
1706     pending_required_assignments_.push_back({value, required_assignment});
1707   }
1708 }
1709 
AddRequiredAssignment(const HloInstruction * instruction,ShapeIndex index,MemorySpace memory_space,AliasedOffset * offset)1710 void AlternateMemoryBestFitHeap::AddRequiredAssignment(
1711     const HloInstruction* instruction, ShapeIndex index,
1712     MemorySpace memory_space, AliasedOffset* offset) {
1713   const HloValue* value =
1714       &alias_analysis_.dataflow_analysis().GetUniqueValueAt(instruction, index);
1715   int64 instruction_time =
1716       hlo_live_range_.instruction_schedule().at(instruction);
1717   AddRequiredAssignment(value, instruction, memory_space, instruction_time,
1718                         offset);
1719 }
1720 
AddInputAndOutputRequiredAssignments()1721 void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
1722   // Go through the parameters and outputs and pin them to the corresponding
1723   // memory by adding a required assignment.
1724   const HloModule& module = alias_analysis_.dataflow_analysis().module();
1725   const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
1726   HloComputation* entry_computation = module.entry_computation();
1727   for (HloInstruction* parameter_instruction :
1728        entry_computation->parameter_instructions()) {
1729     int64 parameter_instruction_time =
1730         instruction_schedule.at(parameter_instruction);
1731     ShapeUtil::ForEachSubshape(
1732         parameter_instruction->shape(),
1733         [&](const Shape& subshape, const ShapeIndex& index) {
1734           MemorySpace memory_space = MemorySpace::kDefault;
1735           if (subshape.has_layout() && subshape.layout().memory_space() ==
1736                                            options_.alternate_memory_space) {
1737             memory_space = MemorySpace::kAlternate;
1738           }
1739           for (const HloBuffer* buffer :
1740                alias_analysis_.ComputeBuffersAt(parameter_instruction, index)) {
1741             for (const HloValue* value : buffer->values()) {
1742               VLOG(3) << "Adding required assignment for parameter value = "
1743                       << value->ToShortString()
1744                       << " time = " << parameter_instruction_time << " space = "
1745                       << (memory_space == MemorySpace::kDefault ? "def"
1746                                                                 : "alt");
1747               required_assignments_[value].push_back(
1748                   {memory_space, /*time=*/parameter_instruction_time});
1749             }
1750           }
1751         });
1752   }
1753   HloInstruction* root_instruction = entry_computation->root_instruction();
1754   int64 root_instruction_time = instruction_schedule.at(root_instruction);
1755   ShapeUtil::ForEachSubshape(
1756       root_instruction->shape(),
1757       [&](const Shape& subshape, const ShapeIndex& index) {
1758         MemorySpace memory_space = MemorySpace::kDefault;
1759         if (subshape.has_layout() && subshape.layout().memory_space() ==
1760                                          options_.alternate_memory_space) {
1761           memory_space = MemorySpace::kAlternate;
1762         }
1763         for (const HloBuffer* buffer :
1764              alias_analysis_.ComputeBuffersAt(root_instruction, index)) {
1765           for (const HloValue* value : buffer->values()) {
1766             VLOG(3) << "Adding required assignment for output value = "
1767                     << value->ToShortString()
1768                     << " time = " << root_instruction_time << " space = "
1769                     << (memory_space == MemorySpace::kDefault ? "def" : "alt");
1770             required_assignments_[value].push_back(
1771                 {memory_space, /*time=*/root_instruction_time});
1772           }
1773         }
1774       });
1775 }
1776 
AreIntervalsReservedInAlternateMemory(absl::Span<const BufferInterval * const> colocated_intervals) const1777 bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory(
1778     absl::Span<const BufferInterval* const> colocated_intervals) const {
1779   auto is_position_in_alternate_memory = [&](const HloPosition& position) {
1780     const Shape& shape = position.shape();
1781     return shape.has_layout() &&
1782            shape.layout().memory_space() == options_.alternate_memory_space;
1783   };
1784 
1785   const HloModule& module = alias_analysis_.dataflow_analysis().module();
1786   const HloComputation* entry_computation = module.entry_computation();
1787   const HloInstruction* root_instruction =
1788       entry_computation->root_instruction();
1789   for (const BufferInterval* colocated_interval : colocated_intervals) {
1790     const HloValue* value = colocated_interval->buffer;
1791     if (value->defining_instruction()->opcode() == HloOpcode::kParameter &&
1792         value->defining_instruction()->parent() == entry_computation &&
1793         is_position_in_alternate_memory(value->defining_position())) {
1794       return true;
1795     }
1796 
1797     for (const HloPosition& position : value->positions()) {
1798       if (position.instruction == root_instruction &&
1799           is_position_in_alternate_memory(position)) {
1800         return true;
1801       }
1802     }
1803   }
1804   return false;
1805 }
1806 
ExportAllocationsForRepacking(std::vector<MemorySpaceAssignmentRepacker::AllocationBlock * > & allocations)1807 void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking(
1808     std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*>& allocations) {
1809   for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
1810     allocations.push_back(&allocation_block);
1811   }
1812 }
1813 
ImportRepackedAllocations()1814 void AlternateMemoryBestFitHeap::ImportRepackedAllocations() {
1815   interval_tree_ = {};
1816   for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) {
1817     MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation;
1818     VLOG(3) << "Moved " << allocation->ToString() << ", size "
1819             << allocation->chunk().size << ", (" << allocation_block.start_time
1820             << ", " << allocation_block.end_time << ") from "
1821             << allocation_block.initial_offset << " to "
1822             << allocation_block.offset;
1823     allocation_block.allocation->mutable_chunk()->offset =
1824         allocation_block.offset;
1825     interval_tree_.Add(allocation_block.start_time, allocation_block.end_time,
1826                        {allocation_block.offset, allocation_block.size});
1827     allocation_block.initial_offset = allocation_block.offset;
1828     allocation_block.offset = -1;
1829   }
1830 }
1831 
UncommitPendingChunks(absl::Span<AllocationValue> allocation_values)1832 void AlternateMemoryBestFitHeap::UncommitPendingChunks(
1833     absl::Span<AllocationValue> allocation_values) {
1834   // Clear the allocation sequence of the allocation values so that in case we
1835   // retry allocation after uncommitting.
1836   for (AllocationValue& allocation_value : allocation_values) {
1837     allocation_value.allocation_sequence()->clear();
1838   }
1839   for (const auto& interval_and_chunk : pending_chunks_) {
1840     const BufferInterval& interval = interval_and_chunk.first;
1841     const Chunk& chunk = interval_and_chunk.second.chunk;
1842     VLOG(3) << "Uncommitting: (" << interval.start << ", " << interval.end
1843             << ") off = " << chunk.offset << " size = " << chunk.size;
1844     interval_tree_.Remove(interval.start, interval.end, chunk);
1845   }
1846   for (const auto& interval : pending_async_copies_) {
1847     if (interval.destination == MemorySpace::kAlternate) {
1848       prefetch_interval_tree_.Remove(interval.start_time, interval.end_time,
1849                                      kDummyChunk);
1850       async_copy_ordering_.RemoveCopy(interval);
1851     } else {
1852       eviction_interval_tree_.Remove(interval.start_time, interval.end_time,
1853                                      kDummyChunk);
1854     }
1855   }
1856   for (const auto& value_and_required_assignment :
1857        pending_required_assignments_) {
1858     auto& required_assignment_vector =
1859         required_assignments_[value_and_required_assignment.first];
1860     const RequiredMemoryAssignment& required_assignment =
1861         value_and_required_assignment.second;
1862     VLOG(3) << "Removing required assignment: "
1863             << (required_assignment.memory_space == MemorySpace::kDefault
1864                     ? "def"
1865                     : "alt")
1866             << " time = " << required_assignment.time << " off = "
1867             << (required_assignment.offset ? required_assignment.offset->offset
1868                                            : -1);
1869     for (auto it = required_assignment_vector.begin();
1870          it != required_assignment_vector.end(); ++it) {
1871       if (*it == value_and_required_assignment.second) {
1872         required_assignment_vector.erase(it);
1873         break;
1874       }
1875     }
1876   }
1877   ClearPendingChunks();
1878 }
1879 
FinalizeAllocations(absl::Span<AllocationValue> allocation_values)1880 void AlternateMemoryBestFitHeap::FinalizeAllocations(
1881     absl::Span<AllocationValue> allocation_values) {
1882   absl::flat_hash_map<const AliasedOffset*,
1883                       std::vector<MemorySpaceAssignment::Allocation*>>
1884       colocation_map;
1885   for (AllocationValue& allocation_value : allocation_values) {
1886     for (auto& allocation : *allocation_value.allocation_sequence()) {
1887       AppendAllocationInfoDebugString(allocation_value, *allocation,
1888                                       allocation_info_str_);
1889       allocations_->push_back(std::move(allocation));
1890       MemorySpaceAssignment::Allocation* inserted_allocation =
1891           allocations_->back().get();
1892       if (inserted_allocation->memory_space() == MemorySpace::kAlternate) {
1893         colocation_map[GetAliasedOffset(*inserted_allocation)].push_back(
1894             inserted_allocation);
1895       }
1896     }
1897   }
1898   // The allocations that have the same AliasedOffset need to be colocated.
1899   // Export these to repack_allocation_blocks_ so that we can repack them to
1900   // reduce fragmentation.
1901   for (auto& colocation : colocation_map) {
1902     std::vector<MemorySpaceAssignmentRepacker::AllocationBlock*> colocations;
1903     for (MemorySpaceAssignment::Allocation* colocated_allocation :
1904          colocation.second) {
1905       repack_allocation_blocks_.push_back(MakeRepackAllocationBlock(
1906           colocated_allocation->start_time(), colocated_allocation->end_time(),
1907           colocated_allocation->chunk().size,
1908           colocated_allocation->chunk().offset,
1909           static_cast<int64>(repack_allocation_blocks_.size()),
1910           colocated_allocation));
1911       colocations.push_back(&repack_allocation_blocks_.back());
1912     }
1913     for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block :
1914          colocations) {
1915       repack_block->colocations = colocations;
1916     }
1917   }
1918   ClearPendingChunks();
1919 }
1920 
ClearPendingChunks()1921 void AlternateMemoryBestFitHeap::ClearPendingChunks() {
1922   pending_chunks_.clear();
1923   pending_async_copies_.clear();
1924   pending_required_assignments_.clear();
1925   aliased_offset_map_.clear();
1926   aliased_offsets_.clear();
1927 }
1928 
AddToPendingChunks(const BufferInterval & buffer_interval,const ChunkCandidate & chunk_candidate)1929 void AlternateMemoryBestFitHeap::AddToPendingChunks(
1930     const BufferInterval& buffer_interval,
1931     const ChunkCandidate& chunk_candidate) {
1932   VLOG(3) << "Committing chunk: " << buffer_interval.start << "-"
1933           << buffer_interval.end << " : [" << chunk_candidate.chunk.offset
1934           << ", " << chunk_candidate.chunk.size << "]";
1935   pending_chunks_.emplace_back(buffer_interval, chunk_candidate);
1936   CommitChunk(buffer_interval, chunk_candidate);
1937 }
1938 
AllocateSegment(const AllocationRequest & request)1939 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::AllocateSegment(
1940     const AllocationRequest& request) {
1941   auto allocation_sequence = request.allocation_value->allocation_sequence();
1942   // start_time == end_time is a special case where the value is consumed
1943   // multiple times by the same instruction. We can just find the previous
1944   // allocation and use that allocation.
1945   if (request.start_time == request.end_time) {
1946     MemorySpaceAssignment::Allocation* allocation =
1947         GetLiveAllocationAt(*allocation_sequence, request.end_time);
1948     CHECK_NE(allocation, nullptr);
1949     allocation->AddUse(request.use->hlo_use);
1950     return Result::kSuccess;
1951   }
1952 
1953   const HloPosition& defining_position =
1954       request.allocation_value->defining_position();
1955   VLOG(2) << "Finding allocation for "
1956           << request.allocation_value->ToShortString() << " ("
1957           << request.start_time << ", " << request.end_time
1958           << ") latest prefetch = " << request.latest_prefetch_time
1959           << " last use = " << request.allocation_value->uses().back().time
1960           << " use = " << request.use->hlo_use.ToString()
1961           << ". Size = " << request.size
1962           << ", def pos = " << defining_position.ToString();
1963   CHECK_LE(request.start_time, request.end_time);
1964 
1965   // There could be a requirement to pin this buffer to default memory either
1966   // because it is a parameter or an output.  If the buffer is a parameter, then
1967   // we're allowed to prefetch. If the use expects the output to be in default
1968   // memory, we cannot prefetch it because if we did, it would be in alternate
1969   // memory instead.
1970   auto required_assignment_at_start = RequiredMemoryAssignmentAt(
1971       request.allocation_value->value(), request.start_time);
1972   absl::optional<MemorySpace> required_memory_space_at_start;
1973   if (required_assignment_at_start) {
1974     required_memory_space_at_start = required_assignment_at_start->memory_space;
1975   }
1976   // Find required assignment both for the use and its aliases. If they are both
1977   // non-nullopt, then make sure they require the same assignment.
1978   auto required_assignment_at_end = RequiredMemoryAssignmentAt(
1979       request.allocation_value->value(), request.end_time);
1980   auto aliased_required_assignment_at_end =
1981       AliasedRequiredAssignmentForUse(*request.use);
1982   if (required_assignment_at_end != aliased_required_assignment_at_end) {
1983     if (required_assignment_at_end == absl::nullopt) {
1984       required_assignment_at_end = aliased_required_assignment_at_end;
1985     } else {
1986       CHECK(aliased_required_assignment_at_end == absl::nullopt ||
1987             aliased_required_assignment_at_end->equals_ignoring_time(
1988                 *required_assignment_at_end));
1989     }
1990   }
1991   absl::optional<MemorySpace> required_memory_space_at_end;
1992   if (required_assignment_at_end) {
1993     required_memory_space_at_end = required_assignment_at_end->memory_space;
1994   }
1995 
1996   if (required_assignment_at_start) {
1997     if (!allocation_sequence->empty()) {
1998       // We shouldn't have a situation where the required assignment at start is
1999       // at alternate memory space and we have existing allocations in the
2000       // allocation sequence. The only time we'll have required assignment at
2001       // start to be in the alternate memory space is in called computations
2002       // (e.g., while body) and we shouldn't have any allocations in the
2003       // allocation sequence so far.
2004       CHECK(required_assignment_at_start->memory_space ==
2005             MemorySpace::kDefault);
2006       // Find the previous allocation in default memory (might not be the very
2007       // last one) and extend its lifetime to include the start time of this
2008       // segment.
2009       auto prev_allocation_in_default_mem_it = std::find_if(
2010           allocation_sequence->rbegin(), allocation_sequence->rend(),
2011           [&](const auto& allocation) {
2012             return allocation->memory_space() == MemorySpace::kDefault &&
2013                    allocation->defining_position() == defining_position;
2014           });
2015       CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2016       (*prev_allocation_in_default_mem_it)->Extend(request.start_time);
2017     } else {
2018       absl::optional<Chunk> aliased_chunk = absl::nullopt;
2019       if (required_assignment_at_start->memory_space ==
2020           MemorySpace::kAlternate) {
2021         aliased_chunk =
2022             Chunk{required_assignment_at_start->offset->offset, request.size};
2023       }
2024       allocation_sequence->push_back(
2025           absl::make_unique<MemorySpaceAssignment::Allocation>(
2026               defining_position, required_assignment_at_start->memory_space,
2027               aliased_chunk, request.start_time, request.start_time));
2028       if (required_assignment_at_start->memory_space ==
2029           MemorySpace::kAlternate) {
2030         CreateOrAddToAliasedOffset(*allocation_sequence->back(),
2031                                    required_assignment_at_start->offset);
2032       }
2033     }
2034   }
2035 
2036   Result allocation_result = Result::kSuccess;
2037   // First try keeping the allocation entirely in the alternate memory.
2038   if (required_memory_space_at_start != MemorySpace::kDefault &&
2039       required_memory_space_at_end != MemorySpace::kDefault &&
2040       request.allow_no_copy_alternate_mem_allocation) {
2041     allocation_result = AllocateInAlternateMemoryNoCopy(request);
2042     if (allocation_result == Result::kSuccess) {
2043       return Result::kSuccess;
2044     }
2045   }
2046 
2047   auto prev_allocation_it = allocation_sequence->rbegin();
2048   // Find a previous allocation that is in the default memory space (not
2049   // necessarily the very last allocation).
2050   auto prev_allocation_in_default_mem_it = std::find_if(
2051       allocation_sequence->rbegin(), allocation_sequence->rend(),
2052       [&](const auto& allocation) {
2053         return allocation->memory_space() == MemorySpace::kDefault &&
2054                allocation->defining_position() == defining_position;
2055       });
2056 
2057   if (prev_allocation_in_default_mem_it == allocation_sequence->rend() &&
2058       prev_allocation_it != allocation_sequence->rend() &&
2059       (*prev_allocation_it)->memory_space() == MemorySpace::kAlternate &&
2060       (*prev_allocation_it)->defining_position() == defining_position) {
2061     // If there was an allocation for this HloValue that was in the alternate
2062     // memory space, we also need to perform an eviction.
2063     Result eviction_result = Evict(request);
2064     if (eviction_result != Result::kSuccess) {
2065       // A non-success eviction requires us to uncommit previous allocations.
2066       return result_mark(Result::kFailRequiresUncommit, eviction_result);
2067     }
2068     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2069   } else if (prev_allocation_in_default_mem_it == allocation_sequence->rend()) {
2070     allocation_sequence->push_back(
2071         absl::make_unique<MemorySpaceAssignment::Allocation>(
2072             defining_position, MemorySpace::kDefault, /*chunk=*/absl::nullopt,
2073             request.start_time, request.end_time));
2074     prev_allocation_in_default_mem_it = allocation_sequence->rbegin();
2075   }
2076 
2077   CHECK(prev_allocation_in_default_mem_it != allocation_sequence->rend());
2078   CHECK((*prev_allocation_in_default_mem_it)->memory_space() ==
2079         MemorySpace::kDefault);
2080 
2081   // If the buffer must be in default memory at the end_time, don't prefetch.
2082   if (required_memory_space_at_end == MemorySpace::kDefault) {
2083     VLOG(3)
2084         << "Not trying to prefetch because use requires buffer in default mem.";
2085     (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2086     (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2087     return Result::kSuccess;
2088   }
2089 
2090   // Finally, try to prefetch the buffer into alternate memory.
2091   Result prefetch_result =
2092       Prefetch(request, **prev_allocation_in_default_mem_it);
2093   if (prefetch_result == Result::kSuccess) {
2094     return Result::kSuccess;
2095   }
2096   result_mark(prefetch_result, allocation_result);
2097 
2098   // If the end assignment was required to be in alternate memory but that
2099   // wasn't possible, then this allocation is invalid.
2100   if (required_memory_space_at_end == MemorySpace::kAlternate) {
2101     return result_mark(Result::kFailRequiresUncommit, allocation_result);
2102   }
2103 
2104   // If a copy wasn't inserted, then add this use to the latest allocation in
2105   // default memory.
2106   (*prev_allocation_in_default_mem_it)->Extend(request.end_time);
2107   (*prev_allocation_in_default_mem_it)->AddUse(request.use->hlo_use);
2108   return allocation_result;
2109 }
2110 
AddAsyncCopy(const MemorySpaceAssignment::Allocation & prev_allocation,MemorySpace memory_space,absl::optional<Chunk> chunk,int64 start_time,int64 end_time,int64 copy_done_schedule_before_time,MemorySpaceAssignment::AllocationSequence * allocations,AliasedOffset * aliased_offset,bool is_cross_program_prefetch)2111 void AlternateMemoryBestFitHeap::AddAsyncCopy(
2112     const MemorySpaceAssignment::Allocation& prev_allocation,
2113     MemorySpace memory_space, absl::optional<Chunk> chunk, int64 start_time,
2114     int64 end_time, int64 copy_done_schedule_before_time,
2115     MemorySpaceAssignment::AllocationSequence* allocations,
2116     AliasedOffset* aliased_offset, bool is_cross_program_prefetch) {
2117   VLOG(3) << "Copy to "
2118           << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
2119                   ? "default"
2120                   : "alternate")
2121           << " memory between " << start_time << " and "
2122           << copy_done_schedule_before_time << " keeping until " << end_time;
2123   CHECK_LT(start_time, copy_done_schedule_before_time);
2124 
2125   allocations->push_back(
2126       absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
2127           prev_allocation, memory_space, chunk, start_time, end_time,
2128           copy_done_schedule_before_time, is_cross_program_prefetch));
2129 
2130   // Register the additional async copy with the interval tree to keep track of
2131   // the limit at any given time.
2132   pending_async_copies_.push_back(
2133       {start_time, copy_done_schedule_before_time, memory_space});
2134   if (memory_space == MemorySpaceAssignment::MemorySpace::kAlternate) {
2135     prefetch_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2136                                 kDummyChunk);
2137     async_copy_ordering_.AddCopy(pending_async_copies_.back());
2138     CreateOrAddToAliasedOffset(*allocations->back(), aliased_offset);
2139   } else {
2140     eviction_interval_tree_.Add(start_time, copy_done_schedule_before_time,
2141                                 kDummyChunk);
2142   }
2143 }
2144 
ViolatesMaximumOutstandingAsyncCopies(int64 start_time,int64 end_time,bool is_prefetch,int64 extra_async_copy_limit) const2145 bool AlternateMemoryBestFitHeap::ViolatesMaximumOutstandingAsyncCopies(
2146     int64 start_time, int64 end_time, bool is_prefetch,
2147     int64 extra_async_copy_limit) const {
2148   if (options_.max_outstanding_prefetches < 0 && is_prefetch) {
2149     return false;
2150   }
2151   if (options_.max_outstanding_evictions < 0 && !is_prefetch) {
2152     return false;
2153   }
2154 
2155   // Count the prefetches/evictions in the interval tree for the given interval.
2156   if (is_prefetch) {
2157     int64 num_prefetches =
2158         prefetch_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2159             .size();
2160     return num_prefetches >=
2161            options_.max_outstanding_prefetches + extra_async_copy_limit;
2162   } else {
2163     int64 num_evictions =
2164         eviction_interval_tree_.ChunksOverlappingInTime(start_time, end_time)
2165             .size();
2166     return num_evictions >=
2167            options_.max_outstanding_evictions + extra_async_copy_limit;
2168   }
2169 }
2170 
2171 absl::optional<AsynchronousCopy>
ViolatesAsyncCopyOrdering(int64 start_time,int64 end_time) const2172 AlternateMemoryBestFitHeap::ViolatesAsyncCopyOrdering(int64 start_time,
2173                                                       int64 end_time) const {
2174   return async_copy_ordering_.ViolatesOrdering(start_time, end_time);
2175 }
2176 
2177 AlternateMemoryBestFitHeap::Result
AllocateInAlternateMemoryNoCopy(const AllocationRequest & request)2178 AlternateMemoryBestFitHeap::AllocateInAlternateMemoryNoCopy(
2179     const AllocationRequest& request) {
2180   MemorySpaceAssignment::Allocation* prev_allocation = nullptr;
2181   bool can_eliminate_copy = false;
2182   if (request.allocation_value->allocation_sequence()->empty()) {
2183     // There hasn't been any allocations for this interval so far. We can
2184     // eliminate copy if the value can be placed in the alternate memory.
2185     can_eliminate_copy = options_.is_allowed_in_alternate_mem_fn(
2186         *request.allocation_value->value());
2187   } else {
2188     // If there has been a previous allocation, we can eliminate the copy if the
2189     // previous allocation was also in the alternate memory.
2190     prev_allocation =
2191         request.allocation_value->allocation_sequence()->back().get();
2192     can_eliminate_copy =
2193         (prev_allocation->memory_space() == MemorySpace::kAlternate);
2194   }
2195 
2196   if (!can_eliminate_copy) {
2197     return Result::kFailPrevAllocationNotInAlternateMem;
2198   }
2199 
2200   const HloPosition& defining_position =
2201       request.allocation_value->defining_position();
2202   if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2203           defining_position.shape(), request.start_time + 1,
2204           request.end_time)) {
2205     return Result::kFailLiveRangeTooLong;
2206   }
2207 
2208   BufferInterval alternate_mem_interval;
2209   alternate_mem_interval.buffer = request.allocation_value->value();
2210   alternate_mem_interval.size = request.size;
2211   alternate_mem_interval.end = request.end_time;
2212   alternate_mem_interval.start = request.start_time;
2213 
2214   // Prefer the offset that was previously used for the previous allocation.
2215   AliasedOffset* preferred_offset = nullptr;
2216   if (prev_allocation != nullptr) {
2217     preferred_offset = GetAliasedOffset(*prev_allocation);
2218     // If there is a previous allocation, set the start time one after the end
2219     // of the previous allocation's end.
2220     alternate_mem_interval.start = prev_allocation->end_time() + 1;
2221   }
2222 
2223   if (request.preferred_offset) {
2224     // Sanity check that if there is a preferred offset provided in the request,
2225     // it matches with the previous allocation.
2226     CHECK(!preferred_offset || request.preferred_offset == preferred_offset)
2227         << "preferred_offset = " << preferred_offset->offset
2228         << ", request.preferred_offset = " << request.preferred_offset->offset;
2229     preferred_offset = request.preferred_offset;
2230   }
2231 
2232   VLOG(3) << "We can eliminate copy to alternate memory. Preferred offset = "
2233           << (preferred_offset ? preferred_offset->offset : -1);
2234   // In case there are additional uses after this use, we rely on the last use
2235   // time to try to reserve a chunk in the heap simulator. This is to prevent
2236   // the following scenario:
2237   //
2238   //                            +-------+
2239   //                           /         \
2240   //                   Producer--->Use1   +-->Use2
2241   //                       +---------+---------+
2242   // New buffer:           |         |         |
2243   //                       +---------+---------+
2244   //
2245   //                                     +-----------+
2246   // Current heap:                       | offset: 0 |
2247   //           --------------------------+-----------+------
2248   //
2249   // Because we allocate buffers greedily, Producer to Use1 segment first, and
2250   // then Use1 to Use2 segment, it is possible to allocate the first segment at
2251   // an offset that is available for the first segment (e.g. offset 0) but not
2252   // for the entire live range. This can result in unnecessary copies. By using
2253   // the last use time, we try to find an allocation that is available for the
2254   // entire Producer to Use2 range.
2255   absl::optional<ChunkCandidate> chunk_candidate = FindBestChunkCandidate(
2256       request, preferred_offset, &alternate_mem_interval);
2257   // Check if the new heap size fits within limits. Also ensure if a
2258   // preferred offset was provided, that offset was used.
2259   if (chunk_candidate) {
2260     VLOG(3) << "Keep the buffer in alternate memory. Offset = "
2261             << chunk_candidate->chunk.offset
2262             << ", size = " << chunk_candidate->chunk.size
2263             << ", heap_size = " << chunk_candidate->heap_size
2264             << ", prefetch picker = "
2265             << options_.prefetch_interval_picker->ToNoCopyDebugString(
2266                    defining_position.shape(), request.start_time,
2267                    request.end_time);
2268     AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2269 
2270     // If there was a previous allocation, the buffer location is the
2271     // same as the previous. Otherwise, it is the operand.
2272     if (prev_allocation != nullptr &&
2273         (prev_allocation->is_copy_allocation() ||
2274          prev_allocation->defining_position() == defining_position)) {
2275       prev_allocation->Extend(request.end_time);
2276     } else {
2277       request.allocation_value->allocation_sequence()->push_back(
2278           absl::make_unique<MemorySpaceAssignment::Allocation>(
2279               defining_position, MemorySpace::kAlternate,
2280               chunk_candidate->chunk, request.start_time, request.end_time));
2281       CreateOrAddToAliasedOffset(
2282           *request.allocation_value->allocation_sequence()->back(),
2283           preferred_offset);
2284     }
2285     request.allocation_value->allocation_sequence()->back()->AddUse(
2286         request.use->hlo_use);
2287     return Result::kSuccess;
2288   }
2289   return Result::kFailOutOfMemory;
2290 }
2291 
Evict(const AllocationRequest & request)2292 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Evict(
2293     const AllocationRequest& request) {
2294   CHECK_GT(request.allocation_value->allocation_sequence()->size(), 0);
2295   MemorySpaceAssignment::Allocation* prev_allocation =
2296       request.allocation_value->allocation_sequence()->back().get();
2297   int64 eviction_start_time = prev_allocation->start_time();
2298   int64 eviction_end_time = prev_allocation->end_time();
2299   CHECK(eviction_start_time <= eviction_end_time);
2300 
2301   int64 preferred_eviction_end_time =
2302       std::max(options_.prefetch_interval_picker->PreferredEvictionEndTime(
2303                    request.allocation_value->defining_position().shape(),
2304                    eviction_start_time, request.end_time),
2305                eviction_end_time);
2306   // Evictions must complete by the time of this use.
2307   preferred_eviction_end_time =
2308       std::min(preferred_eviction_end_time, request.latest_prefetch_time);
2309 
2310   BufferInterval eviction_mem_interval;
2311   eviction_mem_interval.buffer = request.allocation_value->value();
2312   eviction_mem_interval.size = request.size;
2313   // Try to reserve a buffer from the end of the previous allocation to the
2314   // preferred eviction end time.
2315   eviction_mem_interval.start = eviction_end_time + 1;
2316   eviction_mem_interval.end = preferred_eviction_end_time;
2317   int64 preferred_offset = prev_allocation->chunk().offset;
2318   VLOG(3) << "Eviction (" << eviction_start_time << ", " << eviction_end_time
2319           << ") preferred end time = " << eviction_mem_interval.end;
2320 
2321   for (; eviction_mem_interval.end > eviction_end_time;
2322        --eviction_mem_interval.end) {
2323     ChunkCandidate chunk_candidate =
2324         FindChunkCandidate(eviction_mem_interval, preferred_offset);
2325     if (chunk_candidate.chunk.offset == preferred_offset) {
2326       AddToPendingChunks(eviction_mem_interval, chunk_candidate);
2327       break;
2328     }
2329   }
2330   eviction_end_time = eviction_mem_interval.end;
2331 
2332   VLOG(3) << "Evicting buffer at " << prev_allocation->chunk().offset << " ("
2333           << eviction_start_time << ", " << eviction_end_time << ")";
2334 
2335   bool eviction_interval_too_short = (eviction_start_time == eviction_end_time);
2336   bool eviction_violates_outstanding_copies =
2337       ViolatesMaximumOutstandingAsyncCopies(eviction_start_time,
2338                                             eviction_end_time,
2339                                             /*is_prefetch=*/false);
2340 
2341   // See if this interval would violate the asynchronous copy limit.
2342   if (!eviction_interval_too_short && !eviction_violates_outstanding_copies) {
2343     prev_allocation->Extend(eviction_end_time);
2344     AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2345                  /*chunk=*/absl::nullopt, eviction_start_time,
2346                  prev_allocation->end_time(), eviction_end_time,
2347                  request.allocation_value->allocation_sequence(),
2348                  /*aliased_offset=*/nullptr);
2349   } else {
2350     if (eviction_violates_outstanding_copies) {
2351       VLOG(3) << "This violates the maximum async copies.";
2352     } else {
2353       VLOG(3) << "Eviction interval is too short (" << eviction_start_time
2354               << ", " << eviction_end_time << ").";
2355     }
2356     // If the original interval violated the limit, try sub-intervals within
2357     // this interval.
2358     bool eviction_scheduled = false;
2359     for (int64 time = eviction_start_time; time < eviction_end_time; ++time) {
2360       VLOG(4) << "Try evicting (" << time << ", " << time + 1 << ")";
2361       if (!ViolatesMaximumOutstandingAsyncCopies(time, time + 1,
2362                                                  /*is_prefetch=*/false)) {
2363         VLOG(3) << "Eviction successful.";
2364         AddAsyncCopy(*prev_allocation, MemorySpace::kDefault,
2365                      /*chunk=*/absl::nullopt, time, time + 1, time + 1,
2366                      request.allocation_value->allocation_sequence(),
2367                      /*aliased_offset=*/nullptr);
2368         eviction_scheduled = true;
2369         break;
2370       }
2371     }
2372 
2373     if (!eviction_scheduled) {
2374       // If the eviction couldn't be scheduled, then fail. This buffer will be
2375       // kept in the default memory.
2376       VLOG(3) << "Bailing: Could not evict " << request.use->hlo_use.ToString()
2377               << " because we hit the limit of maximum asynchronous copies "
2378               << "between "
2379               << hlo_live_range_.flattened_instruction_sequence()
2380                      .instructions()[eviction_start_time]
2381               << " and "
2382               << hlo_live_range_.flattened_instruction_sequence()
2383                      .instructions()[eviction_end_time];
2384       // return false;
2385       return Result::kFailOutOfAsyncCopies;
2386     }
2387   }
2388   // return true;
2389   return Result::kSuccess;
2390 }
2391 
FindPrefetchEndTime(const AllocationRequest & request,int64 earliest_prefetch_time) const2392 int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime(
2393     const AllocationRequest& request, int64 earliest_prefetch_time) const {
2394   int64 prefetch_end_time = request.latest_prefetch_time;
2395 
2396   const HloUse& use = request.use->hlo_use;
2397   const Shape& shape = ShapeUtil::GetSubshape(
2398       use.instruction->operand(use.operand_number)->shape(), use.operand_index);
2399   for (int retry_number = 0;
2400        retry_number < options_.prefetch_copy_done_reorder_max_retries;
2401        ++retry_number) {
2402     int64 latest_prefetch_time =
2403         options_.prefetch_interval_picker->LatestPrefetchStartTime(
2404             shape, earliest_prefetch_time, prefetch_end_time, &use);
2405     VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time
2406             << ", earliest prefetch start time = " << earliest_prefetch_time
2407             << ", prefetch end time = " << prefetch_end_time;
2408     // Return if we couldn't find a suitable prefetch start time.
2409     if (latest_prefetch_time < earliest_prefetch_time) {
2410       break;
2411     }
2412 
2413     // Return either if there is no other violating asynchronous copy (since we
2414     // don't need to change the prefetch end time) or if the violating
2415     // asynchronous copy ends after the prefetch end time.
2416     auto violating_async_copy =
2417         ViolatesAsyncCopyOrdering(latest_prefetch_time, prefetch_end_time);
2418     if (!violating_async_copy ||
2419         violating_async_copy->end_time >= prefetch_end_time) {
2420       break;
2421     }
2422     VLOG(4) << "Violating async copy: (" << violating_async_copy->start_time
2423             << ", " << violating_async_copy->end_time << ")";
2424 
2425     int64 new_prefetch_end_time =
2426         options_.prefetch_interval_picker->LatestPrefetchEndTime(
2427             prefetch_end_time, violating_async_copy->end_time);
2428     if (new_prefetch_end_time > earliest_prefetch_time) {
2429       VLOG(3) << "Update prefetch end time = " << new_prefetch_end_time;
2430       prefetch_end_time = new_prefetch_end_time;
2431     } else {
2432       VLOG(3) << "Can't update prefetch end time = " << new_prefetch_end_time
2433               << " because earliest prefetch start time = "
2434               << earliest_prefetch_time;
2435       break;
2436     }
2437   }
2438 
2439   return prefetch_end_time;
2440 }
2441 
Prefetch(const AllocationRequest & request,const MemorySpaceAssignment::Allocation & prev_allocation_in_default_mem)2442 AlternateMemoryBestFitHeap::Result AlternateMemoryBestFitHeap::Prefetch(
2443     const AllocationRequest& request,
2444     const MemorySpaceAssignment::Allocation& prev_allocation_in_default_mem) {
2445   // Try partially placing the buffer in the alternate space. The time that is
2446   // overlapped will be used to asynchronously copy the buffer from the
2447   // default memory to the alternate memory.
2448   //
2449   //                      start                 end
2450   //                      time                  time
2451   //                      X---------------------X
2452   // Alternate:                          +------+
2453   // Default:             +---------------------+
2454   //                                     ^      ^
2455   //                                   Copy    Copy
2456   //                                   Start   Done
2457   int64 earliest_prefetch_time =
2458       prev_allocation_in_default_mem.earliest_available_time();
2459   if (request.earliest_prefetch_time) {
2460     earliest_prefetch_time =
2461         std::max(earliest_prefetch_time, *request.earliest_prefetch_time);
2462   }
2463   int64 prefetch_end_time =
2464       FindPrefetchEndTime(request, earliest_prefetch_time);
2465 
2466   options_.prefetch_interval_picker->Begin(
2467       request.use->hlo_use, earliest_prefetch_time, prefetch_end_time);
2468   VLOG(3) << "Trying prefetch picker = "
2469           << options_.prefetch_interval_picker->ToDebugString();
2470 
2471   // Create an alternate memory interval that starts at the earliest
2472   // possible position, given by max_prefetch_interval.
2473   BufferInterval alternate_mem_interval;
2474   alternate_mem_interval.buffer = request.allocation_value->value();
2475   alternate_mem_interval.size = request.size;
2476   // While uses might be allowed to have additional outstanding prefetches.
2477   int64 extra_async_copy_limit =
2478       request.use->hlo_use.instruction->opcode() == HloOpcode::kWhile
2479           ? options_.while_use_extra_outstanding_prefetch_limit
2480           : 0;
2481   Result result = Result::kSuccess;
2482   while (!options_.prefetch_interval_picker->Done()) {
2483     alternate_mem_interval.start = options_.prefetch_interval_picker->Next();
2484     CHECK_LT(alternate_mem_interval.start, prefetch_end_time);
2485     VLOG(4) << "Trying alternate memory allocation ("
2486             << alternate_mem_interval.start << ", " << request.end_time << ")";
2487     // If this additional asynchronous copy would violate the limit, try a
2488     // different interval.
2489     if (ViolatesAsyncCopyOrdering(alternate_mem_interval.start,
2490                                   prefetch_end_time)) {
2491       VLOG(4) << "This would violate asynchronous copy ordering.";
2492       result_mark(Result::kFailViolatesAsyncCopyOrdering, result);
2493       continue;
2494     }
2495     if (ViolatesMaximumOutstandingAsyncCopies(
2496             alternate_mem_interval.start, prefetch_end_time,
2497             /*is_prefetch=*/true, extra_async_copy_limit)) {
2498       VLOG(4) << "This would violate the outstanding async copy limit.";
2499       result_mark(Result::kFailOutOfAsyncCopies, result);
2500       continue;
2501     }
2502 
2503     auto chunk_candidate = FindBestChunkCandidate(
2504         request, request.preferred_offset, &alternate_mem_interval);
2505     // Check if we could find a suitable chunk.
2506     if (chunk_candidate) {
2507       VLOG(3) << "Move the buffer to alternate memory at "
2508               << alternate_mem_interval.start
2509               << ". Offset = " << chunk_candidate->chunk.offset
2510               << ", size = " << chunk_candidate->chunk.size
2511               << ", heap_size = " << chunk_candidate->heap_size
2512               << ", prefetch picker = "
2513               << options_.prefetch_interval_picker->ToDebugString();
2514       AddToPendingChunks(alternate_mem_interval, *chunk_candidate);
2515 
2516       AddAsyncCopy(prev_allocation_in_default_mem, MemorySpace::kAlternate,
2517                    chunk_candidate->chunk, alternate_mem_interval.start,
2518                    request.end_time, prefetch_end_time,
2519                    request.allocation_value->allocation_sequence(),
2520                    request.preferred_offset);
2521 
2522       request.allocation_value->allocation_sequence()->back()->AddUse(
2523           request.use->hlo_use);
2524       return Result::kSuccess;
2525     }
2526     result_mark(Result::kFailOutOfMemory, result);
2527   }
2528   // If we didn't consider any prefetch intervals, then the live range was too
2529   // short.
2530   if (result == Result::kSuccess) {
2531     return Result::kFailLiveRangeTooShort;
2532   } else {
2533     return result;
2534   }
2535 }
2536 
2537 absl::optional<AlternateMemoryBestFitHeap::ChunkCandidate>
FindBestChunkCandidate(const AllocationRequest & request,const AliasedOffset * preferred_offset,BufferInterval * alternate_mem_interval) const2538 AlternateMemoryBestFitHeap::FindBestChunkCandidate(
2539     const AllocationRequest& request, const AliasedOffset* preferred_offset,
2540     BufferInterval* alternate_mem_interval) const {
2541   int64 end_time = request.end_time;
2542   if (!preferred_offset) {
2543     // First find the earliest use that is the same or later than the end time.
2544     const auto& uses = request.allocation_value->uses();
2545     auto use_it = uses.begin();
2546     for (; use_it->time < end_time; ++use_it) {
2547     }
2548     CHECK(use_it != uses.end());
2549     int64 earliest_use = use_it->time;
2550 
2551     // Then find the latest use that can be allocated contiguously without
2552     // copies.
2553     const Shape& shape = request.allocation_value->defining_position().shape();
2554     for (;
2555          (use_it + 1) != uses.end() &&
2556          options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
2557              shape, use_it->time, (use_it + 1)->time);
2558          ++use_it) {
2559     }
2560     CHECK(use_it != uses.end());
2561     int64 latest_contiguous_use = use_it->time;
2562 
2563     // Find a chunk that's as long living as possible iterating in reverse over
2564     // the use times.
2565     for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) {
2566       alternate_mem_interval->end = use_it->time;
2567       ChunkCandidate chunk_candidate =
2568           FindChunkCandidate(*alternate_mem_interval);
2569       if (chunk_candidate.heap_size <= available_heap_size()) {
2570         alternate_mem_interval->end = end_time;
2571         VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
2572                 << ", latest contiguous use = " << latest_contiguous_use
2573                 << ", use with available mem = " << use_it->time
2574                 << ", offset = " << chunk_candidate.chunk.offset;
2575         return chunk_candidate;
2576       }
2577     }
2578     alternate_mem_interval->end = end_time;
2579     return absl::nullopt;
2580   }
2581   // If a preferred offset is given, try to find an allocation at that offset
2582   // only.
2583   alternate_mem_interval->end = end_time;
2584   ChunkCandidate chunk_candidate =
2585       FindChunkCandidate(*alternate_mem_interval, preferred_offset->offset);
2586   if (chunk_candidate.chunk.offset == preferred_offset->offset) {
2587     return chunk_candidate;
2588   }
2589   return absl::nullopt;
2590 }
2591 
2592 StatusOr<MemorySpaceAssignment::AsyncCopyStats>
CalculateAsyncCopyStats() const2593 MemorySpaceAssignment::CalculateAsyncCopyStats() const {
2594   AsyncCopyStats stats;
2595   stats.max_outstanding_async_copies = 0;
2596   stats.num_prefetches = 0;
2597   stats.prefetch_bytes = 0;
2598   stats.num_evictions = 0;
2599   stats.eviction_bytes = 0;
2600   int64 current_copies = 0;
2601   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> dataflow_analysis,
2602                       HloDataflowAnalysis::Run(*module_));
2603   for (const HloComputation* computation :
2604        module_->MakeNonfusionComputations()) {
2605     for (HloInstruction* instruction : computation->instructions()) {
2606       if (instruction->opcode() == HloOpcode::kCopyStart) {
2607         current_copies++;
2608       } else if (instruction->opcode() == HloOpcode::kCopyDone) {
2609         current_copies--;
2610         int64 size =
2611             options_.size_fn(dataflow_analysis->GetUniqueValueAt(instruction));
2612         if (instruction->shape().layout().memory_space() ==
2613             options_.alternate_memory_space) {
2614           ++stats.num_prefetches;
2615           stats.prefetch_bytes += size;
2616         } else {
2617           ++stats.num_evictions;
2618           stats.eviction_bytes += size;
2619         }
2620       }
2621       stats.max_outstanding_async_copies =
2622           std::max(stats.max_outstanding_async_copies, current_copies);
2623     }
2624   }
2625   return stats;
2626 }
2627 
2628 /*static*/ MemorySpaceAssignment::BufferIntervalCompare
GetMemoryBoundednessBufferIntervalCompare(const MemorySpaceAssignmentCostAnalysis & cost_analysis,MemorySpaceAssignmentCostAnalysis::Cache * cache)2629 MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
2630     const MemorySpaceAssignmentCostAnalysis& cost_analysis,
2631     MemorySpaceAssignmentCostAnalysis::Cache* cache) {
2632   return [&cost_analysis, cache](const BufferInterval& x,
2633                                  const BufferInterval& y) {
2634     float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
2635     float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
2636     if (x_memory_boundedness != y_memory_boundedness) {
2637       return x_memory_boundedness > y_memory_boundedness;
2638     }
2639     // Tie-break if the memory boundedness is the same.
2640     return GlobalDecreasingSizeBestFitHeap<
2641         HloValue>::GetSpatialBufferIntervalCompare()(x, y);
2642   };
2643 }
2644 
2645 
2646 /*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
Run(HloModule * module,const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis,const Options & options)2647 MemorySpaceAssignment::Run(HloModule* module,
2648                            const HloLiveRange& hlo_live_range,
2649                            const HloAliasAnalysis& alias_analysis,
2650                            const Options& options) {
2651   CHECK(module->has_schedule());
2652   VLOG(3) << "Module before memory space assignment: ";
2653   XLA_VLOG_LINES(3, module->ToString());
2654   VLOG(3) << "Schedule: " << module->schedule().ToString();
2655   MemorySpaceAssignment memory_space_assignment(module, options,
2656                                                 hlo_live_range);
2657 
2658   return memory_space_assignment.RunMemorySpaceAssignment(hlo_live_range,
2659                                                           alias_analysis);
2660 }
2661 
2662 StatusOr<std::unique_ptr<PresetAssignments>>
RunMemorySpaceAssignment(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2663 MemorySpaceAssignment::RunMemorySpaceAssignment(
2664     const HloLiveRange& hlo_live_range,
2665     const HloAliasAnalysis& alias_analysis) {
2666   TF_RETURN_IF_ERROR(FindAllocationSequence(hlo_live_range, alias_analysis));
2667   TF_RETURN_IF_ERROR(Process());
2668   ScheduleAsynchronousCopies();
2669   TF_RETURN_IF_ERROR(SimplifyGraph());
2670   TF_RETURN_IF_ERROR(FixSchedule());
2671   TF_RETURN_IF_ERROR(ExportAndColorBuffers());
2672 
2673   VLOG(3) << "Module after memory space assignment: ";
2674   XLA_VLOG_LINES(3, module_->ToString());
2675   TF_CHECK_OK(module_->schedule().Verify());
2676   TF_ASSIGN_OR_RETURN(AsyncCopyStats stats, CalculateAsyncCopyStats());
2677   VLOG(1) << "Maximum number of outstanding async copies: "
2678           << stats.max_outstanding_async_copies;
2679   VLOG(1) << "Number of prefetches: " << stats.num_prefetches
2680           << ", in bytes: " << stats.prefetch_bytes;
2681   VLOG(1) << "Number of evictions: " << stats.num_evictions
2682           << ", in bytes: " << stats.eviction_bytes;
2683 
2684   TF_RETURN_IF_ERROR(VerifyAndExportHeapSimulatorTrace());
2685 
2686   return std::move(preset_assignments_);
2687 }
2688 
FindAllocationSequence(const HloLiveRange & hlo_live_range,const HloAliasAnalysis & alias_analysis)2689 Status MemorySpaceAssignment::FindAllocationSequence(
2690     const HloLiveRange& hlo_live_range,
2691     const HloAliasAnalysis& alias_analysis) {
2692   auto algorithm = absl::make_unique<AlternateMemoryBestFitHeap>(
2693       &allocations_, options_, alias_analysis, hlo_live_range);
2694 
2695   HeapSimulator::Options heap_simulator_options;
2696   heap_simulator_options.may_reuse_operand_buffers = false;
2697   TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module_,
2698                                         module_->schedule(), alias_analysis,
2699                                         options_.size_fn,
2700                                         heap_simulator_options)
2701                          .status());
2702   return Status::OK();
2703 }
2704 
AddUse(HloUse use)2705 void MemorySpaceAssignment::Allocation::AddUse(HloUse use) {
2706   HloInstruction* operand =
2707       use.instruction->mutable_operand(use.operand_number);
2708   // If the use is a tuple, look inside the tuple to find the actual use.
2709   for (int64 index : use.operand_index) {
2710     if (operand->opcode() != HloOpcode::kTuple) {
2711       break;
2712     }
2713     operand = operand->mutable_operand(index);
2714   }
2715 
2716   // Look beyond GetTupleElement(Tuple()) pattern for any bitcasts.
2717   std::function<HloInstruction*(HloInstruction*)> get_simplified_operand;
2718   get_simplified_operand = [&](HloInstruction* instruction) {
2719     while (instruction->opcode() == HloOpcode::kGetTupleElement) {
2720       HloInstruction* operand =
2721           get_simplified_operand(instruction->mutable_operand(0));
2722       if (operand->opcode() == HloOpcode::kTuple) {
2723         instruction = operand->mutable_operand(instruction->tuple_index());
2724       } else {
2725         return instruction;
2726       }
2727     }
2728     return instruction;
2729   };
2730   operand = get_simplified_operand(operand);
2731 
2732   uses_.push_back(use);
2733 }
2734 
Process(MemorySpaceAssignment * memory_space_assignment)2735 Status MemorySpaceAssignment::Allocation::Process(
2736     MemorySpaceAssignment* memory_space_assignment) {
2737   HloInstruction* producing_instruction = AddGetTupleElements();
2738   HloComputation* computation = producing_instruction->parent();
2739   for (const HloUse& use : uses_) {
2740     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
2741     HloInstruction* replacement_instruction = producing_instruction;
2742     if (operand_shape.IsTuple()) {
2743       TF_ASSIGN_OR_RETURN(
2744           replacement_instruction,
2745           ReplaceTupleWith(producing_instruction,
2746                            use.instruction->mutable_operand(use.operand_number),
2747                            use.operand_index));
2748     } else if (operand_shape != producing_instruction->shape()) {
2749       VLOG(4) << "Old shape = " << operand_shape.ToString()
2750               << ", new shape = " << producing_instruction->shape().ToString()
2751               << "; inserting a bitcast.";
2752       replacement_instruction = computation->AddInstruction(
2753           HloInstruction::CreateBitcast(operand_shape, producing_instruction));
2754     }
2755     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
2756         use.operand_number, replacement_instruction));
2757   }
2758   return Status::OK();
2759 }
2760 
ReplaceTupleWith(HloInstruction * new_instruction,HloInstruction * tuple,ShapeIndex shape_index)2761 StatusOr<HloInstruction*> MemorySpaceAssignment::Allocation::ReplaceTupleWith(
2762     HloInstruction* new_instruction, HloInstruction* tuple,
2763     ShapeIndex shape_index) {
2764   const Shape& tuple_shape = tuple->shape();
2765   CHECK(tuple->shape().IsTuple())
2766       << "ReplaceTupleWith was called for a non-tuple. Tuple = "
2767       << tuple->ToString()
2768       << ", new_instruction = " << new_instruction->ToString()
2769       << ", shape_index = " << shape_index.ToString();
2770 
2771   HloComputation* computation = new_instruction->parent();
2772   std::vector<HloInstruction*> tuple_args(tuple_shape.tuple_shapes_size());
2773   for (int64 i = 0; i < tuple_shape.tuple_shapes_size(); ++i) {
2774     const Shape& subshape = tuple_shape.tuple_shapes(i);
2775     // If tuple is a tuple instruction, we can get the tuple instruction's
2776     // operand to construct the new tuple to improve compilation time
2777     // performance.
2778     auto get_operand = [&]() {
2779       if (tuple->opcode() == HloOpcode::kTuple) {
2780         return tuple->mutable_operand(i);
2781       } else {
2782         return computation->AddInstruction(
2783             HloInstruction::CreateGetTupleElement(subshape, tuple, i));
2784       }
2785     };
2786     if (i == shape_index[0]) {
2787       // If the subshape is still a tuple, recurse and pass a new shape index
2788       // for the one level deeper.
2789       if (subshape.IsTuple()) {
2790         TF_ASSIGN_OR_RETURN(tuple_args[i],
2791                             ReplaceTupleWith(new_instruction, get_operand(),
2792                                              ShapeIndex(shape_index.begin() + 1,
2793                                                         shape_index.end())));
2794       } else {
2795         if (subshape != new_instruction->shape()) {
2796           VLOG(4) << "Old shape = " << subshape.ToString()
2797                   << ", new shape = " << new_instruction->shape().ToString()
2798                   << "; inserting a bitcast.";
2799           new_instruction = computation->AddInstruction(
2800               HloInstruction::CreateBitcast(subshape, new_instruction));
2801         } else if (tuple->opcode() == HloOpcode::kTuple &&
2802                    tuple->operand(i) == new_instruction) {
2803           // If the tuple element is the same as the new instruction, we
2804           // actually don't have to create a new tuple, just return the original
2805           // tuple.
2806           VLOG(4) << "Tuple already contains the new instruction = "
2807                   << new_instruction->ToShortString()
2808                   << " tuple = " << tuple->ToShortString();
2809           return tuple;
2810         }
2811         tuple_args[i] = new_instruction;
2812       }
2813     } else {
2814       tuple_args[i] = get_operand();
2815     }
2816   }
2817   return computation->AddInstruction(HloInstruction::CreateTuple(tuple_args));
2818 }
2819 
AddGetTupleElements()2820 HloInstruction* MemorySpaceAssignment::Allocation::AddGetTupleElements() {
2821   HloInstruction* producing_instruction = defining_position().instruction;
2822   CHECK_NE(producing_instruction, nullptr);
2823 
2824   Shape shape = defining_position().shape();
2825   CHECK(shape.IsArray()) << "Allocation shape is not an array. Shape = "
2826                          << shape.ToString()
2827                          << " position = " << defining_position().shape();
2828   HloComputation* computation = producing_instruction->parent();
2829 
2830   // If the instruction we're processing is a tuple, we (recursively) search or
2831   // create kGetTupleElement instructions and copy that value. Asynchronous
2832   // copies only support array types.
2833   for (int64 index : defining_position().index) {
2834     // We first search if there already is a get-tuple-element with the correct
2835     // index. If there is no such get-tuple-element, we create one.
2836     auto gte_it = absl::c_find_if(
2837         producing_instruction->users(), [index](const HloInstruction* use) {
2838           return use != use->parent()->root_instruction() &&
2839                  use->opcode() == HloOpcode::kGetTupleElement &&
2840                  use->tuple_index() == index;
2841         });
2842     if (gte_it != producing_instruction->users().end()) {
2843       producing_instruction = *gte_it;
2844     } else {
2845       producing_instruction =
2846           computation->AddInstruction(HloInstruction::CreateGetTupleElement(
2847               producing_instruction->shape().tuple_shapes(index),
2848               producing_instruction, index));
2849     }
2850   }
2851   return producing_instruction;
2852 }
2853 
ToString() const2854 std::string MemorySpaceAssignment::Allocation::ToString() const {
2855   std::string memory_space_str = "def";
2856   if (memory_space_ == MemorySpace::kAlternate) {
2857     memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
2858   }
2859   return absl::StrCat("Allocation in ", memory_space_str, " defined at ",
2860                       defining_position_.ToString());
2861 }
2862 
ToString() const2863 std::string MemorySpaceAssignment::CopyAllocation::ToString() const {
2864   std::string memory_space_str = "def";
2865   if (memory_space_ == MemorySpace::kAlternate) {
2866     memory_space_str = absl::StrCat("alt (off: ", chunk_->offset, ")");
2867   }
2868   return absl::StrCat("Copy Allocation in ", memory_space_str, " from ",
2869                       prev_allocation_.ToString());
2870 }
2871 
Process(MemorySpaceAssignment * memory_space_assignment)2872 Status MemorySpaceAssignment::CopyAllocation::Process(
2873     MemorySpaceAssignment* memory_space_assignment) {
2874   // Copy allocations need to insert asynchronous copy nodes.
2875   Shape shape = defining_position().shape();
2876   HloInstruction* producing_instruction = AddGetTupleElements();
2877   HloComputation* computation = producing_instruction->parent();
2878   copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
2879       ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
2880       producing_instruction, is_cross_program_prefetch_));
2881   copy_done_ = computation->AddInstruction(
2882       HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
2883   VLOG(4) << "Created " << copy_start_->name()
2884           << " for position: " << defining_position().ToString();
2885   // Update the allocation position with the copy done instruction so that if
2886   // there are further copies from it, it can find the correct position.
2887   defining_position_ = HloPosition{copy_done_, {}};
2888 
2889   // Replace all the uses with the new copy instruction.
2890   for (HloUse use : uses_) {
2891     // If the operand is a tuple, we need to descend to the actual instruction
2892     // we want to replace.
2893     HloInstruction* replacement_instruction;
2894     Shape operand_shape = use.instruction->operand(use.operand_number)->shape();
2895     if (operand_shape.IsTuple()) {
2896       TF_ASSIGN_OR_RETURN(
2897           replacement_instruction,
2898           ReplaceTupleWith(copy_done_,
2899                            use.instruction->mutable_operand(use.operand_number),
2900                            use.operand_index));
2901     } else if (operand_shape != copy_done_->shape()) {
2902       VLOG(4) << "Old shape = " << operand_shape.ToString()
2903               << ", new shape = " << copy_done_->shape().ToString()
2904               << "; inserting a bitcast.";
2905       replacement_instruction = computation->AddInstruction(
2906           HloInstruction::CreateBitcast(operand_shape, copy_done_));
2907     } else {
2908       replacement_instruction = copy_done_;
2909     }
2910     TF_RETURN_IF_ERROR(use.instruction->ReplaceOperandWith(
2911         use.operand_number, replacement_instruction));
2912   }
2913 
2914   return Status::OK();
2915 }
2916 
Process()2917 Status MemorySpaceAssignment::Process() {
2918   VLOG(1) << "Processing assigned buffers...";
2919   // Insert CopyStart/CopyDone pairs.
2920   for (auto& allocation : allocations_) {
2921     VLOG(3) << "Processing: " << allocation->ToString();
2922     TF_RETURN_IF_ERROR(allocation->Process(this));
2923     // Add the offset and size of the allocation in the alternate memory to
2924     // the output map.
2925     if (allocation->memory_space() == MemorySpace::kAlternate) {
2926       alternate_memory_assignments_.emplace_back(
2927           allocation->defining_position(), allocation->chunk());
2928       alternate_memory_size_ =
2929           std::max(alternate_memory_size_, allocation->chunk().chunk_end());
2930     }
2931   }
2932   return Status::OK();
2933 }
2934 
ExportAndColorBuffers()2935 Status MemorySpaceAssignment::ExportAndColorBuffers() {
2936   VLOG(1) << "Exporting buffers...";
2937   TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(module_));
2938   absl::flat_hash_map<int64, int64> seen_buffer_offsets;
2939   VLOG(3) << "Exported alternate memory allocations:";
2940   for (const auto& position_and_chunk : alternate_memory_assignments_) {
2941     const HloPosition& defining_position = position_and_chunk.first;
2942     const Chunk& chunk = position_and_chunk.second;
2943     const HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(
2944         defining_position.instruction, defining_position.index);
2945     auto seen_buffer_offset_it = seen_buffer_offsets.find(buffer.id());
2946     if (seen_buffer_offset_it != seen_buffer_offsets.end()) {
2947       CHECK_EQ(chunk.offset, seen_buffer_offset_it->second)
2948           << "Mismatch in offset for positions that map to the same value: "
2949           << buffer.ToString() << ", pos: " << defining_position.ToString();
2950     } else {
2951       VLOG(3) << " [" << chunk.offset << ", " << chunk.size
2952               << "] : " << defining_position.ToString() << " ("
2953               << buffer.ToString() << ")";
2954       preset_assignments_->add_chunk(defining_position, chunk);
2955       seen_buffer_offsets[buffer.id()] = chunk.offset;
2956     }
2957   }
2958 
2959   if (!preset_assignments_->chunks().empty()) {
2960     preset_assignments_
2961         ->assignment_information_for_space(options_.alternate_memory_space)
2962         ->size = alternate_memory_size_;
2963   }
2964 
2965   VLOG(3) << "Exported alternate memory sizes:";
2966   for (auto& pair : preset_assignments_->assignment_informations()) {
2967     VLOG(3) << "  space: " << pair.first << ", size: " << pair.second.size;
2968   }
2969 
2970   VLOG(1) << "Coloring buffers...";
2971   // Color the pending positions and all of their aliased buffers.
2972   for (const auto& defining_position_and_chunk :
2973        preset_assignments_->chunks()) {
2974     const HloPosition& defining_position = defining_position_and_chunk.first;
2975     for (auto& buffer : alias_analysis->ComputeBuffersAt(
2976              defining_position.instruction, defining_position.index)) {
2977       for (auto& value : buffer->values()) {
2978         for (auto& position : value->positions()) {
2979           VLOG(4) << "Coloring " << position.ToString();
2980           Shape* shape = ShapeUtil::GetMutableSubshape(
2981               position.instruction->mutable_shape(), position.index);
2982           CHECK(shape->IsArray()) << "Coloring a shape that is not an array: "
2983                                   << position.ToString();
2984           shape->mutable_layout()->set_memory_space(
2985               options_.alternate_memory_space);
2986         }
2987       }
2988     }
2989   }
2990   return Status::OK();
2991 }
2992 
RemoveAssignmentForInstruction(const HloInstruction * instruction)2993 void MemorySpaceAssignment::RemoveAssignmentForInstruction(
2994     const HloInstruction* instruction) {
2995   for (auto& position_and_chunk : alternate_memory_assignments_) {
2996     const HloPosition& position = position_and_chunk.first;
2997     if (position.instruction == instruction) {
2998       VLOG(3) << "Removing instruction from alternate memory assignments.";
2999       // Swap the removed position and chunk with the back and pop back.
3000       position_and_chunk = alternate_memory_assignments_.back();
3001       alternate_memory_assignments_.pop_back();
3002       break;
3003     }
3004   }
3005 }
3006 
SimplifyGraph()3007 Status MemorySpaceAssignment::SimplifyGraph() {
3008   VLOG(1) << "Simplifying graph...";
3009   for (HloComputation* computation : module_->MakeNonfusionComputations()) {
3010     // Parallel computations aren't in the schedule and don't need to be
3011     // modified.
3012     if (!computations_in_schedule_.contains(computation)) {
3013       VLOG(4) << "Not simplifying " << computation->name()
3014               << " because it's not in the schedule.";
3015       continue;
3016     }
3017     // Drop control dependencies. Since the computation is already scheduled, we
3018     // don't need control dependencies anymore, and having control
3019     // predecessors/successors prevents us from removing instructions without
3020     // users (HloComputation::IsSafelyRemovable returns false if there are
3021     // control dependencies).
3022     for (HloInstruction* instruction :
3023          computation->MakeInstructionPostOrder()) {
3024       TF_RETURN_IF_ERROR(instruction->DropAllControlDeps());
3025     }
3026     // We perform limited DCE and forward the tuple operand in patterns like
3027     // GetTupleElement(Tuple(a, b), 0). This is mostly because memory space
3028     // assignment is ran late in compilation (after DCE and arithmetic
3029     // simplification passes) and we don't want to generate redundant code.  Run
3030     // to fixed point.
3031     bool computation_modified = true;
3032     while (computation_modified) {
3033       computation_modified = false;
3034       VLOG(4) << "Running simplify graph loop over " << computation->name();
3035       for (HloInstruction* instruction :
3036            computation->MakeInstructionPostOrder()) {
3037         if (computation->IsSafelyRemovable(instruction) &&
3038             instruction->user_count() == 0 && !instruction->HasSideEffect() &&
3039             instruction != computation->root_instruction() &&
3040             instruction->opcode() != HloOpcode::kCopyStart &&
3041             instruction->opcode() != HloOpcode::kCopyDone) {
3042           VLOG(4) << "Instruction removed: " << instruction->ToString();
3043           // Ensure the alternate memory assignments don't contain a reference
3044           // to the removed instruction.
3045           RemoveAssignmentForInstruction(instruction);
3046           // Instead of deleting the instruction from the schedule, replace it
3047           // with a nullptr. This is needed because FixSchedule relies on the
3048           // logical time that is the index into flattened_instructions_ for
3049           // scheduling asynchronous copies.
3050           auto instruction_it =
3051               absl::c_find(flattened_instructions_, instruction);
3052           if (instruction_it != flattened_instructions_.end()) {
3053             *instruction_it = nullptr;
3054           }
3055           TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
3056           computation_modified = true;
3057         } else if (instruction->opcode() == HloOpcode::kGetTupleElement) {
3058           HloInstruction* operand = instruction->mutable_operand(0);
3059           if (operand->opcode() == HloOpcode::kTuple) {
3060             HloInstruction* forwarded_instruction =
3061                 operand->mutable_operand(instruction->tuple_index());
3062             VLOG(4) << "Replacing uses of " << instruction->ToString()
3063                     << " with " << forwarded_instruction->ToString();
3064             TF_RETURN_IF_ERROR(
3065                 instruction->ReplaceAllUsesWith(forwarded_instruction));
3066             computation_modified = true;
3067           }
3068         } else if (instruction->opcode() == HloOpcode::kTuple) {
3069           // Replace Tuple(GetTupleElement(x), ..., GetTupleElement(x)) pattern
3070           // with x.
3071           bool can_replace =
3072               instruction->operand_count() > 0 &&
3073               instruction->operand(0)->opcode() ==
3074                   HloOpcode::kGetTupleElement &&
3075               instruction->operand(0)
3076                       ->operand(0)
3077                       ->shape()
3078                       .tuple_shapes_size() == instruction->operand_count();
3079           for (int operand_number = 0;
3080                operand_number < instruction->operand_count();
3081                ++operand_number) {
3082             const HloInstruction* operand =
3083                 instruction->operand(operand_number);
3084             if (operand->opcode() != HloOpcode::kGetTupleElement ||
3085                 operand->tuple_index() != operand_number ||
3086                 operand->operand(0) != instruction->operand(0)->operand(0)) {
3087               can_replace = false;
3088               break;
3089             }
3090           }
3091           if (can_replace) {
3092             HloInstruction* forwarded_instruction =
3093                 instruction->mutable_operand(0)->mutable_operand(0);
3094             VLOG(4) << "Replacing uses of " << instruction->ToString()
3095                     << " with " << forwarded_instruction->ToString();
3096             TF_RETURN_IF_ERROR(
3097                 instruction->ReplaceAllUsesWith(forwarded_instruction));
3098             computation_modified = true;
3099           }
3100         }
3101       }
3102     }
3103   }
3104 
3105   return Status::OK();
3106 }
3107 
EnsureInstructionAndOperandsInserted(HloInstruction * new_instruction,HloInstructionSequence * new_sequence,absl::flat_hash_set<HloInstruction * > * inserted_instructions) const3108 void MemorySpaceAssignment::EnsureInstructionAndOperandsInserted(
3109     HloInstruction* new_instruction, HloInstructionSequence* new_sequence,
3110     absl::flat_hash_set<HloInstruction*>* inserted_instructions) const {
3111   if (inserted_instructions->contains(new_instruction)) {
3112     return;
3113   }
3114   for (HloInstruction* operand : new_instruction->operands()) {
3115     // CopyStart/CopyDone dependencies should always be already inserted; it is
3116     // a red flag when they haven't already been inserted.
3117     CHECK((operand->opcode() != HloOpcode::kCopyStart &&
3118            operand->opcode() != HloOpcode::kCopyDone) ||
3119           inserted_instructions->contains(operand))
3120         << "Inserted instruction " << new_instruction->ToString()
3121         << " has un-inserted dependency: " << operand->ToString();
3122     EnsureInstructionAndOperandsInserted(operand, new_sequence,
3123                                          inserted_instructions);
3124   }
3125   VLOG(4) << "inserting: " << new_instruction->ToShortString();
3126   new_sequence->push_back(new_instruction);
3127   inserted_instructions->insert(new_instruction);
3128 }
3129 
ScheduleAsynchronousCopies()3130 void MemorySpaceAssignment::ScheduleAsynchronousCopies() {
3131   VLOG(1) << "Scheduling asynchronous copies...";
3132   for (MemorySpace memory_space :
3133        {MemorySpace::kDefault, MemorySpace::kAlternate}) {
3134     std::vector<CopyAllocation*> copy_allocations;
3135     for (auto& allocation : allocations_) {
3136       if (allocation->is_copy_allocation()) {
3137         auto copy_allocation = static_cast<CopyAllocation*>(allocation.get());
3138         if (copy_allocation->memory_space() == memory_space) {
3139           copy_allocations.push_back(copy_allocation);
3140         }
3141       }
3142     }
3143 
3144     absl::c_stable_sort(
3145         copy_allocations, [](CopyAllocation* first, CopyAllocation* second) {
3146           return std::forward_as_tuple(first->copy_done_schedule_before(),
3147                                        first->copy_start_schedule_after()) <
3148                  std::forward_as_tuple(second->copy_done_schedule_before(),
3149                                        second->copy_start_schedule_after());
3150         });
3151 
3152     CopyAllocation* prev_copy_allocation = nullptr;
3153     for (CopyAllocation* copy_allocation : copy_allocations) {
3154       // If the copy start doesn't happen to be scheduled at the correct
3155       // computation, delay it until the correct computation starts.
3156       int64 copy_start_schedule_after =
3157           copy_allocation->copy_start_schedule_after();
3158       // Accessing flattened_instructions_ here without checking if it is
3159       // nullptr is safe because this method is called before SimplifyGraph.
3160       while (copy_allocation->defining_position().instruction->parent() !=
3161              flattened_instructions_[copy_start_schedule_after]->parent()) {
3162         VLOG(4) << "Delaying CopyStart (" << copy_start_schedule_after << " to "
3163                 << (copy_start_schedule_after + 1) << ") for "
3164                 << copy_allocation->copy_start()->ToString()
3165                 << " because it is not in the correct computation.";
3166         copy_allocation->set_copy_start_schedule_after(
3167             ++copy_start_schedule_after);
3168       }
3169 
3170       schedule_after_[copy_allocation->copy_start_schedule_after()].push_back(
3171           copy_allocation->copy_start());
3172       schedule_before_[copy_allocation->copy_done_schedule_before()].push_back(
3173           copy_allocation->copy_done());
3174       prev_copy_allocation = copy_allocation;
3175     }
3176   }
3177 }
3178 
FixSchedule()3179 Status MemorySpaceAssignment::FixSchedule() {
3180   VLOG(1) << "Fixing schedule...";
3181   CHECK(module_->has_schedule());
3182   HloSchedule& schedule = module_->schedule();
3183   for (const HloComputation* computation :
3184        module_->MakeNonfusionComputations()) {
3185     // Parallel computations aren't in the schedule and don't need to be
3186     // modified.
3187     if (!computations_in_schedule_.contains(computation)) {
3188       VLOG(4) << "Not scheduling " << computation->name()
3189               << " because it's not in the schedule.";
3190       continue;
3191     }
3192     CHECK(schedule.is_computation_scheduled(computation));
3193     HloInstructionSequence new_sequence;
3194 
3195     absl::flat_hash_set<HloInstruction*> inserted_instructions;
3196 
3197     VLOG(4) << "Scheduling: " << computation->ToString();
3198 
3199     for (int64 instruction_index = 0;
3200          instruction_index < flattened_instructions_.size();
3201          ++instruction_index) {
3202       auto insts_before_iter = schedule_before_.find(instruction_index);
3203       if (insts_before_iter != schedule_before_.end()) {
3204         for (HloInstruction* new_instruction : insts_before_iter->second) {
3205           if (new_instruction->parent() == computation) {
3206             VLOG(4) << "before " << instruction_index << ": "
3207                     << new_instruction->name();
3208             EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3209                                                  &inserted_instructions);
3210           }
3211         }
3212       }
3213       HloInstruction* instruction = flattened_instructions_[instruction_index];
3214       // Insert only if it is not deleted (SimplifyGraph sets it to nullptr if
3215       // it was deleted) and not previously inserted. Also bitcasts and tuples
3216       // are treated specially and only inserted as a result of operand
3217       // dependencies.
3218       if (instruction != nullptr &&
3219           !inserted_instructions.contains(instruction) &&
3220           instruction->parent() == computation &&
3221           instruction->opcode() != HloOpcode::kBitcast &&
3222           instruction->opcode() != HloOpcode::kTuple) {
3223         VLOG(4) << "inst " << instruction_index << ": " << instruction->name();
3224         EnsureInstructionAndOperandsInserted(instruction, &new_sequence,
3225                                              &inserted_instructions);
3226       }
3227       auto insts_after_iter = schedule_after_.find(instruction_index);
3228       if (insts_after_iter != schedule_after_.end()) {
3229         for (HloInstruction* new_instruction : insts_after_iter->second) {
3230           if (new_instruction->parent() == computation) {
3231             VLOG(4) << "after " << instruction_index << ": "
3232                     << new_instruction->name();
3233             EnsureInstructionAndOperandsInserted(new_instruction, &new_sequence,
3234                                                  &inserted_instructions);
3235           }
3236         }
3237       }
3238     }
3239     // For rare cases where the original sequence is empty, ensure the root
3240     // instruction and its dependencies are scheduled.
3241     EnsureInstructionAndOperandsInserted(computation->root_instruction(),
3242                                          &new_sequence, &inserted_instructions);
3243     CHECK_EQ(new_sequence.size(), computation->instruction_count())
3244         << "New sequence for computation " << computation->name() << " has "
3245         << new_sequence.size() << " instructions, expects "
3246         << computation->instruction_count() << ".";
3247     schedule.set_sequence(computation, new_sequence);
3248   }
3249 
3250   return Status::OK();
3251 }
3252 
VerifyAndExportHeapSimulatorTrace()3253 Status MemorySpaceAssignment::VerifyAndExportHeapSimulatorTrace() {
3254   VLOG(1) << "Verifying...";
3255   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
3256                       HloAliasAnalysis::Run(module_));
3257   TF_ASSIGN_OR_RETURN(std::unique_ptr<HloLiveRange> hlo_live_range,
3258                       HloLiveRange::Run(module_->schedule(), *alias_analysis,
3259                                         module_->entry_computation()));
3260 
3261   BufferIntervalTree interval_tree;
3262   absl::flat_hash_set<int64> seen_buffers;
3263   // The key for events is: time, is_free, value_id. This is so that the events
3264   // are sorted first by time, then within the same time, allocations are sorted
3265   // earlier than frees, and finally the value id as a tie breaker.
3266   std::map<std::tuple<int64, bool, int64>,
3267            std::tuple<const HloValue*, Chunk, HeapSimulatorTrace::Event::Kind>>
3268       events;
3269 
3270   auto add_allocation_and_verify = [&](int64 start_time, int64 end_time,
3271                                        const Chunk& chunk,
3272                                        const HloValue* value) {
3273     events[std::make_tuple(start_time, /*is_free=*/false, value->id())] =
3274         std::make_tuple(value, chunk, HeapSimulatorTrace::Event::ALLOC);
3275     events[std::make_tuple(end_time, /*is_free=*/true, value->id())] =
3276         std::make_tuple(value, chunk, HeapSimulatorTrace::Event::FREE);
3277 
3278     // Get the chunks overlapping in time and search if they overlap in space
3279     // as well.
3280     // TODO(berkin): For now checking against end_time - 1 (exclusive), but we
3281     // really should check against end_time (inclusive) for cases where the
3282     // operand can't share buffer with user (see
3283     // HloDataflowAnalysis::CanShareOperandBufferWithUser).
3284     for (const Chunk& overlapping_chunk :
3285          interval_tree.ChunksOverlappingInTime(start_time, end_time - 1)) {
3286       if (chunk.OverlapsWith(overlapping_chunk)) {
3287         return InternalError(
3288             ("Value %s (%d, %d) off: %d size: %d overlaps with another chunk"
3289              " off: %d size: %d"),
3290             value->ToShortString(), start_time, end_time, chunk.offset,
3291             chunk.size, overlapping_chunk.offset, overlapping_chunk.size);
3292       }
3293     }
3294     interval_tree.Add(start_time, end_time - 1, chunk);
3295     return Status::OK();
3296   };
3297 
3298   // Go through all instructions in the module to ensure CopyStart/CopyDone
3299   // instructions copy between alternate memory and default memory.
3300   for (const HloComputation* computation :
3301        module_->MakeNonfusionComputations()) {
3302     for (const HloInstruction* instruction : computation->instructions()) {
3303       if (instruction->opcode() == HloOpcode::kCopyStart) {
3304         int64 from_memory_space =
3305             ShapeUtil::GetSubshape(instruction->shape(), {1})
3306                 .layout()
3307                 .memory_space();
3308         int64 to_memory_space =
3309             ShapeUtil::GetSubshape(instruction->shape(), {0})
3310                 .layout()
3311                 .memory_space();
3312         CHECK_NE(from_memory_space, to_memory_space)
3313             << "Asynchronous copy to the same memory space: "
3314             << instruction->ToString();
3315       }
3316     }
3317   }
3318 
3319   for (const auto& position_and_chunk : preset_assignments_->chunks()) {
3320     const HloPosition& position = position_and_chunk.first;
3321     const Chunk& chunk = position_and_chunk.second;
3322     const HloBuffer& buffer =
3323         alias_analysis->GetUniqueBufferAt(position.instruction, position.index);
3324     CHECK(!seen_buffers.contains(buffer.id()))
3325         << "Multiple preset assignments for the same buffer: "
3326         << buffer.ToString() << ", pos: " << position.ToString()
3327         << ", off: " << chunk.offset << ", size: " << chunk.size;
3328     seen_buffers.insert(buffer.id());
3329 
3330     for (const HloValue* value : buffer.values()) {
3331       const HloLiveRange::TimeBound& time_bound =
3332           hlo_live_range->buffer_live_ranges().at(value);
3333       const HloInstruction* last_use_instruction = nullptr;
3334       int64 last_use_time = time_bound.start;
3335       for (const HloUse& use : value->uses()) {
3336         int64 use_time =
3337             hlo_live_range->instruction_schedule().at(use.instruction);
3338         if (use_time > last_use_time) {
3339           last_use_time = use_time;
3340           last_use_instruction = use.instruction;
3341         }
3342       }
3343 
3344       std::function<Status(const HloInstruction*, int64, int64,
3345                            absl::string_view)>
3346           split_conditional_buffer;
3347       split_conditional_buffer = [&](const HloInstruction* use_instruction,
3348                                      int64 start_time, int64 end_time,
3349                                      absl::string_view indent_string) {
3350         // Special case when verifying conditional: we internally split the use
3351         // of alternate memory in conditionals, so fish them out from the
3352         // conditionals.
3353         VLOG(3) << indent_string
3354                 << "Splitting conditional buffer: " << buffer.ToString()
3355                 << " value: " << value->ToShortString() << ": (" << start_time
3356                 << ", " << end_time << ") off: " << chunk.offset
3357                 << ", size: " << chunk.size;
3358         int64 earliest_computation_start_time = end_time;
3359         for (const HloComputation* called_computation :
3360              use_instruction->called_computations()) {
3361           earliest_computation_start_time =
3362               std::min(earliest_computation_start_time,
3363                        hlo_live_range->computation_span_times()
3364                            .at(called_computation)
3365                            .start);
3366           int64 parameter_time = -1;
3367           int64 last_use_time = -1;
3368           const HloInstruction* last_use_instruction = nullptr;
3369           for (const HloPosition& position : value->positions()) {
3370             if (position.instruction->opcode() == HloOpcode::kParameter &&
3371                 position.instruction->parent() == called_computation) {
3372               parameter_time = hlo_live_range->instruction_schedule().at(
3373                   position.instruction);
3374               break;
3375             }
3376           }
3377           for (const HloUse& use : value->uses()) {
3378             int64 use_time =
3379                 hlo_live_range->instruction_schedule().at(use.instruction);
3380             if (use.instruction->parent() == called_computation &&
3381                 use_time > last_use_time) {
3382               last_use_time = use_time;
3383               last_use_instruction = use.instruction;
3384             }
3385           }
3386           if (last_use_time != -1) {
3387             CHECK_NE(parameter_time, -1);
3388             VLOG(3) << indent_string
3389                     << " computation: " << called_computation->name() << ": ("
3390                     << parameter_time << ", " << last_use_time << ")";
3391             CHECK(last_use_instruction);
3392             if (last_use_instruction->opcode() == HloOpcode::kConditional) {
3393               // The last use is another (nested) conditional. Call this
3394               // function recursively.
3395               TF_RETURN_IF_ERROR(split_conditional_buffer(
3396                   last_use_instruction, parameter_time, last_use_time,
3397                   absl::StrCat(indent_string, "  ")));
3398             } else {
3399               last_use_time = std::min(last_use_time, end_time);
3400               TF_RETURN_IF_ERROR(add_allocation_and_verify(
3401                   parameter_time, last_use_time, chunk, value));
3402             }
3403           }
3404         }
3405         VLOG(3) << indent_string << " from beginning until first computation: ("
3406                 << start_time << ", " << (earliest_computation_start_time - 1)
3407                 << ")";
3408         TF_RETURN_IF_ERROR(add_allocation_and_verify(
3409             start_time, earliest_computation_start_time - 1, chunk, value));
3410         return Status::OK();
3411       };
3412 
3413       if (last_use_instruction &&
3414           last_use_instruction->opcode() == HloOpcode::kConditional) {
3415         TF_RETURN_IF_ERROR(split_conditional_buffer(
3416             last_use_instruction, time_bound.start, time_bound.end, " "));
3417       } else if (!value->uses().empty()) {
3418         last_use_time = std::min(last_use_time, time_bound.end);
3419         VLOG(3) << " buffer: " << buffer.ToString()
3420                 << " value: " << value->ToShortString() << ": ("
3421                 << time_bound.start << ", " << last_use_time
3422                 << ") off: " << chunk.offset << ", size: " << chunk.size;
3423         TF_RETURN_IF_ERROR(add_allocation_and_verify(
3424             time_bound.start, last_use_time, chunk, value));
3425       }
3426     }
3427   }
3428 
3429   HeapSimulatorTrace* heap_trace =
3430       &preset_assignments_
3431            ->assignment_information_for_space(options_.alternate_memory_space)
3432            ->heap_simulator_trace;
3433   int64 memory_usage = 0;
3434   int64 max_memory_usage = 0;
3435   for (const auto& event : events) {
3436     int64 time;
3437     bool is_free;
3438     int64 buffer_id;
3439     std::tie(time, is_free, buffer_id) = event.first;
3440     const HloValue* value;
3441     Chunk chunk;
3442     HeapSimulatorTrace::Event::Kind kind;
3443     std::tie(value, chunk, kind) = event.second;
3444     HeapSimulatorTrace::Event* heap_trace_event = heap_trace->add_events();
3445     heap_trace_event->set_kind(kind);
3446     heap_trace_event->set_buffer_id(buffer_id);
3447     heap_trace_event->set_instruction_name(value->instruction()->name());
3448     heap_trace_event->set_computation_name(
3449         value->instruction()->parent()->name());
3450 
3451     if (kind == HeapSimulatorTrace::Event::ALLOC) {
3452       memory_usage += chunk.size;
3453     } else {
3454       CHECK_EQ(kind, HeapSimulatorTrace::Event::FREE);
3455       memory_usage -= chunk.size;
3456     }
3457     max_memory_usage = std::max(max_memory_usage, memory_usage);
3458     VLOG(4) << "Memory usage: " << memory_usage << " at time: " << time;
3459   }
3460   VLOG(1) << "Max memory usage ignoring fragmentation: " << max_memory_usage;
3461 
3462   return Status::OK();
3463 }
3464 
3465 }  // namespace xla
3466