1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/logical_buffer_analysis.h"
17 
18 #include <utility>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace xla {
26 
27 namespace {
28 
29 // Gather fusion instructions from 'instruction' into 'fusion_instructions'.
GatherFusionInstructions(HloInstruction * instruction,std::vector<HloInstruction * > * fusion_instructions)30 void GatherFusionInstructions(
31     HloInstruction* instruction,
32     std::vector<HloInstruction*>* fusion_instructions) {
33   CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
34   for (auto* fused : instruction->fused_instructions()) {
35     if (fused->opcode() == HloOpcode::kFusion) {
36       GatherFusionInstructions(fused, fusion_instructions);
37     }
38   }
39   fusion_instructions->push_back(instruction);
40 }
41 
42 }  // namespace
43 
44 /* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>>
Run(const HloModule * module)45 LogicalBufferAnalysis::Run(const HloModule* module) {
46   std::unique_ptr<LogicalBufferAnalysis> analysis(
47       new LogicalBufferAnalysis(module));
48   TF_RETURN_IF_ERROR(analysis->Analyze());
49   return std::move(analysis);
50 }
51 
Analyze()52 Status LogicalBufferAnalysis::Analyze() {
53   // Empirically we usually have a few more logical buffers than instructions,
54   // so reserve 10% more than the number of instructions to avoid frequent
55   // resizes.
56   logical_buffers_.clear();
57   logical_buffers_.reserve((module_->instruction_count() * 11) / 10);
58 
59   // We filter out fusion computations, and get to them through fusion
60   // instructions. This is because it's possible to have orphaned (unreachable)
61   // fusion computations, and we don't want to try to assign buffers to those.
62   std::vector<HloInstruction*> fusion_instructions;
63   for (auto* computation : module_->MakeNonfusionComputations()) {
64     TF_RETURN_IF_ERROR(computation->Accept(this));
65     for (auto* instruction : computation->instructions()) {
66       if (instruction->opcode() != HloOpcode::kFusion) {
67         continue;
68       }
69       GatherFusionInstructions(instruction, &fusion_instructions);
70     }
71   }
72   for (auto* instruction : fusion_instructions) {
73     TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
74   }
75   return Status::OK();
76 }
77 
GetBuffer(LogicalBuffer::Id id) const78 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(LogicalBuffer::Id id) const {
79   CHECK_GE(id, 0);
80   CHECK_LT(id, logical_buffers_.size());
81   return *logical_buffers_[id];
82 }
83 
GetBuffer(HloInstruction * instruction,const ShapeIndex & index) const84 LogicalBuffer& LogicalBufferAnalysis::GetBuffer(HloInstruction* instruction,
85                                                 const ShapeIndex& index) const {
86   return *output_buffers_.at(std::make_pair(instruction, index));
87 }
88 
NewLogicalBuffer(HloInstruction * instruction,const ShapeIndex & index)89 void LogicalBufferAnalysis::NewLogicalBuffer(HloInstruction* instruction,
90                                              const ShapeIndex& index) {
91   CHECK_EQ(logical_buffers_.size(), next_buffer_id_);
92   logical_buffers_.emplace_back(
93       absl::make_unique<LogicalBuffer>(instruction, index, next_buffer_id_));
94   output_buffers_[std::make_pair(instruction, index)] =
95       logical_buffers_.back().get();
96 
97   ++next_buffer_id_;
98 }
99 
DefaultAction(HloInstruction * hlo_instruction)100 Status LogicalBufferAnalysis::DefaultAction(HloInstruction* hlo_instruction) {
101   // Create a logical buffer for each output of the instruction.
102   ShapeUtil::ForEachSubshape(
103       hlo_instruction->shape(),
104       [this, hlo_instruction](const Shape& shape, const ShapeIndex& index) {
105         NewLogicalBuffer(hlo_instruction, index);
106       });
107 
108   return Status::OK();
109 }
110 
HandleGetTupleElement(HloInstruction *)111 Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) {
112   // GetTupleElement does not create buffers.
113   return Status::OK();
114 }
115 
HandleAddDependency(HloInstruction * add_dependency)116 Status LogicalBufferAnalysis::HandleAddDependency(
117     HloInstruction* add_dependency) {
118   // AddDependency just forwards the value of its zero-th operand and does not
119   // create buffers.
120   return Status::OK();
121 }
122 
HandleCopy(HloInstruction * copy)123 Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) {
124   // The top-level buffer (index={}) for kCopy is newly created, but all other
125   // buffers (in the case of a tuple shape) come from the operand
126   NewLogicalBuffer(copy, /*index=*/{});
127   return Status::OK();
128 }
129 
HandleBitcast(HloInstruction *)130 Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) {
131   // A kBitcast instruction aliases its operand. That is, the buffer of its
132   // result *is* the buffer of its operand.
133   return Status::OK();
134 }
135 
HandleDomain(HloInstruction *)136 Status LogicalBufferAnalysis::HandleDomain(HloInstruction*) {
137   // A kDomain instruction aliases its operand. That is, the buffer of its
138   // result *is* the buffer of its operand.
139   return Status::OK();
140 }
141 
HandleRecvDone(HloInstruction * recv_done)142 Status LogicalBufferAnalysis::HandleRecvDone(HloInstruction* recv_done) {
143   // RecvDone produces a two-element tuple containing the data value (which
144   // aliases part of its operand) and a token. Only the tuple index table and
145   // the token are defined by the RecvDone.
146   NewLogicalBuffer(recv_done, /*index=*/{});
147   NewLogicalBuffer(recv_done, /*index=*/{1});
148   return Status::OK();
149 }
150 
HandleSend(HloInstruction * send)151 Status LogicalBufferAnalysis::HandleSend(HloInstruction* send) {
152   // Send creates new buffers for the top-level tuple, the context (tuple
153   // element at {1}), and the token (tuple element at {2}). Tuple element at {0}
154   // is an alias of the Send operand, so we don't need to create a new Logical
155   // Buffer for that.
156   NewLogicalBuffer(send, /*index=*/{});
157   NewLogicalBuffer(send, /*index=*/{1});
158   NewLogicalBuffer(send, /*index=*/{2});
159   return Status::OK();
160 }
161 
HandleTuple(HloInstruction * tuple)162 Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) {
163   // A Tuple instruction only creates the top-level buffer.
164   NewLogicalBuffer(tuple, /*index=*/{});
165   return Status::OK();
166 }
167 
HandleTupleSelect(HloInstruction * tuple_select)168 Status LogicalBufferAnalysis::HandleTupleSelect(HloInstruction* tuple_select) {
169   // Select allocates a new buffer and then shallow copies the on_true or
170   // on_false buffer into this new buffer.
171   NewLogicalBuffer(tuple_select, /*index=*/{});
172   return Status::OK();
173 }
174 
175 }  // namespace xla
176