1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <string>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46 
47 namespace xla {
48 namespace {
49 
50 using ::tensorflow::strings::HumanReadableNumBytes;
51 
52 // Potential optimizations:
53 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
54 //   of candidates.
55 // . Cache IsRematerializable in Item?  Only correct if control
56 //   predecessors and successors don't change.
57 
58 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)59 bool IsRematerializable(const HloInstruction* instruction) {
60   if (instruction->opcode() == HloOpcode::kCopy) {
61     if (LayoutUtil::Equal(instruction->shape().layout(),
62                           instruction->operand(0)->shape().layout())) {
63       // Don't rematerialize copies added by copy insertion (layout doesn't
64       // change).
65       return false;
66     }
67   }
68 
69   // Don't rematerialize instructions with side effects or instructions which
70   // cannot be cloned safely.
71   switch (instruction->opcode()) {
72     case HloOpcode::kCall:
73     case HloOpcode::kConstant:
74     case HloOpcode::kConditional:
75     case HloOpcode::kAllReduce:
76     case HloOpcode::kCustomCall:
77     case HloOpcode::kParameter:
78     case HloOpcode::kWhile:
79       return false;
80     default:
81       return !instruction->HasSideEffect();
82   }
83 }
84 
85 // Checks whether an instruction can be rematerialized, by looking up the
86 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)87 bool CanBeRematerialized(
88     const HloInstruction* instruction,
89     absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
90   auto it = remat_able->find(instruction);
91   if (it != remat_able->end()) {
92     return it->second;
93   }
94   bool rematerializable = IsRematerializable(instruction);
95   (*remat_able)[instruction] = rematerializable;
96   return rematerializable;
97 }
98 
99 // Type holding a unique identifier for each Buffer object.
100 using BufferId = int64;
101 using BufferIdList = absl::InlinedVector<BufferId, 3>;
102 
103 // We wrap HloInstruction* with an Item that holds auxiliary
104 // per-instruction state.
105 struct Item {
106   HloInstruction* instruction;
107 
108   // True once the instruction is marked as placed (when BeginInstruction
109   // has been called for this instruction).
110   bool placed = false;
111 
112   // To avoid an infinite loop rematerializing the same set of
113   // instructions ad infinitum, keep a blacklist of instructions
114   // which should not be rematerialized.
115   bool blacklisted = false;
116 
117   // The buffers defined by this instruction.
118   BufferIdList buffers_defined;
119 
120   // The buffers used by this instruction.
121   BufferIdList buffers_used;
122 
123  private:
124   friend class InstructionList;
125 
126   // Items are arranged in a doubly linked list.
127   Item* next;
128   Item* prev;
129 
130   // List is ordered by position, which can however be duplicated as
131   // new instructions are inserted.  See InsertBeforeInstructions
132   // comment for details.
133   int64 position;
134 };
135 
136 using ItemList = absl::InlinedVector<Item*, 3>;
137 
138 // Class which maintains an ordered list of instructions with fast insertion
139 // before arbitrary elements.
140 class InstructionList {
141  public:
InstructionList(const HloInstructionSequence & order)142   explicit InstructionList(const HloInstructionSequence& order) {
143     int64 position = 0;
144     Item* last = nullptr;
145     for (HloInstruction* inst : order.instructions()) {
146       // Add a new item to the linked list.
147       Item* item = new Item;
148       item->next = nullptr;
149       item->prev = last;
150       if (last == nullptr) {
151         first_ = item;
152       } else {
153         last->next = item;
154       }
155       last = item;
156 
157       // Initially position numbers are uniquely assigned in order. Later as
158       // instructions are added with InsertBefore* methods, some instructions
159       // may have duplicate position numbers, but the values will be guaranteed
160       // to be monotonically increasing through the list, and so is still useful
161       // for quickly(-ish) determining the order of arbitrary instructions in
162       // the list.
163       item->instruction = inst;
164       item->position = position;
165       position++;
166 
167       item_map_[inst] = item;
168     }
169   }
170 
~InstructionList()171   ~InstructionList() {
172     for (Item* item = first_; item != nullptr;) {
173       Item* next = item->next;
174       delete item;
175       item = next;
176     }
177   }
178 
size() const179   size_t size() const { return item_map_.size(); }
180 
181   // For ordered iteration over items.
182   //    for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const183   Item* first() const { return first_; }
next(Item * item) const184   Item* next(Item* item) const { return item->next; }
185 
186   // Creates an Item for the given instruction, but doesn't add it to the list.
187   // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)188   Item* CreateItem(HloInstruction* inst) {
189     Item* item = new Item;
190     item->instruction = inst;
191     CHECK(item_map_.insert({inst, item}).second)
192         << "inserting inst twice " << inst->name();
193     return item;
194   }
195 
196   // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const197   Item* GetItem(const HloInstruction* inst) const {
198     auto iter = item_map_.find(inst);
199     CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
200     return iter->second;
201   }
202 
203   // Insert instruction 'to_insert' immediately before the earliest instruction
204   // in 'before_instructions'.
205   //
206   // Each instruction gets a non-decreasing ordinal number. We use this to let
207   // InsertBeforeInstructions quickly insert an instruction before the earliest
208   // instruction in a set of instructions.  If position_number_[a] <
209   // position_number_[b] then 'a' comes before 'b' in the list. If the position
210   // numbers are the same then nothing can be said about their order without
211   // examining the list.
212   //
213   // On object construction this ordinal is precisely the instruction's index
214   // in the list. Later, instructions inserted via InsertBefore receive
215   // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)216   void InsertBeforeInstructions(Item* to_insert,
217                                 absl::Span<Item* const> before_instructions) {
218     VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
219             << " before {"
220             << absl::StrJoin(before_instructions, ", ",
221                              [](string* out, Item* item) {
222                                absl::StrAppend(out, item->instruction->name());
223                              })
224             << "}";
225 
226     // Find the minimal position number of any instruction in
227     // 'before_instructions'.
228     CHECK(!before_instructions.empty());
229     Item* min_position_item = nullptr;
230     for (Item* item : before_instructions) {
231       if (min_position_item == nullptr ||
232           item->position < min_position_item->position) {
233         min_position_item = item;
234       }
235     }
236 
237     // Because more than one instruction in 'before_instructions' may have a
238     // position number of 'min_position_number', find the first such instruction
239     // with position number 'min_position_number'.
240 
241     // First find first instruction with the min position.
242     while (min_position_item->prev != nullptr &&
243            min_position_item->position == min_position_item->prev->position) {
244       min_position_item = min_position_item->prev;
245     }
246 
247     // Now scan forwards until we find one of the before_instructions.
248     while (!absl::c_linear_search(before_instructions, min_position_item)) {
249       min_position_item = min_position_item->next;
250     }
251     return InsertBefore(to_insert, min_position_item);
252   }
253 
Blacklist(const HloInstruction * inst)254   void Blacklist(const HloInstruction* inst) {
255     GetItem(inst)->blacklisted = true;
256   }
257 
258  private:
259   // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)260   void InsertBefore(Item* item, Item* before) {
261     VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
262             << before->instruction->name();
263     // Insert new item into linked list.
264     item->prev = before->prev;
265     item->next = before;
266     before->prev = item;
267     if (item->prev != nullptr) {
268       item->prev->next = item;
269     } else {
270       first_ = item;
271     }
272 
273     // Assign the same position number to the newly added instruction as
274     // 'before'. This guarantees monotonicity of the position numbers, but not
275     // uniqueness.
276     item->position = before->position;
277   }
278 
279   Item* first_;
280 
281   // Item for each instruction.
282   absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
283 };
284 
285 // Return the items which use the given LogicalBuffer. Sets
286 // has_indirect_users to whether any of the uses is indirect. A use is indirect
287 // if the instruction defining logical_buffer is not an operand of the use. This
288 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)289 ItemList GetUsers(const InstructionList& instruction_list,
290                   const LogicalBuffer* logical_buffer,
291                   const TuplePointsToAnalysis& points_to_analysis,
292                   bool* has_indirect_users) {
293   ItemList users;
294   // To identify uses iterate through all HloInstruction users of the
295   // BufferAliases of the logical buffer.
296   *has_indirect_users = false;
297   for (const BufferAlias& buffer_alias :
298        points_to_analysis.GetBufferAliases(*logical_buffer)) {
299     for (const HloInstruction* user : buffer_alias.instruction()->users()) {
300       if (points_to_analysis.DoesNotUseOperandBuffer(
301               buffer_alias.instruction(), buffer_alias.index(), user)) {
302         // The alias may be an operand of 'user', but the LogicalBuffer cannot
303         // possibly be used by the instruction so ignore 'user'. This is the
304         // case, for example, for the tuple element buffers in a GetTupleElement
305         // instruction (the GTE instruction only uses the pointer vector).
306         continue;
307       }
308       if (buffer_alias.instruction() != logical_buffer->instruction()) {
309         *has_indirect_users = true;
310       }
311       // A buffer may be used by the instruction via more than one alias. For
312       // example, a buffer which appears in more than one element of a tuple.
313       Item* user_item = instruction_list.GetItem(user);
314       if (!absl::c_linear_search(users, user_item)) {
315         users.push_back(user_item);
316       }
317     }
318   }
319   return users;
320 }
321 
322 // Class for tracking memory usage of a computation as the instructions are
323 // placed sequentially. Memory usage is the sum of the sizes of live values
324 // (LogicalBuffers) at the current point in the instruction sequence.
325 class MemoryUsageTracker {
326  public:
327   MemoryUsageTracker(
328       const HloComputation* computation,
329       const HloRematerialization::ShapeSizeFunction& size_function,
330       const TuplePointsToAnalysis& points_to_analysis,
331       const InstructionList& instruction_list);
332 
333   // Starts the placement of the given instruction. This adds the sizes of the
334   // LogicalBuffers defined by the instruction to the current memory
335   // usage. Placement is broken into two steps (BeginInstruction and
336   // EndInstruction) to accurately model memory usage. At BeginInstruction the
337   // memory for the output value(s) of the current instruction is allocated. At
338   // EndInstruction memory for dead operand(s) is freed.
339   Status BeginInstruction(Item* item);
340 
341   // Finishes the placement of the current instruction. This frees any dead
342   // operands or dead result of the instruction. This must be called after
343   // each call to BeginInstruction.
344   Status EndInstruction();
345 
346   // Returns the number of bytes that the current memory usage will be reduced
347   // if the given instruction is rematerialized.
348   int64 MemoryReducedIfRematerialized(Item* item) const;
349 
350   // Adjusts memory usage to account for the rematerialization of
351   // original_item for all remaining unplaced uses. The rematerialization
352   // is remat_item. This method should be called after the HLO graph has
353   // been transformed (rematerialization instruction created and connected to
354   // uses).
355   Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
356 
357   // Returns whether the given instruction has been placed (BeginInstruction
358   // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const359   bool IsPlaced(const HloInstruction* instruction) const {
360     return instruction_list_.GetItem(instruction)->placed;
361   }
362 
363   // Returns the current memory usage. This is the sum of sizes of all live
364   // values.
memory_usage() const365   int64 memory_usage() const { return memory_usage_; }
366 
367   // Check invariants of the data structure. This is expensive to call.
368   bool Check() const;
369 
370   string ToString() const;
371 
372  private:
373   // A Buffer represents a single LogicalBuffer in the computation including
374   // various metadata useful for tracking liveness of the value. A LogicalBuffer
375   // is not used directly because the HLO graph is transformed and
376   // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
377   // HLO graph transformations.
378   struct Buffer {
379     // The unique id of this Buffer. This value is equal to the buffer's index
380     // in the vector buffers_.
381     const BufferId id;
382 
383     // The instruction which defines this buffer.
384     Item* defining_instruction;
385 
386     // The materialized size of the buffer in bytes.
387     const int64 size;
388 
389     // Whether this buffer is live-out of the computation.
390     bool live_out;
391 
392     // Whether this buffer has indirect uses. Ie, an instruction which is not a
393     // user of defining_instruction uses this buffer. This can occur due to
394     // buffer aliasing (eg, tuples).
395     bool has_indirect_uses;
396 
397     // The instructions which use this buffer.
398     ItemList users;
399 
400     // The number of users (HloInstructions) of this buffer which have not yet
401     // been placed in the sequence.
402     int64 unfinished_user_count;
403 
ToStringxla::__anon6c3ee80d0111::MemoryUsageTracker::Buffer404     string ToString() const {
405       return absl::StrCat("Buffer ", id, " (defined by ",
406                           defining_instruction->instruction->name(), ", size ",
407                           size, " bytes)");
408     }
409   };
410 
411   // Creates a Buffer representing the given logical buffer. The buffer is added
412   // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,const HloRematerialization::ShapeSizeFunction & size_function,bool live_out)413   Buffer& CreateBufferFromLogicalBuffer(
414       const LogicalBuffer* logical_buffer,
415       const TuplePointsToAnalysis& points_to_analysis,
416       const HloRematerialization::ShapeSizeFunction& size_function,
417       bool live_out) {
418     bool has_indirect_uses = false;
419     ItemList users = GetUsers(instruction_list_, logical_buffer,
420                               points_to_analysis, &has_indirect_uses);
421     return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
422                      size_function(logical_buffer->shape()), std::move(users),
423                      live_out, has_indirect_uses);
424   }
425 
426   // Create a new buffer representing a rematerialization of given buffer for
427   // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,ItemList && rematerialized_uses)428   Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
429                               ItemList&& rematerialized_uses) {
430     CHECK(original_buffer.defining_instruction->placed)
431         << original_buffer.defining_instruction->instruction->name();
432     CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
433     CHECK(!original_buffer.live_out) << original_buffer.ToString();
434     for (Item* use : rematerialized_uses) {
435       CHECK(!use->placed) << use->instruction->name();
436     }
437     return NewBuffer(remat_item, original_buffer.size,
438                      std::move(rematerialized_uses), /*live_out=*/false,
439                      /*has_indirect_uses=*/false);
440   }
441 
442   // Return number of bytes allocated for the buffer with the given id. Buffers
443   // allocated by the calling computation (eg, parameter and output buffers) are
444   // considered to have zero bytes because the memory is accounted for in a
445   // different computation.
AllocatedSize(BufferId buffer_id) const446   int64 AllocatedSize(BufferId buffer_id) const {
447     const Buffer& buffer = buffers_.at(buffer_id);
448     HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
449     if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
450       return 0;
451     } else {
452       return buffer.size;
453     }
454   }
455 
456   // Returns true if BeginInstruction and EndInstruction has been called for the
457   // given instruction.
IsFinished(Item * item) const458   bool IsFinished(Item* item) const {
459     return item->placed && item != in_progress_item_;
460   }
461 
462   // Returns whether the given buffer is being used by the in-progress
463   // instruction.
IsInUse(BufferId buffer_id) const464   bool IsInUse(BufferId buffer_id) const {
465     if (in_progress_item_ == nullptr) {
466       return false;
467     }
468     const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
469     return absl::c_linear_search(in_progress_uses, buffer_id);
470   }
471 
472   // Returns whether the given instruction is live at the current program
473   // point.
IsCurrentlyLive(BufferId buffer_id) const474   bool IsCurrentlyLive(BufferId buffer_id) const {
475     const Buffer& buffer = buffers_[buffer_id];
476     return (buffer.defining_instruction->placed &&
477             buffer.unfinished_user_count > 0);
478   }
479 
480   // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,int64 size,ItemList && users,bool live_out,bool has_indirect_uses)481   Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
482                     bool live_out, bool has_indirect_uses) {
483     int buffer_id = buffers_.size();
484     buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
485                               has_indirect_uses, users,
486                               static_cast<int64>(users.size())});
487     return buffers_.back();
488   }
489 
490   const HloComputation* computation_;
491 
492   // Instruction list containing the ordering of instructions in
493   // computation_. This is the order in which instructions are placed
494   // (BeginInstruction/EndInstruction calls).
495   const InstructionList& instruction_list_;
496 
497   // Memory usage at the currently placed instruction.
498   int64 memory_usage_ = 0;
499 
500   // The instruction currently being placed. This value is non-null only
501   // between the calling of BeginInstruction and EndInstruction.
502   Item* in_progress_item_ = nullptr;
503 
504   // All buffers in the computation.
505   std::vector<Buffer> buffers_;
506 };
507 
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list)508 MemoryUsageTracker::MemoryUsageTracker(
509     const HloComputation* computation,
510     const HloRematerialization::ShapeSizeFunction& size_function,
511     const TuplePointsToAnalysis& points_to_analysis,
512     const InstructionList& instruction_list)
513     : computation_(computation), instruction_list_(instruction_list) {
514   PointsToSet::BufferSet live_out_set =
515       points_to_analysis.GetPointsToSet(computation_->root_instruction())
516           .CreateFlattenedSet();
517   absl::flat_hash_map<const LogicalBuffer*, BufferId>
518       logical_buffer_to_buffer_id;
519 
520   for (auto* item = instruction_list_.first(); item != nullptr;
521        item = instruction_list_.next(item)) {
522     const HloInstruction* const instruction = item->instruction;
523     for (const LogicalBuffer* logical_buffer :
524          points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
525       Buffer* buffer;
526       if (instruction->opcode() == HloOpcode::kWhile) {
527         // The while instruction defines no new buffers. Instead it reuses the
528         // buffers of its operand. Find the Buffer of its operand at the
529         // proper ShapeIndex.
530         const PointsToSet& operand_points_to =
531             points_to_analysis.GetPointsToSet(instruction->operand(0));
532         CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
533         const LogicalBuffer* source_logical_buffer =
534             operand_points_to.element(logical_buffer->index())[0];
535         buffer =
536             &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
537 
538         // Mark buffer as has indirect use and live out.
539         buffer->has_indirect_uses = true;
540         buffer->live_out =
541             buffer->live_out || ContainsKey(live_out_set, logical_buffer);
542 
543         // Add users of while to Buffer users.
544         bool unused;
545         for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
546                                         points_to_analysis, &unused)) {
547           if (!absl::c_linear_search(buffer->users, user_item)) {
548             buffer->users.push_back(user_item);
549             buffer->unfinished_user_count++;
550             user_item->buffers_used.push_back(buffer->id);
551           }
552         }
553       } else {
554         buffer = &CreateBufferFromLogicalBuffer(
555             logical_buffer, points_to_analysis, size_function,
556             ContainsKey(live_out_set, logical_buffer));
557         item->buffers_defined.push_back(buffer->id);
558         for (Item* user : buffer->users) {
559           user->buffers_used.push_back(buffer->id);
560         }
561       }
562 
563       logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
564     }
565   }
566   XLA_VLOG_LINES(10, ToString());
567   DCHECK(Check());
568 }
569 
BeginInstruction(Item * item)570 Status MemoryUsageTracker::BeginInstruction(Item* item) {
571   const HloInstruction* instruction = item->instruction;
572   VLOG(3) << "BeginInstruction " << instruction->name();
573   TF_RET_CHECK(in_progress_item_ == nullptr);
574   in_progress_item_ = item;
575 
576   item->placed = true;
577 
578   // All buffers defined by this instruction need memory.
579   for (BufferId buffer_id : item->buffers_defined) {
580     VLOG(3) << "  Buffer " << buffers_.at(buffer_id).ToString()
581             << " is now live.";
582     memory_usage_ += AllocatedSize(buffer_id);
583   }
584 
585   // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
586   // operand. Account for this potential reuse here.
587 
588   VLOG(3) << "  memory usage = " << memory_usage_;
589   VLOG(10) << ToString();
590 
591   if (VLOG_IS_ON(1)) {
592     DCHECK(Check());
593   }
594   return Status::OK();
595 }
596 
EndInstruction()597 Status MemoryUsageTracker::EndInstruction() {
598   TF_RET_CHECK(in_progress_item_ != nullptr);
599   VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
600 
601   for (BufferId buffer_id : in_progress_item_->buffers_used) {
602     Buffer& buffer = buffers_.at(buffer_id);
603     buffer.unfinished_user_count--;
604     CHECK_GE(buffer.unfinished_user_count, 0)
605         << buffer.ToString() << " has negative unfinished use count.";
606     if (buffer.unfinished_user_count == 0) {
607       // Buffer is now dead.
608       VLOG(3) << "  " << buffer.ToString() << " is now dead.";
609       memory_usage_ -= AllocatedSize(buffer_id);
610       CHECK_GE(memory_usage_, 0);
611     }
612   }
613 
614   // If any buffer defined by this instruction has no uses, then memory can be
615   // reclaimed immediately.
616   for (BufferId buffer_id : in_progress_item_->buffers_defined) {
617     const Buffer& buffer = buffers_.at(buffer_id);
618     if (buffer.unfinished_user_count == 0) {
619       VLOG(3) << "  " << buffer.ToString() << " is immediately dead.";
620       memory_usage_ -= AllocatedSize(buffer_id);
621       CHECK_GE(memory_usage_, 0);
622     }
623   }
624 
625   in_progress_item_ = nullptr;
626 
627   VLOG(3) << "  memory usage = " << memory_usage_;
628   VLOG(10) << ToString();
629 
630   if (VLOG_IS_ON(1)) {
631     DCHECK(Check());
632   }
633   return Status::OK();
634 }
635 
MemoryReducedIfRematerialized(Item * item) const636 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
637   CHECK_NE(in_progress_item_, nullptr);
638   if (!item->placed || item == in_progress_item_) {
639     return 0;
640   }
641 
642   // TODO(b/37687140): Rematerialization can increase peak memory consumption at
643   // an earlier point in the program if rematerialization extends the live range
644   // of the operand of the instruction being rematerialized across the live
645   // range of the value of instruction being rematerialized. Don't rematerialize
646   // in this case (ie, return 0 here).
647 
648   // Compute the amount of memory reduced (if any) by rematerializing
649   // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
650   // be live at this program point, so initially set memory_reduced to the
651   // size of its defined values.
652   int64 memory_reduced = 0;
653   for (BufferId buffer_id : item->buffers_defined) {
654     // Avoid rematerializing instructions with indirect uses as it is difficult
655     // to reason about liveness after rematerializing the instruction.
656     // TODO(b/37714814): Consider rematerialzing instructions with indirect
657     // uses.
658     if (buffers_.at(buffer_id).has_indirect_uses) {
659       return 0;
660     }
661 
662     if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
663       memory_reduced += AllocatedSize(buffer_id);
664     }
665   }
666 
667   // Account for any logical buffers whose live range must be extended across
668   // this program point.
669   for (BufferId buffer_id : item->buffers_used) {
670     if (!IsCurrentlyLive(buffer_id)) {
671       // This logical buffer is used by 'instruction' but is not live at this
672       // program point. Rematerializing 'instruction' will extend the buffer's
673       // live range across this program point.
674       memory_reduced -= AllocatedSize(buffer_id);
675     }
676   }
677 
678   return memory_reduced;
679 }
680 
AddRematerializedInstruction(Item * original_item,Item * remat_item)681 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
682                                                         Item* remat_item) {
683   VLOG(3) << "AddRematerializedInstruction: original_instruction = "
684           << original_item->instruction->name()
685           << ", remat_instruction = " << remat_item->instruction->name();
686 
687   TF_RET_CHECK(in_progress_item_ != nullptr);
688   TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
689   TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
690 
691   // Construct the list of buffers used and defined by the rematerialization.
692   remat_item->buffers_used = original_item->buffers_used;
693 
694   // Account for the additional buffer uses created by the new rematerialization
695   // instruction. Update memory usage if the rematerialization makes a dead
696   // buffer live again.
697   for (BufferId buffer_id : original_item->buffers_used) {
698     Buffer& buffer = buffers_.at(buffer_id);
699     if (buffer.unfinished_user_count == 0) {
700       // Buffer used by this instruction was dead, now is alive.
701       memory_usage_ += AllocatedSize(buffer.id);
702     }
703 
704     buffer.unfinished_user_count++;
705     buffer.users.push_back(remat_item);
706   }
707 
708   // Create a new set of Buffers defined by the new rematerialization
709   // instruction. Update the internal data structures and memory use to account
710   // for them.
711   for (BufferId old_buffer_id : original_item->buffers_defined) {
712     Buffer& old_buffer = buffers_.at(old_buffer_id);
713 
714     ItemList placed_users;
715     ItemList unplaced_users;
716     for (Item* user : old_buffer.users) {
717       if (user->placed) {
718         CHECK(IsFinished(user)) << user->instruction->name();
719         placed_users.push_back(user);
720       } else {
721         unplaced_users.push_back(user);
722       }
723     }
724     old_buffer.users = std::move(placed_users);
725     old_buffer.unfinished_user_count = 0;
726 
727     // Buffer is now dead.
728     memory_usage_ -= AllocatedSize(old_buffer.id);
729 
730     Buffer& new_buffer =
731         RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
732 
733     remat_item->buffers_defined.push_back(new_buffer.id);
734     for (Item* user : new_buffer.users) {
735       BufferIdList& buffers_used = user->buffers_used;
736       std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
737                    new_buffer.id);
738     }
739   }
740 
741   VLOG(3) << "  memory usage = " << memory_usage_;
742   XLA_VLOG_LINES(10, ToString());
743 
744   DCHECK(Check());
745 
746   return Status::OK();
747 }
748 
ToString() const749 string MemoryUsageTracker::ToString() const {
750   string output =
751       absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
752   absl::StrAppend(&output,
753                   "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
754                   memory_usage(), " bytes)");
755   for (auto* item = instruction_list_.first(); item != nullptr;
756        item = instruction_list_.next(item)) {
757     const HloInstruction* instruction = item->instruction;
758     string inprogress = item == in_progress_item_ ? " in-progress" : "";
759     string placed = item->placed ? " placed" : "";
760     absl::StrAppend(&output, "  ", instruction->name(), inprogress, placed,
761                     "\n    Defines:\n");
762     for (BufferId buffer_id : item->buffers_defined) {
763       const Buffer& buffer = buffers_[buffer_id];
764       string live = IsCurrentlyLive(buffer_id) ? " live" : "";
765       absl::StrAppend(&output, "      ", buffer.ToString(), live, ", ",
766                       buffer.unfinished_user_count, " unfinished uses\n");
767     }
768     absl::StrAppend(&output, "    Uses:\n");
769     for (BufferId buffer_id : item->buffers_used) {
770       absl::StrAppend(&output, "      ", buffers_[buffer_id].ToString(), "\n");
771     }
772   }
773   return output;
774 }
775 
Check() const776 bool MemoryUsageTracker::Check() const {
777   auto elements_are_unique = [](const BufferIdList& vec) {
778     return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
779   };
780 
781   // Verify buffers_defined per instruction.
782   for (auto* instruction : computation_->instructions()) {
783     const BufferIdList& defined_buffers =
784         instruction_list_.GetItem(instruction)->buffers_defined;
785     CHECK(elements_are_unique(defined_buffers))
786         << "Instruction " << instruction->name()
787         << " does not have unique defined buffers: "
788         << absl::StrJoin(
789                defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
790                  absl::StrAppend(out, buffers_.at(buffer_id).ToString());
791                });
792 
793     for (const Buffer& buffer : buffers_) {
794       if (buffer.defining_instruction->instruction == instruction) {
795         CHECK(absl::c_linear_search(defined_buffers, buffer.id))
796             << "Instruction " << instruction->name()
797             << " defined buffers is missing: " << buffer.ToString();
798       }
799     }
800   }
801 
802   // Verify buffers_used per instruction.
803   for (auto* instruction : computation_->instructions()) {
804     const BufferIdList& used_buffers =
805         instruction_list_.GetItem(instruction)->buffers_used;
806     CHECK(elements_are_unique(used_buffers))
807         << "Instruction " << instruction->name()
808         << " does not have unique used buffers: "
809         << absl::StrJoin(
810                used_buffers, ", ", [this](string* out, BufferId buffer_id) {
811                  absl::StrAppend(out, buffers_.at(buffer_id).ToString());
812                });
813   }
814   for (const Buffer& buffer : buffers_) {
815     int64 unfinished_uses = 0;
816     for (Item* user : buffer.users) {
817       const BufferIdList& used_buffers = user->buffers_used;
818       CHECK(absl::c_linear_search(used_buffers, buffer.id))
819           << "Instruction " << user->instruction->name()
820           << " used buffers is missing " << buffer.ToString();
821       if (!IsFinished(user)) {
822         unfinished_uses++;
823       }
824     }
825     CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
826         << "Incorrect unplaced use count for " << buffer.ToString();
827   }
828   return true;
829 }
830 
831 // Computes and returns the cost of rematerializing the given instruction.
832 // Cost per rematerialized instruction is defined as:
833 //
834 // memory_limit_bytes / memory_reduced
835 //
836 // The idea is to choose the operation that will save the most memory for
837 // rematerialization and do not worry about how much the compute costs since
838 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64 memory_reduced,int64 memory_limit_bytes)839 int64 RematerializationCost(const HloInstruction* instruction,
840                             const MemoryUsageTracker& memory_tracker,
841                             int64 memory_reduced, int64 memory_limit_bytes) {
842   // If none of the users of 'instruction' have been placed in the sequence (as
843   // tracked by memory_tracker), then rematerialization of 'instruction' is a
844   // zero-cost move of 'instruction' in the sequence.
845   if (!absl::c_any_of(instruction->users(),
846                       [&memory_tracker](const HloInstruction* inst) {
847                         return memory_tracker.IsPlaced(inst);
848                       })) {
849     return 0;
850   }
851 
852   CHECK_GT(memory_reduced, 0);
853   // Return the inverse of the benefit of rematerialization.
854   return memory_limit_bytes / memory_reduced;
855 }
856 
857 // Selects and returns the best candidate instruction for rematerialization.
858 // The instruction with lowest rematerialization cost is selected among those
859 // candidate which reduce memory use at the program point of the current
860 // instruction as indicated by memory_tracker. nullptr is returned if no
861 // candidate can be found.
PickRematerializationCandidate(const MemoryUsageTracker & memory_tracker,const InstructionList & instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)862 Item* PickRematerializationCandidate(
863     const MemoryUsageTracker& memory_tracker,
864     const InstructionList& instruction_list, int64 memory_limit_bytes,
865     absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
866   Item* best_item = nullptr;
867   int64 best_cost = 0;
868 
869   // TODO(b/35244891): This is currently quadratic in the number of HLO
870   // instructions.
871   for (auto* item = instruction_list.first(); item != nullptr;
872        item = instruction_list.next(item)) {
873     if (!item->placed) {
874       // Only iterate up to the currently placed instruction.
875       // We are trying to reduce memory usage at the placed
876       // instruction so rematerializing later values is of no benefit.
877       break;
878     }
879     HloInstruction* candidate = item->instruction;
880     VLOG(5) << "considering rematerialization candidate " << candidate->name();
881 
882     if (item->blacklisted) {
883       // Skip instructions on the blacklist to avoid infinite loops of
884       // rematerializing the same instruction(s) repeatedly.
885       VLOG(5) << "candidate " << candidate->name()
886               << " is excluded from rematerialization";
887       continue;
888     }
889     if (!CanBeRematerialized(candidate, remat_able)) {
890       VLOG(5) << "candidate " << candidate->name()
891               << " not viable: is not rematerializable";
892       continue;
893     }
894 
895     // If any of the candidate's control successor has been placed, we need to
896     // skip this candidate. Otherwise we will violate control dependency.
897     bool control_successor_placed =
898         std::any_of(candidate->control_successors().begin(),
899                     candidate->control_successors().end(),
900                     [&memory_tracker](const HloInstruction* inst) {
901                       return memory_tracker.IsPlaced(inst);
902                     });
903 
904     if (control_successor_placed) {
905       continue;
906     }
907 
908     const int64 memory_reduced =
909         memory_tracker.MemoryReducedIfRematerialized(item);
910 
911     if (memory_reduced <= 0) {
912       VLOG(5) << "candidate " << candidate->name()
913               << " memory reduced = " << memory_reduced << " <=  0";
914       continue;
915     }
916 
917     const int cost = RematerializationCost(candidate, memory_tracker,
918                                            memory_reduced, memory_limit_bytes);
919 
920     VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
921             << memory_reduced << ", cost per byte " << cost;
922 
923     if (best_item == nullptr || cost < best_cost) {
924       VLOG(5) << "candidate " << candidate->name() << " now best";
925       best_item = item;
926       best_cost = cost;
927     }
928   }
929   return best_item;
930 }
931 
932 }  // namespace
933 
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order) const934 StatusOr<int64> HloRematerialization::ComputePeakMemory(
935     const HloComputation* computation,
936     const HloInstructionSequence& order) const {
937   InstructionList instruction_list(order);
938   MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
939                              instruction_list);
940   int64 peak_memory = tracker.memory_usage();
941   for (auto* item = instruction_list.first(); item != nullptr;
942        item = instruction_list.next(item)) {
943     const HloInstruction* instruction = item->instruction;
944     TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
945     TF_ASSIGN_OR_RETURN(int64 callee_usage,
946                         CalledComputationsMemoryUsage(instruction));
947     peak_memory =
948         std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
949     TF_RETURN_IF_ERROR(tracker.EndInstruction());
950   }
951   VLOG(1) << "Peak memory for " << computation->name() << ": "
952           << HumanReadableNumBytes(peak_memory);
953   return peak_memory;
954 }
955 
CalledComputationsMemoryUsage(const HloInstruction * instruction) const956 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
957     const HloInstruction* instruction) const {
958   const CallSite* callsite =
959       call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
960   if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
961     return 0;
962   }
963   int64 callee_usage = 0;
964   for (const HloComputation* computation : callsite->called_computations()) {
965     TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
966     callee_usage += computation_peak_memory_.at(computation);
967   }
968   return callee_usage;
969 }
970 
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64 memory_limit_bytes)971 StatusOr<bool> HloRematerialization::RematerializeComputation(
972     HloComputation* computation, HloSchedule* schedule,
973     int64 memory_limit_bytes) {
974   VLOG(1) << "Rematerializing computation " << computation->name()
975           << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
976   VLOG(1) << "peak memory usage is "
977           << HumanReadableNumBytes(computation_peak_memory_.at(computation));
978   CHECK(!ContainsKey(rematerialized_computations_, computation));
979 
980   InstructionList instruction_list(schedule->sequence(computation));
981   MemoryUsageTracker memory_tracker(computation, size_function_,
982                                     *points_to_analysis_, instruction_list);
983   bool changed = false;
984 
985   // If the rematerialization makes the source instruction dead, then the
986   // rematerialization is added to 'remat_move_instructions' (the
987   // rematerialization is essentially a move). If the next rematerialization of
988   // the instruction is also a move then the rematerialization is added to the
989   // blacklist.
990   absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
991 
992   // The map from instructions to their rematerializable status.
993   absl::flat_hash_map<const HloInstruction*, bool> remat_able;
994 
995   // The peak memory of the computation at any point in the instruction
996   // sequence.
997   int64 peak_memory = memory_tracker.memory_usage();
998 
999   // Total count of instructions rematerialized.
1000   int64 remat_count = 0;
1001   // Total count of clones created minus number of original rematerialized
1002   // instructions which are dead.
1003   int64 net_instructions_added = 0;
1004 
1005   const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1006 
1007   // Iterate through all instructions in the sequence. At each instruction
1008   // (program point) if memory_usage exceeds the specified limit then
1009   // rematerialize HLO instructions until memory_usage is reduced.
1010   int64 instruction_index = 0;
1011   for (auto* item = instruction_list.first(); item != nullptr;
1012        item = instruction_list.next(item)) {
1013     const HloInstruction* instruction = item->instruction;
1014     TF_ASSIGN_OR_RETURN(int64 callee_usage,
1015                         CalledComputationsMemoryUsage(instruction));
1016     TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1017 
1018     VLOG(2) << "Program point at " << instruction->name()
1019             << ", memory usage = " << memory_tracker.memory_usage()
1020             << ", callee usage = " << callee_usage << ", [" << instruction_index
1021             << "/" << instruction_list.size() << "]";
1022     instruction_index++;
1023 
1024     while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1025       VLOG(2) << "Over memory limit at instruction " << instruction->name()
1026               << ", using "
1027               << HumanReadableNumBytes(memory_tracker.memory_usage() +
1028                                        callee_usage)
1029               << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1030 
1031       Item* best_item = PickRematerializationCandidate(
1032           memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
1033 
1034       if (best_item == nullptr) {
1035         VLOG(3) << "Unable to find rematerialization candidate at program "
1036                    "point "
1037                 << instruction->name() << ". Memory usage = "
1038                 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1039                                          callee_usage);
1040         break;
1041       }
1042 
1043       HloInstruction* best = best_item->instruction;
1044       VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
1045               << HumanReadableNumBytes(
1046                      memory_tracker.MemoryReducedIfRematerialized(best_item))
1047               << ")";
1048       changed = true;
1049       remat_count++;
1050 
1051       HloInstruction* remat =
1052           computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1053 
1054       // Add control dependencies to the new operation.
1055       for (auto successor : best->control_successors()) {
1056         TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1057       }
1058       for (auto predecessor : best->control_predecessors()) {
1059         TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1060       }
1061 
1062       Item* remat_item = instruction_list.CreateItem(remat);
1063 
1064       // Replace each remaining use of 'best' with the rematerialization.
1065       std::vector<HloInstruction*> best_users_copy = best->users();
1066       for (HloInstruction* user : best_users_copy) {
1067         if (!memory_tracker.IsPlaced(user)) {
1068           VLOG(2) << "  Replacing use of " << best->name() << " in "
1069                   << user->name() << " with " << remat->name();
1070           TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
1071         }
1072       }
1073 
1074       // Account for the rematerialization in the memory tracker.
1075       TF_RETURN_IF_ERROR(
1076           memory_tracker.AddRematerializedInstruction(best_item, remat_item));
1077 
1078       // Insert rematerialized instruction right before the earliest unplaced
1079       // use of the instruction *and* the earliest unplaced last use of any
1080       // operands of remat. Unplaced uses of the remat's operands are included
1081       // because we don't want to extend the live range of remat's operands as
1082       // this could increase memory usage.
1083       ItemList place_before;
1084       for (auto user : remat->users()) {
1085         place_before.push_back(instruction_list.GetItem(user));
1086       }
1087       for (auto* operand : remat->operands()) {
1088         for (auto* operand_user : operand->users()) {
1089           if (operand_user != remat) {
1090             Item* operand_user_item = instruction_list.GetItem(operand_user);
1091             if (!operand_user_item->placed) {
1092               place_before.push_back(operand_user_item);
1093             }
1094           }
1095         }
1096       }
1097       // Insert rematerialized instruction before any of its successors to
1098       // preserve ordering regarding control dependency.
1099       for (auto successor : remat->control_successors()) {
1100         Item* successor_item = instruction_list.GetItem(successor);
1101         // Assert to make sure we never remat an operation with control
1102         // successor already placed.
1103         CHECK(!successor_item->placed) << successor_item->instruction->name();
1104         place_before.push_back(successor_item);
1105       }
1106       instruction_list.InsertBeforeInstructions(remat_item, place_before);
1107 
1108       // If the rematerialized instruction is dead then rematerialization is
1109       // essentially a move. Don't delete the instruction now because we don't
1110       // want duplicate HloInstruction* values during the course of the
1111       // transformation because we keep maps with HloInstruction* values as
1112       // keys.
1113       if (best->users().empty()) {
1114         VLOG(2) << best->name() << " is now dead";
1115         if (ContainsKey(remat_move_instructions, best)) {
1116           // Previously, 'best' was a rematerialization which killed the
1117           // instruction it was a copying of. Now 'remat' is a rematerialization
1118           // of 'best' and kills 'best'. Stop rematerializing this instruction
1119           // to avoid an infinite loop.
1120           instruction_list.Blacklist(remat);
1121         }
1122         remat_move_instructions.insert(remat);
1123       } else {
1124         net_instructions_added++;
1125       }
1126 
1127       VLOG(1) << "memory_usage after rematerialization = "
1128               << HumanReadableNumBytes(memory_tracker.memory_usage());
1129     }
1130 
1131     const CallSite* callsite = call_graph_node.GetCallSite(instruction);
1132     if (callsite != nullptr &&
1133         callsite->context() == CallContext::kSequential &&
1134         memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1135       // Memory usage exceeds the limit. Try to rematerialize any
1136       // subcomputation(s) that this instruction calls.
1137       VLOG(1) << "Memory usage still over the limit ("
1138               << (memory_tracker.memory_usage() + callee_usage) << " > "
1139               << memory_limit_bytes
1140               << "). Rematerializing computations called by "
1141               << instruction->name();
1142 
1143       // Recompute callee usage to account for any rematerialization performed
1144       // in the callee computations.
1145       for (HloComputation* called_computation :
1146            callsite->called_computations()) {
1147         if (!ContainsKey(rematerialized_computations_, called_computation)) {
1148           // Memory limit for the subcomputation is the memory limit less the
1149           // amount of memory used at this point in the computation.
1150           int64 subcomputation_memory_limit_bytes = std::max<int64>(
1151               0, memory_limit_bytes - memory_tracker.memory_usage());
1152           TF_ASSIGN_OR_RETURN(
1153               bool subcomputation_changed,
1154               RematerializeComputation(called_computation, schedule,
1155                                        subcomputation_memory_limit_bytes));
1156           changed |= subcomputation_changed;
1157         }
1158       }
1159       TF_ASSIGN_OR_RETURN(callee_usage,
1160                           CalledComputationsMemoryUsage(instruction));
1161     }
1162 
1163     peak_memory = std::max<int64>(peak_memory,
1164                                   memory_tracker.memory_usage() + callee_usage);
1165     VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
1166 
1167     TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
1168   }
1169 
1170   // Verify some invariants on the memory tracker.
1171   CHECK_EQ(memory_tracker.memory_usage(), 0);
1172   for (auto* instruction : computation->instructions()) {
1173     CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
1174   }
1175 
1176   VLOG(1) << "In computation " << computation->name() << " rematerialized "
1177           << remat_count << " instructions; " << net_instructions_added
1178           << " net instructions added";
1179   VLOG(1) << "  peak memory usage now " << HumanReadableNumBytes(peak_memory)
1180           << " (was "
1181           << HumanReadableNumBytes(computation_peak_memory_.at(computation))
1182           << ")";
1183 
1184   // Update peak memory used by computation.
1185   computation_peak_memory_.at(computation) = peak_memory;
1186 
1187   // Update order to include rematerialized instructions.
1188   HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
1189   sequence.clear();
1190   for (auto* item = instruction_list.first(); item != nullptr;
1191        item = instruction_list.next(item)) {
1192     HloInstruction* instruction = item->instruction;
1193     sequence.push_back(instruction);
1194   }
1195   rematerialized_computations_.insert(computation);
1196 
1197   instructions_rematerialized_ += remat_count;
1198   net_instructions_added_ += net_instructions_added;
1199 
1200   return changed;
1201 }
1202 
Run(HloModule * module)1203 StatusOr<bool> HloRematerialization::Run(HloModule* module) {
1204   VLOG(1) << "HloRematerialization() with memory limit of "
1205           << HumanReadableNumBytes(memory_limit_bytes_);
1206   XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
1207 
1208   // Initialize pass object state.
1209   computation_peak_memory_.clear();
1210   rematerialized_computations_.clear();
1211   instructions_rematerialized_ = 0;
1212   net_instructions_added_ = 0;
1213 
1214   TF_RET_CHECK(module->has_schedule());
1215   TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
1216 
1217   // Adjust memory limit to account for the output of the entry
1218   // computation. This is necessary because the per-computation accounting in
1219   // MemoryUsageTracker do not include output as these are typically allocated
1220   // by the caller.
1221   int64 module_output_size = 0;
1222   ShapeUtil::ForEachSubshape(
1223       module->result_shape(),
1224       [&module_output_size, this](const Shape& subshape,
1225                                   const ShapeIndex& /*index*/) {
1226         module_output_size += size_function_(subshape);
1227       });
1228 
1229   const int64 adjusted_memory_limit_bytes =
1230       memory_limit_bytes_ - module_output_size;
1231   VLOG(1) << "Adjusted memory limit accounting for output ("
1232           << HumanReadableNumBytes(module_output_size)
1233           << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
1234 
1235   // Compute peak memory usage of all computations in the module called in a
1236   // sequential context.
1237   call_graph_ = CallGraph::Build(module);
1238   TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
1239       [this, module](const CallGraphNode& node) -> Status {
1240         if (node.context() == CallContext::kSequential) {
1241           TF_ASSIGN_OR_RETURN(
1242               computation_peak_memory_[node.computation()],
1243               ComputePeakMemory(node.computation(), module->schedule().sequence(
1244                                                         node.computation())));
1245         }
1246         return Status::OK();
1247       },
1248       /*visit_unreachable_nodes=*/false));
1249 
1250   // The peak memory usage of the module equals the peak memory use of the entry
1251   // computation plus the output size of the computation. This is because the
1252   // peak memory for a computation does not include the output as this is
1253   // typically accounted for in the caller.
1254   const int64 before_peak_memory =
1255       computation_peak_memory_.at(module->entry_computation()) +
1256       module_output_size;
1257   VLOG(1) << "Peak memory usage of module (before): "
1258           << HumanReadableNumBytes(before_peak_memory);
1259 
1260   // Subcomputations called by the entry computation will also be
1261   // rematerialized.
1262   TF_ASSIGN_OR_RETURN(
1263       bool changed,
1264       RematerializeComputation(module->entry_computation(), &module->schedule(),
1265                                adjusted_memory_limit_bytes));
1266 
1267   // Rematerialization can introduce dead code. This occurs if all uses of an
1268   // instruction are replaced with rematerializations of the instruction.
1269   TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
1270   changed |= dead_code_removed;
1271 
1272   // After DCE, the module sequence may include instructions which no longer
1273   // exist.
1274   TF_RETURN_IF_ERROR(module->schedule().Update());
1275   VLOG(1) << "Rematerialized " << instructions_rematerialized_
1276           << " instructions in module " << module->name() << "; "
1277           << net_instructions_added_ << " net instructions added";
1278   const int64 current_peak_memory =
1279       computation_peak_memory_.at(module->entry_computation()) +
1280       module_output_size;
1281   VLOG(1) << "Peak memory usage of module now "
1282           << HumanReadableNumBytes(current_peak_memory) << " ("
1283           << current_peak_memory << " bytes), was "
1284           << HumanReadableNumBytes(before_peak_memory) << " ("
1285           << before_peak_memory << " bytes)";
1286   const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
1287   VLOG(1) << "Reduced peak memory by "
1288           << HumanReadableNumBytes(reduced_peak_memory) << " ("
1289           << reduced_peak_memory << " bytes)";
1290 
1291   if (sizes_ != nullptr) {
1292     sizes_->before_bytes = before_peak_memory;
1293     sizes_->after_bytes = current_peak_memory;
1294   }
1295 
1296   XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
1297 
1298   if (current_peak_memory > memory_limit_bytes_) {
1299     LOG(WARNING) << absl::StrFormat(
1300         "Can't reduce memory use below %s (%d bytes) by rematerialization; "
1301         "only reduced to %s (%d bytes)",
1302         HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
1303         HumanReadableNumBytes(current_peak_memory), current_peak_memory);
1304   }
1305 
1306   return changed;
1307 }
1308 
1309 }  // namespace xla
1310