1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <set>
21 #include <string>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "absl/container/inlined_vector.h"
26 #include "absl/strings/str_cat.h"
27 #include "absl/strings/str_format.h"
28 #include "absl/strings/str_join.h"
29 #include "tensorflow/compiler/xla/map_util.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/buffer_value.h"
32 #include "tensorflow/compiler/xla/service/flatten_call_graph.h"
33 #include "tensorflow/compiler/xla/service/hlo_computation.h"
34 #include "tensorflow/compiler/xla/service/hlo_dce.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
37 #include "tensorflow/compiler/xla/service/hlo_module.h"
38 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
39 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
40 #include "tensorflow/compiler/xla/service/logical_buffer.h"
41 #include "tensorflow/compiler/xla/status_macros.h"
42 #include "tensorflow/compiler/xla/statusor.h"
43 #include "tensorflow/compiler/xla/types.h"
44 #include "tensorflow/compiler/xla/util.h"
45 #include "tensorflow/core/platform/logging.h"
46
47 namespace xla {
48 namespace {
49
50 using ::tensorflow::strings::HumanReadableNumBytes;
51
52 // Potential optimizations:
53 // . TODO(b/35244891): Avoid N^2 behavior by keeping a priority queue
54 // of candidates.
55 // . Cache IsRematerializable in Item? Only correct if control
56 // predecessors and successors don't change.
57
58 // Returns true if the given instruction is rematerializable.
IsRematerializable(const HloInstruction * instruction)59 bool IsRematerializable(const HloInstruction* instruction) {
60 if (instruction->opcode() == HloOpcode::kCopy) {
61 if (LayoutUtil::Equal(instruction->shape().layout(),
62 instruction->operand(0)->shape().layout())) {
63 // Don't rematerialize copies added by copy insertion (layout doesn't
64 // change).
65 return false;
66 }
67 }
68
69 // Don't rematerialize instructions with side effects or instructions which
70 // cannot be cloned safely.
71 switch (instruction->opcode()) {
72 case HloOpcode::kCall:
73 case HloOpcode::kConstant:
74 case HloOpcode::kConditional:
75 case HloOpcode::kAllReduce:
76 case HloOpcode::kCustomCall:
77 case HloOpcode::kParameter:
78 case HloOpcode::kWhile:
79 return false;
80 default:
81 return !instruction->HasSideEffect();
82 }
83 }
84
85 // Checks whether an instruction can be rematerialized, by looking up the
86 // cache before, and eventually calling the IsRematerializable() API.
CanBeRematerialized(const HloInstruction * instruction,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)87 bool CanBeRematerialized(
88 const HloInstruction* instruction,
89 absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
90 auto it = remat_able->find(instruction);
91 if (it != remat_able->end()) {
92 return it->second;
93 }
94 bool rematerializable = IsRematerializable(instruction);
95 (*remat_able)[instruction] = rematerializable;
96 return rematerializable;
97 }
98
99 // Type holding a unique identifier for each Buffer object.
100 using BufferId = int64;
101 using BufferIdList = absl::InlinedVector<BufferId, 3>;
102
103 // We wrap HloInstruction* with an Item that holds auxiliary
104 // per-instruction state.
105 struct Item {
106 HloInstruction* instruction;
107
108 // True once the instruction is marked as placed (when BeginInstruction
109 // has been called for this instruction).
110 bool placed = false;
111
112 // To avoid an infinite loop rematerializing the same set of
113 // instructions ad infinitum, keep a blacklist of instructions
114 // which should not be rematerialized.
115 bool blacklisted = false;
116
117 // The buffers defined by this instruction.
118 BufferIdList buffers_defined;
119
120 // The buffers used by this instruction.
121 BufferIdList buffers_used;
122
123 private:
124 friend class InstructionList;
125
126 // Items are arranged in a doubly linked list.
127 Item* next;
128 Item* prev;
129
130 // List is ordered by position, which can however be duplicated as
131 // new instructions are inserted. See InsertBeforeInstructions
132 // comment for details.
133 int64 position;
134 };
135
136 using ItemList = absl::InlinedVector<Item*, 3>;
137
138 // Class which maintains an ordered list of instructions with fast insertion
139 // before arbitrary elements.
140 class InstructionList {
141 public:
InstructionList(const HloInstructionSequence & order)142 explicit InstructionList(const HloInstructionSequence& order) {
143 int64 position = 0;
144 Item* last = nullptr;
145 for (HloInstruction* inst : order.instructions()) {
146 // Add a new item to the linked list.
147 Item* item = new Item;
148 item->next = nullptr;
149 item->prev = last;
150 if (last == nullptr) {
151 first_ = item;
152 } else {
153 last->next = item;
154 }
155 last = item;
156
157 // Initially position numbers are uniquely assigned in order. Later as
158 // instructions are added with InsertBefore* methods, some instructions
159 // may have duplicate position numbers, but the values will be guaranteed
160 // to be monotonically increasing through the list, and so is still useful
161 // for quickly(-ish) determining the order of arbitrary instructions in
162 // the list.
163 item->instruction = inst;
164 item->position = position;
165 position++;
166
167 item_map_[inst] = item;
168 }
169 }
170
~InstructionList()171 ~InstructionList() {
172 for (Item* item = first_; item != nullptr;) {
173 Item* next = item->next;
174 delete item;
175 item = next;
176 }
177 }
178
size() const179 size_t size() const { return item_map_.size(); }
180
181 // For ordered iteration over items.
182 // for (auto item = q.first(); item != nullptr; item = q.next(item)) {...}
first() const183 Item* first() const { return first_; }
next(Item * item) const184 Item* next(Item* item) const { return item->next; }
185
186 // Creates an Item for the given instruction, but doesn't add it to the list.
187 // (Use InsertBeforeInstructions to add the Item to the list.)
CreateItem(HloInstruction * inst)188 Item* CreateItem(HloInstruction* inst) {
189 Item* item = new Item;
190 item->instruction = inst;
191 CHECK(item_map_.insert({inst, item}).second)
192 << "inserting inst twice " << inst->name();
193 return item;
194 }
195
196 // Return the Item corresponding to inst.
GetItem(const HloInstruction * inst) const197 Item* GetItem(const HloInstruction* inst) const {
198 auto iter = item_map_.find(inst);
199 CHECK(iter != item_map_.end()) << "Did not find " << inst->name();
200 return iter->second;
201 }
202
203 // Insert instruction 'to_insert' immediately before the earliest instruction
204 // in 'before_instructions'.
205 //
206 // Each instruction gets a non-decreasing ordinal number. We use this to let
207 // InsertBeforeInstructions quickly insert an instruction before the earliest
208 // instruction in a set of instructions. If position_number_[a] <
209 // position_number_[b] then 'a' comes before 'b' in the list. If the position
210 // numbers are the same then nothing can be said about their order without
211 // examining the list.
212 //
213 // On object construction this ordinal is precisely the instruction's index
214 // in the list. Later, instructions inserted via InsertBefore receive
215 // duplicate values. However, monotonicity is preserved.
InsertBeforeInstructions(Item * to_insert,absl::Span<Item * const> before_instructions)216 void InsertBeforeInstructions(Item* to_insert,
217 absl::Span<Item* const> before_instructions) {
218 VLOG(3) << "InsertBeforeInstructions: " << to_insert->instruction->name()
219 << " before {"
220 << absl::StrJoin(before_instructions, ", ",
221 [](string* out, Item* item) {
222 absl::StrAppend(out, item->instruction->name());
223 })
224 << "}";
225
226 // Find the minimal position number of any instruction in
227 // 'before_instructions'.
228 CHECK(!before_instructions.empty());
229 Item* min_position_item = nullptr;
230 for (Item* item : before_instructions) {
231 if (min_position_item == nullptr ||
232 item->position < min_position_item->position) {
233 min_position_item = item;
234 }
235 }
236
237 // Because more than one instruction in 'before_instructions' may have a
238 // position number of 'min_position_number', find the first such instruction
239 // with position number 'min_position_number'.
240
241 // First find first instruction with the min position.
242 while (min_position_item->prev != nullptr &&
243 min_position_item->position == min_position_item->prev->position) {
244 min_position_item = min_position_item->prev;
245 }
246
247 // Now scan forwards until we find one of the before_instructions.
248 while (!absl::c_linear_search(before_instructions, min_position_item)) {
249 min_position_item = min_position_item->next;
250 }
251 return InsertBefore(to_insert, min_position_item);
252 }
253
Blacklist(const HloInstruction * inst)254 void Blacklist(const HloInstruction* inst) {
255 GetItem(inst)->blacklisted = true;
256 }
257
258 private:
259 // Insert instruction 'item' immediately before 'before' in the list.
InsertBefore(Item * item,Item * before)260 void InsertBefore(Item* item, Item* before) {
261 VLOG(3) << "InsertBefore: " << item->instruction->name() << " before "
262 << before->instruction->name();
263 // Insert new item into linked list.
264 item->prev = before->prev;
265 item->next = before;
266 before->prev = item;
267 if (item->prev != nullptr) {
268 item->prev->next = item;
269 } else {
270 first_ = item;
271 }
272
273 // Assign the same position number to the newly added instruction as
274 // 'before'. This guarantees monotonicity of the position numbers, but not
275 // uniqueness.
276 item->position = before->position;
277 }
278
279 Item* first_;
280
281 // Item for each instruction.
282 absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
283 };
284
285 // Return the items which use the given LogicalBuffer. Sets
286 // has_indirect_users to whether any of the uses is indirect. A use is indirect
287 // if the instruction defining logical_buffer is not an operand of the use. This
288 // can happen via buffer aliasing (eg, tuples).
GetUsers(const InstructionList & instruction_list,const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,bool * has_indirect_users)289 ItemList GetUsers(const InstructionList& instruction_list,
290 const LogicalBuffer* logical_buffer,
291 const TuplePointsToAnalysis& points_to_analysis,
292 bool* has_indirect_users) {
293 ItemList users;
294 // To identify uses iterate through all HloInstruction users of the
295 // BufferAliases of the logical buffer.
296 *has_indirect_users = false;
297 for (const BufferAlias& buffer_alias :
298 points_to_analysis.GetBufferAliases(*logical_buffer)) {
299 for (const HloInstruction* user : buffer_alias.instruction()->users()) {
300 if (points_to_analysis.DoesNotUseOperandBuffer(
301 buffer_alias.instruction(), buffer_alias.index(), user)) {
302 // The alias may be an operand of 'user', but the LogicalBuffer cannot
303 // possibly be used by the instruction so ignore 'user'. This is the
304 // case, for example, for the tuple element buffers in a GetTupleElement
305 // instruction (the GTE instruction only uses the pointer vector).
306 continue;
307 }
308 if (buffer_alias.instruction() != logical_buffer->instruction()) {
309 *has_indirect_users = true;
310 }
311 // A buffer may be used by the instruction via more than one alias. For
312 // example, a buffer which appears in more than one element of a tuple.
313 Item* user_item = instruction_list.GetItem(user);
314 if (!absl::c_linear_search(users, user_item)) {
315 users.push_back(user_item);
316 }
317 }
318 }
319 return users;
320 }
321
322 // Class for tracking memory usage of a computation as the instructions are
323 // placed sequentially. Memory usage is the sum of the sizes of live values
324 // (LogicalBuffers) at the current point in the instruction sequence.
325 class MemoryUsageTracker {
326 public:
327 MemoryUsageTracker(
328 const HloComputation* computation,
329 const HloRematerialization::ShapeSizeFunction& size_function,
330 const TuplePointsToAnalysis& points_to_analysis,
331 const InstructionList& instruction_list);
332
333 // Starts the placement of the given instruction. This adds the sizes of the
334 // LogicalBuffers defined by the instruction to the current memory
335 // usage. Placement is broken into two steps (BeginInstruction and
336 // EndInstruction) to accurately model memory usage. At BeginInstruction the
337 // memory for the output value(s) of the current instruction is allocated. At
338 // EndInstruction memory for dead operand(s) is freed.
339 Status BeginInstruction(Item* item);
340
341 // Finishes the placement of the current instruction. This frees any dead
342 // operands or dead result of the instruction. This must be called after
343 // each call to BeginInstruction.
344 Status EndInstruction();
345
346 // Returns the number of bytes that the current memory usage will be reduced
347 // if the given instruction is rematerialized.
348 int64 MemoryReducedIfRematerialized(Item* item) const;
349
350 // Adjusts memory usage to account for the rematerialization of
351 // original_item for all remaining unplaced uses. The rematerialization
352 // is remat_item. This method should be called after the HLO graph has
353 // been transformed (rematerialization instruction created and connected to
354 // uses).
355 Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
356
357 // Returns whether the given instruction has been placed (BeginInstruction
358 // has been called with 'instruction' as the argument).
IsPlaced(const HloInstruction * instruction) const359 bool IsPlaced(const HloInstruction* instruction) const {
360 return instruction_list_.GetItem(instruction)->placed;
361 }
362
363 // Returns the current memory usage. This is the sum of sizes of all live
364 // values.
memory_usage() const365 int64 memory_usage() const { return memory_usage_; }
366
367 // Check invariants of the data structure. This is expensive to call.
368 bool Check() const;
369
370 string ToString() const;
371
372 private:
373 // A Buffer represents a single LogicalBuffer in the computation including
374 // various metadata useful for tracking liveness of the value. A LogicalBuffer
375 // is not used directly because the HLO graph is transformed and
376 // TuplePointsToAnalysis which owns all LogicalBuffers cannot be updated after
377 // HLO graph transformations.
378 struct Buffer {
379 // The unique id of this Buffer. This value is equal to the buffer's index
380 // in the vector buffers_.
381 const BufferId id;
382
383 // The instruction which defines this buffer.
384 Item* defining_instruction;
385
386 // The materialized size of the buffer in bytes.
387 const int64 size;
388
389 // Whether this buffer is live-out of the computation.
390 bool live_out;
391
392 // Whether this buffer has indirect uses. Ie, an instruction which is not a
393 // user of defining_instruction uses this buffer. This can occur due to
394 // buffer aliasing (eg, tuples).
395 bool has_indirect_uses;
396
397 // The instructions which use this buffer.
398 ItemList users;
399
400 // The number of users (HloInstructions) of this buffer which have not yet
401 // been placed in the sequence.
402 int64 unfinished_user_count;
403
ToStringxla::__anon6c3ee80d0111::MemoryUsageTracker::Buffer404 string ToString() const {
405 return absl::StrCat("Buffer ", id, " (defined by ",
406 defining_instruction->instruction->name(), ", size ",
407 size, " bytes)");
408 }
409 };
410
411 // Creates a Buffer representing the given logical buffer. The buffer is added
412 // to buffers_ and a reference is returned.
CreateBufferFromLogicalBuffer(const LogicalBuffer * logical_buffer,const TuplePointsToAnalysis & points_to_analysis,const HloRematerialization::ShapeSizeFunction & size_function,bool live_out)413 Buffer& CreateBufferFromLogicalBuffer(
414 const LogicalBuffer* logical_buffer,
415 const TuplePointsToAnalysis& points_to_analysis,
416 const HloRematerialization::ShapeSizeFunction& size_function,
417 bool live_out) {
418 bool has_indirect_uses = false;
419 ItemList users = GetUsers(instruction_list_, logical_buffer,
420 points_to_analysis, &has_indirect_uses);
421 return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
422 size_function(logical_buffer->shape()), std::move(users),
423 live_out, has_indirect_uses);
424 }
425
426 // Create a new buffer representing a rematerialization of given buffer for
427 // the given uses.
RematerializeBuffer(const Buffer & original_buffer,Item * remat_item,ItemList && rematerialized_uses)428 Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
429 ItemList&& rematerialized_uses) {
430 CHECK(original_buffer.defining_instruction->placed)
431 << original_buffer.defining_instruction->instruction->name();
432 CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
433 CHECK(!original_buffer.live_out) << original_buffer.ToString();
434 for (Item* use : rematerialized_uses) {
435 CHECK(!use->placed) << use->instruction->name();
436 }
437 return NewBuffer(remat_item, original_buffer.size,
438 std::move(rematerialized_uses), /*live_out=*/false,
439 /*has_indirect_uses=*/false);
440 }
441
442 // Return number of bytes allocated for the buffer with the given id. Buffers
443 // allocated by the calling computation (eg, parameter and output buffers) are
444 // considered to have zero bytes because the memory is accounted for in a
445 // different computation.
AllocatedSize(BufferId buffer_id) const446 int64 AllocatedSize(BufferId buffer_id) const {
447 const Buffer& buffer = buffers_.at(buffer_id);
448 HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
449 if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
450 return 0;
451 } else {
452 return buffer.size;
453 }
454 }
455
456 // Returns true if BeginInstruction and EndInstruction has been called for the
457 // given instruction.
IsFinished(Item * item) const458 bool IsFinished(Item* item) const {
459 return item->placed && item != in_progress_item_;
460 }
461
462 // Returns whether the given buffer is being used by the in-progress
463 // instruction.
IsInUse(BufferId buffer_id) const464 bool IsInUse(BufferId buffer_id) const {
465 if (in_progress_item_ == nullptr) {
466 return false;
467 }
468 const BufferIdList& in_progress_uses = in_progress_item_->buffers_used;
469 return absl::c_linear_search(in_progress_uses, buffer_id);
470 }
471
472 // Returns whether the given instruction is live at the current program
473 // point.
IsCurrentlyLive(BufferId buffer_id) const474 bool IsCurrentlyLive(BufferId buffer_id) const {
475 const Buffer& buffer = buffers_[buffer_id];
476 return (buffer.defining_instruction->placed &&
477 buffer.unfinished_user_count > 0);
478 }
479
480 // Create a new buffer, add it to buffers_, and return a reference.
NewBuffer(Item * defining_instruction,int64 size,ItemList && users,bool live_out,bool has_indirect_uses)481 Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
482 bool live_out, bool has_indirect_uses) {
483 int buffer_id = buffers_.size();
484 buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
485 has_indirect_uses, users,
486 static_cast<int64>(users.size())});
487 return buffers_.back();
488 }
489
490 const HloComputation* computation_;
491
492 // Instruction list containing the ordering of instructions in
493 // computation_. This is the order in which instructions are placed
494 // (BeginInstruction/EndInstruction calls).
495 const InstructionList& instruction_list_;
496
497 // Memory usage at the currently placed instruction.
498 int64 memory_usage_ = 0;
499
500 // The instruction currently being placed. This value is non-null only
501 // between the calling of BeginInstruction and EndInstruction.
502 Item* in_progress_item_ = nullptr;
503
504 // All buffers in the computation.
505 std::vector<Buffer> buffers_;
506 };
507
MemoryUsageTracker(const HloComputation * computation,const HloRematerialization::ShapeSizeFunction & size_function,const TuplePointsToAnalysis & points_to_analysis,const InstructionList & instruction_list)508 MemoryUsageTracker::MemoryUsageTracker(
509 const HloComputation* computation,
510 const HloRematerialization::ShapeSizeFunction& size_function,
511 const TuplePointsToAnalysis& points_to_analysis,
512 const InstructionList& instruction_list)
513 : computation_(computation), instruction_list_(instruction_list) {
514 PointsToSet::BufferSet live_out_set =
515 points_to_analysis.GetPointsToSet(computation_->root_instruction())
516 .CreateFlattenedSet();
517 absl::flat_hash_map<const LogicalBuffer*, BufferId>
518 logical_buffer_to_buffer_id;
519
520 for (auto* item = instruction_list_.first(); item != nullptr;
521 item = instruction_list_.next(item)) {
522 const HloInstruction* const instruction = item->instruction;
523 for (const LogicalBuffer* logical_buffer :
524 points_to_analysis.GetBuffersDefinedByInstruction(instruction)) {
525 Buffer* buffer;
526 if (instruction->opcode() == HloOpcode::kWhile) {
527 // The while instruction defines no new buffers. Instead it reuses the
528 // buffers of its operand. Find the Buffer of its operand at the
529 // proper ShapeIndex.
530 const PointsToSet& operand_points_to =
531 points_to_analysis.GetPointsToSet(instruction->operand(0));
532 CHECK_EQ(operand_points_to.element(logical_buffer->index()).size(), 1);
533 const LogicalBuffer* source_logical_buffer =
534 operand_points_to.element(logical_buffer->index())[0];
535 buffer =
536 &buffers_.at(logical_buffer_to_buffer_id.at(source_logical_buffer));
537
538 // Mark buffer as has indirect use and live out.
539 buffer->has_indirect_uses = true;
540 buffer->live_out =
541 buffer->live_out || ContainsKey(live_out_set, logical_buffer);
542
543 // Add users of while to Buffer users.
544 bool unused;
545 for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
546 points_to_analysis, &unused)) {
547 if (!absl::c_linear_search(buffer->users, user_item)) {
548 buffer->users.push_back(user_item);
549 buffer->unfinished_user_count++;
550 user_item->buffers_used.push_back(buffer->id);
551 }
552 }
553 } else {
554 buffer = &CreateBufferFromLogicalBuffer(
555 logical_buffer, points_to_analysis, size_function,
556 ContainsKey(live_out_set, logical_buffer));
557 item->buffers_defined.push_back(buffer->id);
558 for (Item* user : buffer->users) {
559 user->buffers_used.push_back(buffer->id);
560 }
561 }
562
563 logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
564 }
565 }
566 XLA_VLOG_LINES(10, ToString());
567 DCHECK(Check());
568 }
569
BeginInstruction(Item * item)570 Status MemoryUsageTracker::BeginInstruction(Item* item) {
571 const HloInstruction* instruction = item->instruction;
572 VLOG(3) << "BeginInstruction " << instruction->name();
573 TF_RET_CHECK(in_progress_item_ == nullptr);
574 in_progress_item_ = item;
575
576 item->placed = true;
577
578 // All buffers defined by this instruction need memory.
579 for (BufferId buffer_id : item->buffers_defined) {
580 VLOG(3) << " Buffer " << buffers_.at(buffer_id).ToString()
581 << " is now live.";
582 memory_usage_ += AllocatedSize(buffer_id);
583 }
584
585 // TODO(b/37686934): Elementwise instructions can share the buffer of a (dead)
586 // operand. Account for this potential reuse here.
587
588 VLOG(3) << " memory usage = " << memory_usage_;
589 VLOG(10) << ToString();
590
591 if (VLOG_IS_ON(1)) {
592 DCHECK(Check());
593 }
594 return Status::OK();
595 }
596
EndInstruction()597 Status MemoryUsageTracker::EndInstruction() {
598 TF_RET_CHECK(in_progress_item_ != nullptr);
599 VLOG(3) << "EndInstruction " << in_progress_item_->instruction->name();
600
601 for (BufferId buffer_id : in_progress_item_->buffers_used) {
602 Buffer& buffer = buffers_.at(buffer_id);
603 buffer.unfinished_user_count--;
604 CHECK_GE(buffer.unfinished_user_count, 0)
605 << buffer.ToString() << " has negative unfinished use count.";
606 if (buffer.unfinished_user_count == 0) {
607 // Buffer is now dead.
608 VLOG(3) << " " << buffer.ToString() << " is now dead.";
609 memory_usage_ -= AllocatedSize(buffer_id);
610 CHECK_GE(memory_usage_, 0);
611 }
612 }
613
614 // If any buffer defined by this instruction has no uses, then memory can be
615 // reclaimed immediately.
616 for (BufferId buffer_id : in_progress_item_->buffers_defined) {
617 const Buffer& buffer = buffers_.at(buffer_id);
618 if (buffer.unfinished_user_count == 0) {
619 VLOG(3) << " " << buffer.ToString() << " is immediately dead.";
620 memory_usage_ -= AllocatedSize(buffer_id);
621 CHECK_GE(memory_usage_, 0);
622 }
623 }
624
625 in_progress_item_ = nullptr;
626
627 VLOG(3) << " memory usage = " << memory_usage_;
628 VLOG(10) << ToString();
629
630 if (VLOG_IS_ON(1)) {
631 DCHECK(Check());
632 }
633 return Status::OK();
634 }
635
MemoryReducedIfRematerialized(Item * item) const636 int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
637 CHECK_NE(in_progress_item_, nullptr);
638 if (!item->placed || item == in_progress_item_) {
639 return 0;
640 }
641
642 // TODO(b/37687140): Rematerialization can increase peak memory consumption at
643 // an earlier point in the program if rematerialization extends the live range
644 // of the operand of the instruction being rematerialized across the live
645 // range of the value of instruction being rematerialized. Don't rematerialize
646 // in this case (ie, return 0 here).
647
648 // Compute the amount of memory reduced (if any) by rematerializing
649 // 'instruction'. The LogicalBuffers defined by 'instruction' will no longer
650 // be live at this program point, so initially set memory_reduced to the
651 // size of its defined values.
652 int64 memory_reduced = 0;
653 for (BufferId buffer_id : item->buffers_defined) {
654 // Avoid rematerializing instructions with indirect uses as it is difficult
655 // to reason about liveness after rematerializing the instruction.
656 // TODO(b/37714814): Consider rematerialzing instructions with indirect
657 // uses.
658 if (buffers_.at(buffer_id).has_indirect_uses) {
659 return 0;
660 }
661
662 if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
663 memory_reduced += AllocatedSize(buffer_id);
664 }
665 }
666
667 // Account for any logical buffers whose live range must be extended across
668 // this program point.
669 for (BufferId buffer_id : item->buffers_used) {
670 if (!IsCurrentlyLive(buffer_id)) {
671 // This logical buffer is used by 'instruction' but is not live at this
672 // program point. Rematerializing 'instruction' will extend the buffer's
673 // live range across this program point.
674 memory_reduced -= AllocatedSize(buffer_id);
675 }
676 }
677
678 return memory_reduced;
679 }
680
AddRematerializedInstruction(Item * original_item,Item * remat_item)681 Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
682 Item* remat_item) {
683 VLOG(3) << "AddRematerializedInstruction: original_instruction = "
684 << original_item->instruction->name()
685 << ", remat_instruction = " << remat_item->instruction->name();
686
687 TF_RET_CHECK(in_progress_item_ != nullptr);
688 TF_RET_CHECK(original_item->placed) << original_item->instruction->name();
689 TF_RET_CHECK(!remat_item->placed) << remat_item->instruction->name();
690
691 // Construct the list of buffers used and defined by the rematerialization.
692 remat_item->buffers_used = original_item->buffers_used;
693
694 // Account for the additional buffer uses created by the new rematerialization
695 // instruction. Update memory usage if the rematerialization makes a dead
696 // buffer live again.
697 for (BufferId buffer_id : original_item->buffers_used) {
698 Buffer& buffer = buffers_.at(buffer_id);
699 if (buffer.unfinished_user_count == 0) {
700 // Buffer used by this instruction was dead, now is alive.
701 memory_usage_ += AllocatedSize(buffer.id);
702 }
703
704 buffer.unfinished_user_count++;
705 buffer.users.push_back(remat_item);
706 }
707
708 // Create a new set of Buffers defined by the new rematerialization
709 // instruction. Update the internal data structures and memory use to account
710 // for them.
711 for (BufferId old_buffer_id : original_item->buffers_defined) {
712 Buffer& old_buffer = buffers_.at(old_buffer_id);
713
714 ItemList placed_users;
715 ItemList unplaced_users;
716 for (Item* user : old_buffer.users) {
717 if (user->placed) {
718 CHECK(IsFinished(user)) << user->instruction->name();
719 placed_users.push_back(user);
720 } else {
721 unplaced_users.push_back(user);
722 }
723 }
724 old_buffer.users = std::move(placed_users);
725 old_buffer.unfinished_user_count = 0;
726
727 // Buffer is now dead.
728 memory_usage_ -= AllocatedSize(old_buffer.id);
729
730 Buffer& new_buffer =
731 RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
732
733 remat_item->buffers_defined.push_back(new_buffer.id);
734 for (Item* user : new_buffer.users) {
735 BufferIdList& buffers_used = user->buffers_used;
736 std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
737 new_buffer.id);
738 }
739 }
740
741 VLOG(3) << " memory usage = " << memory_usage_;
742 XLA_VLOG_LINES(10, ToString());
743
744 DCHECK(Check());
745
746 return Status::OK();
747 }
748
ToString() const749 string MemoryUsageTracker::ToString() const {
750 string output =
751 absl::StrCat("MemoryUsageTracker for ", computation_->name(), "\n");
752 absl::StrAppend(&output,
753 "Memory usage: ", HumanReadableNumBytes(memory_usage()), " (",
754 memory_usage(), " bytes)");
755 for (auto* item = instruction_list_.first(); item != nullptr;
756 item = instruction_list_.next(item)) {
757 const HloInstruction* instruction = item->instruction;
758 string inprogress = item == in_progress_item_ ? " in-progress" : "";
759 string placed = item->placed ? " placed" : "";
760 absl::StrAppend(&output, " ", instruction->name(), inprogress, placed,
761 "\n Defines:\n");
762 for (BufferId buffer_id : item->buffers_defined) {
763 const Buffer& buffer = buffers_[buffer_id];
764 string live = IsCurrentlyLive(buffer_id) ? " live" : "";
765 absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
766 buffer.unfinished_user_count, " unfinished uses\n");
767 }
768 absl::StrAppend(&output, " Uses:\n");
769 for (BufferId buffer_id : item->buffers_used) {
770 absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
771 }
772 }
773 return output;
774 }
775
Check() const776 bool MemoryUsageTracker::Check() const {
777 auto elements_are_unique = [](const BufferIdList& vec) {
778 return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
779 };
780
781 // Verify buffers_defined per instruction.
782 for (auto* instruction : computation_->instructions()) {
783 const BufferIdList& defined_buffers =
784 instruction_list_.GetItem(instruction)->buffers_defined;
785 CHECK(elements_are_unique(defined_buffers))
786 << "Instruction " << instruction->name()
787 << " does not have unique defined buffers: "
788 << absl::StrJoin(
789 defined_buffers, ", ", [this](string* out, BufferId buffer_id) {
790 absl::StrAppend(out, buffers_.at(buffer_id).ToString());
791 });
792
793 for (const Buffer& buffer : buffers_) {
794 if (buffer.defining_instruction->instruction == instruction) {
795 CHECK(absl::c_linear_search(defined_buffers, buffer.id))
796 << "Instruction " << instruction->name()
797 << " defined buffers is missing: " << buffer.ToString();
798 }
799 }
800 }
801
802 // Verify buffers_used per instruction.
803 for (auto* instruction : computation_->instructions()) {
804 const BufferIdList& used_buffers =
805 instruction_list_.GetItem(instruction)->buffers_used;
806 CHECK(elements_are_unique(used_buffers))
807 << "Instruction " << instruction->name()
808 << " does not have unique used buffers: "
809 << absl::StrJoin(
810 used_buffers, ", ", [this](string* out, BufferId buffer_id) {
811 absl::StrAppend(out, buffers_.at(buffer_id).ToString());
812 });
813 }
814 for (const Buffer& buffer : buffers_) {
815 int64 unfinished_uses = 0;
816 for (Item* user : buffer.users) {
817 const BufferIdList& used_buffers = user->buffers_used;
818 CHECK(absl::c_linear_search(used_buffers, buffer.id))
819 << "Instruction " << user->instruction->name()
820 << " used buffers is missing " << buffer.ToString();
821 if (!IsFinished(user)) {
822 unfinished_uses++;
823 }
824 }
825 CHECK_EQ(buffer.unfinished_user_count, unfinished_uses)
826 << "Incorrect unplaced use count for " << buffer.ToString();
827 }
828 return true;
829 }
830
831 // Computes and returns the cost of rematerializing the given instruction.
832 // Cost per rematerialized instruction is defined as:
833 //
834 // memory_limit_bytes / memory_reduced
835 //
836 // The idea is to choose the operation that will save the most memory for
837 // rematerialization and do not worry about how much the compute costs since
838 // running out of memory is more harmful than taking longer to get the answer.
RematerializationCost(const HloInstruction * instruction,const MemoryUsageTracker & memory_tracker,int64 memory_reduced,int64 memory_limit_bytes)839 int64 RematerializationCost(const HloInstruction* instruction,
840 const MemoryUsageTracker& memory_tracker,
841 int64 memory_reduced, int64 memory_limit_bytes) {
842 // If none of the users of 'instruction' have been placed in the sequence (as
843 // tracked by memory_tracker), then rematerialization of 'instruction' is a
844 // zero-cost move of 'instruction' in the sequence.
845 if (!absl::c_any_of(instruction->users(),
846 [&memory_tracker](const HloInstruction* inst) {
847 return memory_tracker.IsPlaced(inst);
848 })) {
849 return 0;
850 }
851
852 CHECK_GT(memory_reduced, 0);
853 // Return the inverse of the benefit of rematerialization.
854 return memory_limit_bytes / memory_reduced;
855 }
856
857 // Selects and returns the best candidate instruction for rematerialization.
858 // The instruction with lowest rematerialization cost is selected among those
859 // candidate which reduce memory use at the program point of the current
860 // instruction as indicated by memory_tracker. nullptr is returned if no
861 // candidate can be found.
PickRematerializationCandidate(const MemoryUsageTracker & memory_tracker,const InstructionList & instruction_list,int64 memory_limit_bytes,absl::flat_hash_map<const HloInstruction *,bool> * remat_able)862 Item* PickRematerializationCandidate(
863 const MemoryUsageTracker& memory_tracker,
864 const InstructionList& instruction_list, int64 memory_limit_bytes,
865 absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
866 Item* best_item = nullptr;
867 int64 best_cost = 0;
868
869 // TODO(b/35244891): This is currently quadratic in the number of HLO
870 // instructions.
871 for (auto* item = instruction_list.first(); item != nullptr;
872 item = instruction_list.next(item)) {
873 if (!item->placed) {
874 // Only iterate up to the currently placed instruction.
875 // We are trying to reduce memory usage at the placed
876 // instruction so rematerializing later values is of no benefit.
877 break;
878 }
879 HloInstruction* candidate = item->instruction;
880 VLOG(5) << "considering rematerialization candidate " << candidate->name();
881
882 if (item->blacklisted) {
883 // Skip instructions on the blacklist to avoid infinite loops of
884 // rematerializing the same instruction(s) repeatedly.
885 VLOG(5) << "candidate " << candidate->name()
886 << " is excluded from rematerialization";
887 continue;
888 }
889 if (!CanBeRematerialized(candidate, remat_able)) {
890 VLOG(5) << "candidate " << candidate->name()
891 << " not viable: is not rematerializable";
892 continue;
893 }
894
895 // If any of the candidate's control successor has been placed, we need to
896 // skip this candidate. Otherwise we will violate control dependency.
897 bool control_successor_placed =
898 std::any_of(candidate->control_successors().begin(),
899 candidate->control_successors().end(),
900 [&memory_tracker](const HloInstruction* inst) {
901 return memory_tracker.IsPlaced(inst);
902 });
903
904 if (control_successor_placed) {
905 continue;
906 }
907
908 const int64 memory_reduced =
909 memory_tracker.MemoryReducedIfRematerialized(item);
910
911 if (memory_reduced <= 0) {
912 VLOG(5) << "candidate " << candidate->name()
913 << " memory reduced = " << memory_reduced << " <= 0";
914 continue;
915 }
916
917 const int cost = RematerializationCost(candidate, memory_tracker,
918 memory_reduced, memory_limit_bytes);
919
920 VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
921 << memory_reduced << ", cost per byte " << cost;
922
923 if (best_item == nullptr || cost < best_cost) {
924 VLOG(5) << "candidate " << candidate->name() << " now best";
925 best_item = item;
926 best_cost = cost;
927 }
928 }
929 return best_item;
930 }
931
932 } // namespace
933
ComputePeakMemory(const HloComputation * computation,const HloInstructionSequence & order) const934 StatusOr<int64> HloRematerialization::ComputePeakMemory(
935 const HloComputation* computation,
936 const HloInstructionSequence& order) const {
937 InstructionList instruction_list(order);
938 MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
939 instruction_list);
940 int64 peak_memory = tracker.memory_usage();
941 for (auto* item = instruction_list.first(); item != nullptr;
942 item = instruction_list.next(item)) {
943 const HloInstruction* instruction = item->instruction;
944 TF_RETURN_IF_ERROR(tracker.BeginInstruction(item));
945 TF_ASSIGN_OR_RETURN(int64 callee_usage,
946 CalledComputationsMemoryUsage(instruction));
947 peak_memory =
948 std::max<int64>(peak_memory, tracker.memory_usage() + callee_usage);
949 TF_RETURN_IF_ERROR(tracker.EndInstruction());
950 }
951 VLOG(1) << "Peak memory for " << computation->name() << ": "
952 << HumanReadableNumBytes(peak_memory);
953 return peak_memory;
954 }
955
CalledComputationsMemoryUsage(const HloInstruction * instruction) const956 StatusOr<int64> HloRematerialization::CalledComputationsMemoryUsage(
957 const HloInstruction* instruction) const {
958 const CallSite* callsite =
959 call_graph_->GetNode(instruction->parent()).GetCallSite(instruction);
960 if (callsite == nullptr || callsite->context() == CallContext::kParallel) {
961 return 0;
962 }
963 int64 callee_usage = 0;
964 for (const HloComputation* computation : callsite->called_computations()) {
965 TF_RET_CHECK(ContainsKey(computation_peak_memory_, computation));
966 callee_usage += computation_peak_memory_.at(computation);
967 }
968 return callee_usage;
969 }
970
RematerializeComputation(HloComputation * computation,HloSchedule * schedule,int64 memory_limit_bytes)971 StatusOr<bool> HloRematerialization::RematerializeComputation(
972 HloComputation* computation, HloSchedule* schedule,
973 int64 memory_limit_bytes) {
974 VLOG(1) << "Rematerializing computation " << computation->name()
975 << " with limit " << HumanReadableNumBytes(memory_limit_bytes);
976 VLOG(1) << "peak memory usage is "
977 << HumanReadableNumBytes(computation_peak_memory_.at(computation));
978 CHECK(!ContainsKey(rematerialized_computations_, computation));
979
980 InstructionList instruction_list(schedule->sequence(computation));
981 MemoryUsageTracker memory_tracker(computation, size_function_,
982 *points_to_analysis_, instruction_list);
983 bool changed = false;
984
985 // If the rematerialization makes the source instruction dead, then the
986 // rematerialization is added to 'remat_move_instructions' (the
987 // rematerialization is essentially a move). If the next rematerialization of
988 // the instruction is also a move then the rematerialization is added to the
989 // blacklist.
990 absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
991
992 // The map from instructions to their rematerializable status.
993 absl::flat_hash_map<const HloInstruction*, bool> remat_able;
994
995 // The peak memory of the computation at any point in the instruction
996 // sequence.
997 int64 peak_memory = memory_tracker.memory_usage();
998
999 // Total count of instructions rematerialized.
1000 int64 remat_count = 0;
1001 // Total count of clones created minus number of original rematerialized
1002 // instructions which are dead.
1003 int64 net_instructions_added = 0;
1004
1005 const CallGraphNode& call_graph_node = call_graph_->GetNode(computation);
1006
1007 // Iterate through all instructions in the sequence. At each instruction
1008 // (program point) if memory_usage exceeds the specified limit then
1009 // rematerialize HLO instructions until memory_usage is reduced.
1010 int64 instruction_index = 0;
1011 for (auto* item = instruction_list.first(); item != nullptr;
1012 item = instruction_list.next(item)) {
1013 const HloInstruction* instruction = item->instruction;
1014 TF_ASSIGN_OR_RETURN(int64 callee_usage,
1015 CalledComputationsMemoryUsage(instruction));
1016 TF_RETURN_IF_ERROR(memory_tracker.BeginInstruction(item));
1017
1018 VLOG(2) << "Program point at " << instruction->name()
1019 << ", memory usage = " << memory_tracker.memory_usage()
1020 << ", callee usage = " << callee_usage << ", [" << instruction_index
1021 << "/" << instruction_list.size() << "]";
1022 instruction_index++;
1023
1024 while (memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1025 VLOG(2) << "Over memory limit at instruction " << instruction->name()
1026 << ", using "
1027 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1028 callee_usage)
1029 << ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
1030
1031 Item* best_item = PickRematerializationCandidate(
1032 memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
1033
1034 if (best_item == nullptr) {
1035 VLOG(3) << "Unable to find rematerialization candidate at program "
1036 "point "
1037 << instruction->name() << ". Memory usage = "
1038 << HumanReadableNumBytes(memory_tracker.memory_usage() +
1039 callee_usage);
1040 break;
1041 }
1042
1043 HloInstruction* best = best_item->instruction;
1044 VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
1045 << HumanReadableNumBytes(
1046 memory_tracker.MemoryReducedIfRematerialized(best_item))
1047 << ")";
1048 changed = true;
1049 remat_count++;
1050
1051 HloInstruction* remat =
1052 computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
1053
1054 // Add control dependencies to the new operation.
1055 for (auto successor : best->control_successors()) {
1056 TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
1057 }
1058 for (auto predecessor : best->control_predecessors()) {
1059 TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
1060 }
1061
1062 Item* remat_item = instruction_list.CreateItem(remat);
1063
1064 // Replace each remaining use of 'best' with the rematerialization.
1065 std::vector<HloInstruction*> best_users_copy = best->users();
1066 for (HloInstruction* user : best_users_copy) {
1067 if (!memory_tracker.IsPlaced(user)) {
1068 VLOG(2) << " Replacing use of " << best->name() << " in "
1069 << user->name() << " with " << remat->name();
1070 TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
1071 }
1072 }
1073
1074 // Account for the rematerialization in the memory tracker.
1075 TF_RETURN_IF_ERROR(
1076 memory_tracker.AddRematerializedInstruction(best_item, remat_item));
1077
1078 // Insert rematerialized instruction right before the earliest unplaced
1079 // use of the instruction *and* the earliest unplaced last use of any
1080 // operands of remat. Unplaced uses of the remat's operands are included
1081 // because we don't want to extend the live range of remat's operands as
1082 // this could increase memory usage.
1083 ItemList place_before;
1084 for (auto user : remat->users()) {
1085 place_before.push_back(instruction_list.GetItem(user));
1086 }
1087 for (auto* operand : remat->operands()) {
1088 for (auto* operand_user : operand->users()) {
1089 if (operand_user != remat) {
1090 Item* operand_user_item = instruction_list.GetItem(operand_user);
1091 if (!operand_user_item->placed) {
1092 place_before.push_back(operand_user_item);
1093 }
1094 }
1095 }
1096 }
1097 // Insert rematerialized instruction before any of its successors to
1098 // preserve ordering regarding control dependency.
1099 for (auto successor : remat->control_successors()) {
1100 Item* successor_item = instruction_list.GetItem(successor);
1101 // Assert to make sure we never remat an operation with control
1102 // successor already placed.
1103 CHECK(!successor_item->placed) << successor_item->instruction->name();
1104 place_before.push_back(successor_item);
1105 }
1106 instruction_list.InsertBeforeInstructions(remat_item, place_before);
1107
1108 // If the rematerialized instruction is dead then rematerialization is
1109 // essentially a move. Don't delete the instruction now because we don't
1110 // want duplicate HloInstruction* values during the course of the
1111 // transformation because we keep maps with HloInstruction* values as
1112 // keys.
1113 if (best->users().empty()) {
1114 VLOG(2) << best->name() << " is now dead";
1115 if (ContainsKey(remat_move_instructions, best)) {
1116 // Previously, 'best' was a rematerialization which killed the
1117 // instruction it was a copying of. Now 'remat' is a rematerialization
1118 // of 'best' and kills 'best'. Stop rematerializing this instruction
1119 // to avoid an infinite loop.
1120 instruction_list.Blacklist(remat);
1121 }
1122 remat_move_instructions.insert(remat);
1123 } else {
1124 net_instructions_added++;
1125 }
1126
1127 VLOG(1) << "memory_usage after rematerialization = "
1128 << HumanReadableNumBytes(memory_tracker.memory_usage());
1129 }
1130
1131 const CallSite* callsite = call_graph_node.GetCallSite(instruction);
1132 if (callsite != nullptr &&
1133 callsite->context() == CallContext::kSequential &&
1134 memory_tracker.memory_usage() + callee_usage > memory_limit_bytes) {
1135 // Memory usage exceeds the limit. Try to rematerialize any
1136 // subcomputation(s) that this instruction calls.
1137 VLOG(1) << "Memory usage still over the limit ("
1138 << (memory_tracker.memory_usage() + callee_usage) << " > "
1139 << memory_limit_bytes
1140 << "). Rematerializing computations called by "
1141 << instruction->name();
1142
1143 // Recompute callee usage to account for any rematerialization performed
1144 // in the callee computations.
1145 for (HloComputation* called_computation :
1146 callsite->called_computations()) {
1147 if (!ContainsKey(rematerialized_computations_, called_computation)) {
1148 // Memory limit for the subcomputation is the memory limit less the
1149 // amount of memory used at this point in the computation.
1150 int64 subcomputation_memory_limit_bytes = std::max<int64>(
1151 0, memory_limit_bytes - memory_tracker.memory_usage());
1152 TF_ASSIGN_OR_RETURN(
1153 bool subcomputation_changed,
1154 RematerializeComputation(called_computation, schedule,
1155 subcomputation_memory_limit_bytes));
1156 changed |= subcomputation_changed;
1157 }
1158 }
1159 TF_ASSIGN_OR_RETURN(callee_usage,
1160 CalledComputationsMemoryUsage(instruction));
1161 }
1162
1163 peak_memory = std::max<int64>(peak_memory,
1164 memory_tracker.memory_usage() + callee_usage);
1165 VLOG(3) << "peak memory usage = " << HumanReadableNumBytes(peak_memory);
1166
1167 TF_RETURN_IF_ERROR(memory_tracker.EndInstruction());
1168 }
1169
1170 // Verify some invariants on the memory tracker.
1171 CHECK_EQ(memory_tracker.memory_usage(), 0);
1172 for (auto* instruction : computation->instructions()) {
1173 CHECK(memory_tracker.IsPlaced(instruction)) << instruction->name();
1174 }
1175
1176 VLOG(1) << "In computation " << computation->name() << " rematerialized "
1177 << remat_count << " instructions; " << net_instructions_added
1178 << " net instructions added";
1179 VLOG(1) << " peak memory usage now " << HumanReadableNumBytes(peak_memory)
1180 << " (was "
1181 << HumanReadableNumBytes(computation_peak_memory_.at(computation))
1182 << ")";
1183
1184 // Update peak memory used by computation.
1185 computation_peak_memory_.at(computation) = peak_memory;
1186
1187 // Update order to include rematerialized instructions.
1188 HloInstructionSequence& sequence = schedule->GetOrCreateSequence(computation);
1189 sequence.clear();
1190 for (auto* item = instruction_list.first(); item != nullptr;
1191 item = instruction_list.next(item)) {
1192 HloInstruction* instruction = item->instruction;
1193 sequence.push_back(instruction);
1194 }
1195 rematerialized_computations_.insert(computation);
1196
1197 instructions_rematerialized_ += remat_count;
1198 net_instructions_added_ += net_instructions_added;
1199
1200 return changed;
1201 }
1202
Run(HloModule * module)1203 StatusOr<bool> HloRematerialization::Run(HloModule* module) {
1204 VLOG(1) << "HloRematerialization() with memory limit of "
1205 << HumanReadableNumBytes(memory_limit_bytes_);
1206 XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString());
1207
1208 // Initialize pass object state.
1209 computation_peak_memory_.clear();
1210 rematerialized_computations_.clear();
1211 instructions_rematerialized_ = 0;
1212 net_instructions_added_ = 0;
1213
1214 TF_RET_CHECK(module->has_schedule());
1215 TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module));
1216
1217 // Adjust memory limit to account for the output of the entry
1218 // computation. This is necessary because the per-computation accounting in
1219 // MemoryUsageTracker do not include output as these are typically allocated
1220 // by the caller.
1221 int64 module_output_size = 0;
1222 ShapeUtil::ForEachSubshape(
1223 module->result_shape(),
1224 [&module_output_size, this](const Shape& subshape,
1225 const ShapeIndex& /*index*/) {
1226 module_output_size += size_function_(subshape);
1227 });
1228
1229 const int64 adjusted_memory_limit_bytes =
1230 memory_limit_bytes_ - module_output_size;
1231 VLOG(1) << "Adjusted memory limit accounting for output ("
1232 << HumanReadableNumBytes(module_output_size)
1233 << "): " << HumanReadableNumBytes(adjusted_memory_limit_bytes);
1234
1235 // Compute peak memory usage of all computations in the module called in a
1236 // sequential context.
1237 call_graph_ = CallGraph::Build(module);
1238 TF_RETURN_IF_ERROR(call_graph_->VisitNodes(
1239 [this, module](const CallGraphNode& node) -> Status {
1240 if (node.context() == CallContext::kSequential) {
1241 TF_ASSIGN_OR_RETURN(
1242 computation_peak_memory_[node.computation()],
1243 ComputePeakMemory(node.computation(), module->schedule().sequence(
1244 node.computation())));
1245 }
1246 return Status::OK();
1247 },
1248 /*visit_unreachable_nodes=*/false));
1249
1250 // The peak memory usage of the module equals the peak memory use of the entry
1251 // computation plus the output size of the computation. This is because the
1252 // peak memory for a computation does not include the output as this is
1253 // typically accounted for in the caller.
1254 const int64 before_peak_memory =
1255 computation_peak_memory_.at(module->entry_computation()) +
1256 module_output_size;
1257 VLOG(1) << "Peak memory usage of module (before): "
1258 << HumanReadableNumBytes(before_peak_memory);
1259
1260 // Subcomputations called by the entry computation will also be
1261 // rematerialized.
1262 TF_ASSIGN_OR_RETURN(
1263 bool changed,
1264 RematerializeComputation(module->entry_computation(), &module->schedule(),
1265 adjusted_memory_limit_bytes));
1266
1267 // Rematerialization can introduce dead code. This occurs if all uses of an
1268 // instruction are replaced with rematerializations of the instruction.
1269 TF_ASSIGN_OR_RETURN(bool dead_code_removed, HloDCE().Run(module));
1270 changed |= dead_code_removed;
1271
1272 // After DCE, the module sequence may include instructions which no longer
1273 // exist.
1274 TF_RETURN_IF_ERROR(module->schedule().Update());
1275 VLOG(1) << "Rematerialized " << instructions_rematerialized_
1276 << " instructions in module " << module->name() << "; "
1277 << net_instructions_added_ << " net instructions added";
1278 const int64 current_peak_memory =
1279 computation_peak_memory_.at(module->entry_computation()) +
1280 module_output_size;
1281 VLOG(1) << "Peak memory usage of module now "
1282 << HumanReadableNumBytes(current_peak_memory) << " ("
1283 << current_peak_memory << " bytes), was "
1284 << HumanReadableNumBytes(before_peak_memory) << " ("
1285 << before_peak_memory << " bytes)";
1286 const int64 reduced_peak_memory = before_peak_memory - current_peak_memory;
1287 VLOG(1) << "Reduced peak memory by "
1288 << HumanReadableNumBytes(reduced_peak_memory) << " ("
1289 << reduced_peak_memory << " bytes)";
1290
1291 if (sizes_ != nullptr) {
1292 sizes_->before_bytes = before_peak_memory;
1293 sizes_->after_bytes = current_peak_memory;
1294 }
1295
1296 XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
1297
1298 if (current_peak_memory > memory_limit_bytes_) {
1299 LOG(WARNING) << absl::StrFormat(
1300 "Can't reduce memory use below %s (%d bytes) by rematerialization; "
1301 "only reduced to %s (%d bytes)",
1302 HumanReadableNumBytes(memory_limit_bytes_), memory_limit_bytes_,
1303 HumanReadableNumBytes(current_peak_memory), current_peak_memory);
1304 }
1305
1306 return changed;
1307 }
1308
1309 } // namespace xla
1310