1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33 namespace gpu {
34
35 using absl::StrAppend;
36 using absl::StrCat;
37
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)38 void HloToIrBindings::EmitBasePointersForHlos(
39 absl::Span<const HloInstruction* const> io_hlos,
40 absl::Span<const HloInstruction* const> non_io_hlos) {
41 // I/O HLOs are bound to the arguments of the current IR function. I.e.,
42 //
43 // void IrFunction(io_0, io_1, ..., io_{m-1}, temp_buffer_base) {
44 llvm::Function* function = b_->GetInsertBlock()->getParent();
45 CHECK_EQ(io_hlos.size() + 1, function->arg_size());
46
47 // An HLO can have duplicated operands. This data structure remembers which
48 // operand HLOs are already bound to avoid rebinding the same HLO.
49 absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
50 auto arg_iter = function->arg_begin();
51 for (const HloInstruction* io_hlo : io_hlos) {
52 if (!already_bound_for_this_function.contains(io_hlo)) {
53 if (!is_nested_ && io_hlo->opcode() == HloOpcode::kGetTupleElement) {
54 BindHloToIrValue(*io_hlo, EmitGetTupleElement(io_hlo, &*arg_iter));
55 } else {
56 BindHloToIrValue(*io_hlo, &*arg_iter);
57 }
58 already_bound_for_this_function.insert(io_hlo);
59 }
60 ++arg_iter;
61 }
62
63 temp_buffer_base_ = &*arg_iter;
64 temp_buffer_base_->setName("temp_buffer");
65
66 for (const HloInstruction* non_io_hlo : non_io_hlos) {
67 if (already_bound_for_this_function.contains(non_io_hlo)) {
68 continue;
69 }
70 already_bound_for_this_function.insert(non_io_hlo);
71
72 if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
73 if (!is_nested_) {
74 // Lookup allocation GetTupleElement operand.
75 const BufferAllocation::Slice slice =
76 buffer_assignment_
77 ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor())
78 .ConsumeValueOrDie();
79 // We are not in a nested context, so check non-thread-local allocation.
80 CHECK(!slice.allocation()->is_thread_local());
81 const int64 offset = slice.offset();
82 CHECK_NE(nullptr, temp_buffer_base_);
83 // Emit IR for GetTupleElement instruction and bind to emitted value.
84 llvm::Value* base_ptr =
85 b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset));
86 BindHloToIrValue(*non_io_hlo,
87 EmitGetTupleElement(non_io_hlo, base_ptr));
88 }
89 continue;
90 }
91
92 if (!buffer_assignment_->HasTopLevelAllocation(non_io_hlo)) {
93 continue;
94 }
95
96 ShapeUtil::ForEachSubshape(
97 non_io_hlo->shape(),
98 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
99 // A non-IO HLO with a buffer is bound to
100 // (1) an alloca if it is thread-local, or
101 // (2) an internal pointer in temp_buffer_base according to its
102 // offset.
103 auto slice_result =
104 buffer_assignment_->GetUniqueSlice(non_io_hlo, index);
105 if (!slice_result.ok()) {
106 return;
107 }
108 const BufferAllocation::Slice slice =
109 slice_result.ConsumeValueOrDie();
110 if (slice.allocation()->is_thread_local()) {
111 llvm::Type* pointee_type =
112 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
113 BindHloToIrValue(*non_io_hlo, b_->CreateAlloca(pointee_type),
114 index);
115 } else if (slice.allocation()->is_constant()) {
116 llvm::Value* global_for_constant = module_->getGlobalVariable(
117 llvm_ir::ConstantBufferAllocationToGlobalName(
118 *slice.allocation()));
119 BindHloToIrValue(*non_io_hlo, global_for_constant);
120 } else {
121 const int64 offset = slice.offset();
122 CHECK_NE(nullptr, temp_buffer_base_);
123 BindHloToIrValue(
124 *non_io_hlo,
125 b_->CreateInBoundsGEP(temp_buffer_base_, b_->getInt64(offset)),
126 index);
127 }
128 });
129 }
130 }
131
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)132 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
133 llvm::Value* base_ptr) {
134 // TODO(b/26344050): tighten the alignment based on the real element type.
135 if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
136 return llvm_ir::EmitGetTupleElement(
137 gte->shape(), gte->tuple_index(), /*alignment=*/1,
138 GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
139 }
140 return llvm_ir::EmitGetTupleElement(
141 gte->shape(), gte->tuple_index(), /*alignment=*/1,
142 EmitGetTupleElement(gte->operand(0), base_ptr), b_);
143 }
144
145 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)146 static bool HasMeaningfulName(llvm::Value* value) {
147 if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
148 return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
149 }
150 return false;
151 }
152
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)153 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
154 ShapeIndexView shape_index,
155 llvm::Value* ir_value) {
156 llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
157 ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
158 llvm::Type* dest_type = pointee_type->getPointerTo();
159
160 llvm::Value* typed_ir_value;
161 if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
162 typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
163 llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
164 } else {
165 typed_ir_value = b_->CreateBitCast(ir_value, pointee_type->getPointerTo());
166 }
167 if (!HasMeaningfulName(ir_value)) {
168 ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
169 }
170 if (!HasMeaningfulName(typed_ir_value)) {
171 typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
172 }
173 return typed_ir_value;
174 }
175
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)176 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
177 llvm::Value* ir_value,
178 ShapeIndexView shape_index) {
179 VLOG(2) << "Binding " << hlo.ToString();
180
181 const Shape& hlo_shape = hlo.shape();
182 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
183
184 if (!BoundToIrValue(hlo)) {
185 // Set the root of ShapeTree first before assigning the element ir value.
186 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
187 }
188 *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
189 }
190
191 // Determines whether hlo's buffers are never modified within the execution of
192 // consumer.
BuffersInvariantWithinConsumer(const HloInstruction & hlo,const HloInstruction & consumer,const BufferAssignment * buffer_assignment)193 static bool BuffersInvariantWithinConsumer(
194 const HloInstruction& hlo, const HloInstruction& consumer,
195 const BufferAssignment* buffer_assignment) {
196 // Check if consumer is inside a fusion node -- if so, "dereference" it until
197 // we get to a non-fusion node.
198 const HloInstruction* c = &consumer;
199 while (c->IsFused()) {
200 c = c->parent()->FusionInstruction();
201 }
202
203 // If, after dereferencing c, we end up with a node that's not inside our
204 // module's top-level computation (say our node is inside a while loop), we
205 // give up on marking array as invariant, because this HLO may be run multiple
206 // times (e.g. multiple while loop iterations, or multiple invocations of a
207 // reducer's computation). TODO(jlebar): We could relax this constraint if we
208 // emitted an llvm.invariant.group.barrier at the end of the computation.
209 return c->parent() == c->GetModule()->entry_computation() &&
210 buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
211 }
212
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)213 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
214 const HloInstruction& consumer,
215 const ShapeIndex& shape_index) {
216 llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
217 CHECK_NE(base_ptr, nullptr)
218 << "Buffer not assigned for shape_index " << shape_index.ToString()
219 << " of " << hlo.ToString();
220 llvm_ir::IrArray ir_array(base_ptr,
221 ShapeUtil::GetSubshape(hlo.shape(), shape_index));
222 alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array, shape_index);
223
224 // The GPU backend emits one kernel per top-level HLO, and LLVM views
225 // execution of one kernel as the "whole program" executed on the GPU.
226 // Therefore if hlo's output buffer is not modified within consumer, and if
227 // consumer runs hlo only once (so that it doesn't create two different
228 // outputs), then we can mark ir_array as invariant over the whole program.
229 if (BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
230 VLOG(2) << "Marking " << hlo.name() << " as invariant within "
231 << consumer.name();
232 ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
233 }
234
235 return ir_array;
236 }
237
UnbindAllLocalIrValues()238 void HloToIrBindings::UnbindAllLocalIrValues() {
239 std::vector<const HloInstruction*> hlos_to_unbind;
240 for (auto& key_value : base_ptrs_) {
241 if (!llvm::isa<llvm::GlobalVariable>(
242 (key_value.second.element({}))->stripPointerCasts())) {
243 hlos_to_unbind.push_back(key_value.first);
244 }
245 }
246 for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
247 VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
248 base_ptrs_.erase(hlo_to_unbind);
249 }
250 }
251
ToString() const252 string HloToIrBindings::ToString() const {
253 string s = StrCat("** HloToIrBindings **\n");
254 StrAppend(&s, " is_nested_=", is_nested_, "\n");
255 StrAppend(&s,
256 " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
257 "\n");
258
259 if (base_ptrs_.empty()) {
260 return s;
261 }
262
263 // Iterate over all computations in the module in topological order, and print
264 // out the base pointers we have in each computation in topological order.
265 for (const HloComputation* computation :
266 base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
267 bool is_first = true;
268 for (const HloInstruction* instr :
269 computation->MakeInstructionPostOrder()) {
270 auto it = base_ptrs_.find(instr);
271 if (it == base_ptrs_.end()) {
272 continue;
273 }
274 if (is_first) {
275 StrAppend(&s, " Base pointers for computation ", computation->name(),
276 ":\n");
277 is_first = false;
278 }
279 StrAppend(&s, " ", instr->ToString());
280
281 const ShapeTree<llvm::Value*>& shape_tree = it->second;
282 if (!instr->shape().IsTuple()) {
283 const llvm::Value* val = shape_tree.begin()->second;
284 StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
285 continue;
286 }
287
288 StrAppend(&s, "\n");
289 for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
290 ++shape_it) {
291 llvm::Value* val = shape_it->second;
292 StrAppend(&s, " ", shape_it->first.ToString(), " -> ",
293 (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
294 "\n");
295 }
296 }
297 }
298 return s;
299 }
300
301 } // namespace gpu
302 } // namespace xla
303