1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <functional>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
41 using llvm_ir::IrArray;
42 
DefaultAction(HloInstruction * hlo)43 Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
44   indexed_generators_[hlo] =
45       [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
46     if (generated_value_cache_[hlo].contains(index.multidim())) {
47       llvm::Value* generated_value =
48           generated_value_cache_[hlo][index.multidim()];
49       llvm::BasicBlock* generated_value_bb = nullptr;
50       if (auto* generated_instruction =
51               llvm::dyn_cast<llvm::Instruction>(generated_value)) {
52         generated_value_bb = generated_instruction->getParent();
53       }
54       // Ideally, we should be able to reuse the cached generated value if it
55       // dominates the current insertion block. However, the check for dominance
56       // can be expensive and unreliable when the function is being constructed.
57       //
58       // It's also worth experimenting what if we don't do caching at all.
59       // LLVM's CSE or GVN should be able to easily merge common subexpressions
60       // that would be regenerated without caching. But this might increase the
61       // JIT compilation time.
62       if (generated_value_bb == nullptr ||
63           generated_value_bb == b_->GetInsertBlock()) {
64         VLOG(3) << "The cached generated value is reused.";
65         return generated_value;
66       }
67       VLOG(3) << "The cached generated value can't be reused, because it is in "
68                  "a different BB ("
69               << generated_value_bb->getName().str()
70               << ") from the current insertion block ("
71               << b_->GetInsertBlock()->getName().str() << ").";
72     }
73 
74     TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()],
75                         elemental_emitter_->MakeElementGenerator(
76                             hlo, indexed_generators_)(index));
77     return generated_value_cache_[hlo][index.multidim()];
78   };
79   return Status::OK();
80 }
81 
HandleConstant(HloInstruction * constant)82 Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
83   indexed_generators_[constant] = [=](const IrArray::Index& index) {
84     const Literal& literal = constant->literal();
85     llvm::Constant* initializer =
86         llvm_ir::ConvertLiteralToIrConstant(literal, module_);
87     llvm::GlobalVariable* global = new llvm::GlobalVariable(
88         *b_->GetInsertBlock()->getModule(), initializer->getType(),
89         /*isConstant=*/true,
90         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
91         /*Initializer=*/initializer,
92         /*Name=*/"");
93     global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
94     llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
95         global,
96         llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
97     return IrArray(shape_constant, constant->shape())
98         .EmitReadArrayElement(index, b_);
99   };
100 
101   return Status::OK();
102 }
103 
HandleGetTupleElement(HloInstruction * get_tuple_element)104 Status FusedIrEmitter::HandleGetTupleElement(
105     HloInstruction* get_tuple_element) {
106   auto emit_tuple_element_ptr = [=]() -> StatusOr<llvm::Value*> {
107     const HloInstruction* tuple_operand = get_tuple_element->operand(0);
108     llvm::Value* tuple_ptr;
109     if (tuple_operand->opcode() == HloOpcode::kGetTupleElement) {
110       TF_ASSIGN_OR_RETURN(tuple_ptr, non_indexed_generators_[tuple_operand]());
111     } else {
112       if (tuple_operand->opcode() != HloOpcode::kParameter) {
113         return Unimplemented(
114             "GetTupleElement fusion currently only supports parameter or "
115             "nested"
116             "GetTupleElement as tuple operand, found an exception: %s",
117             tuple_operand->name());
118       }
119       tuple_ptr =
120           GetBasePointerForFusedParameter(tuple_operand->parameter_number());
121     }
122 
123     // Lookup tuple element pointer.
124     return llvm_ir::EmitGetTupleElement(get_tuple_element->shape(),
125                                         get_tuple_element->tuple_index(),
126                                         /*alignment=*/1, tuple_ptr, b_);
127   };
128 
129   if (!get_tuple_element->shape().IsTuple()) {
130     indexed_generators_[get_tuple_element] =
131         [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
132       // TODO(b/34080002) Add aliasing information to tuple element IrArray.
133       TF_ASSIGN_OR_RETURN(llvm::Value * tuple_element_ptr,
134                           emit_tuple_element_ptr());
135       return IrArray(tuple_element_ptr, get_tuple_element->shape())
136           .EmitReadArrayElement(index, b_);
137     };
138   } else {
139     non_indexed_generators_[get_tuple_element] = emit_tuple_element_ptr;
140   }
141   return Status::OK();
142 }
143 
HandleParameter(HloInstruction * parameter)144 Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
145   indexed_generators_[parameter] =
146       [=](const IrArray::Index& index) -> llvm::Value* {
147     if (tiled_parameter_info_) {
148       if (llvm::Value* param_tile_buffer =
149               tiled_parameter_info_->GetBufferForParameter(
150                   parameter->parameter_number())) {
151         // TODO(jlebar): Add AA metadata to this load.  Tile buffers are global
152         // variables, so LLVM's points-to analysis doesn't help us much.  And we
153         // want the AA info to be present before address spaces are inferred
154         // (which is pretty late in the pipeline), so even if we had
155         // address-space-based AA in LLVM, it wouldn't help us much here.
156         return b_->CreateLoad(
157             b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
158                                               tiled_parameter_info_->x(),
159                                               tiled_parameter_info_->y()}),
160             "tiled_buffer");
161       }
162     }
163     return GetIrArrayForFusedParameter(parameter->parameter_number())
164         .EmitReadArrayElement(index, b_);
165   };
166   return Status::OK();
167 }
168 
HandleTuple(HloInstruction * tuple)169 Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
170   absl::Span<HloInstruction* const> operands(tuple->operands());
171   std::vector<llvm::Type*> operand_elemental_ir_types;
172   for (HloInstruction* operand : operands) {
173     operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
174         operand->shape().element_type(), module_));
175   }
176   indexed_generators_[tuple] =
177       [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
178     llvm::Value* ret = llvm::UndefValue::get(
179         llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
180     for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
181       TF_ASSIGN_OR_RETURN(llvm::Value * val_i,
182                           indexed_generators_[operands[i]](index));
183       ret = b_->CreateInsertValue(ret, val_i, i);
184     }
185     return ret;
186   };
187   return Status::OK();
188 }
189 
FinishVisit(HloInstruction * root)190 Status FusedIrEmitter::FinishVisit(HloInstruction* root) {
191   fused_root_ = root;
192   return Status::OK();
193 }
194 
GetRootGenerator() const195 FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const {
196   CHECK_NE(nullptr, fused_root_)
197       << "GetRootGenerator should be called after Accept.";
198   return indexed_generators_.at(fused_root_);
199 }
200 
GetGenerator(const HloInstruction * instruction) const201 FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
202     const HloInstruction* instruction) const {
203   return indexed_generators_.at(instruction);
204 }
205 
IsFusedIrEmitterInefficient(const HloInstruction * consumer,const HloInstruction * producer)206 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
207     const HloInstruction* consumer, const HloInstruction* producer) {
208   if (consumer->opcode() != HloOpcode::kFusion) {
209     return false;
210   }
211   // Collects for each instruction in the fusion node from which (indirect)
212   // users newly created index values are passed. Roughly speaking, we reuse
213   // index values if the shapes are equal when ignoring the element type (we may
214   // reuse also if the shape change is a bitcast, but we don't consider that
215   // here). By ignoring potential reuses our estimate whether the fusion emitter
216   // is inefficient is a bit more conservative than necessary.
217   absl::flat_hash_map<const HloInstruction*,
218                       absl::flat_hash_set<const HloInstruction*>>
219       indexing_users;
220   // Stores the number of different index accesses for each instruction in the
221   // fusion node. The fusion emitter caches access with the same index, so this
222   // value indicates how many times a specific instruction will be emitted.
223   absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
224   index_usage_count[consumer] = 1;
225 
226   auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
227                                          const HloInstruction* fusion) {
228     auto postorder =
229         fusion->fused_instructions_computation()->MakeInstructionPostOrder();
230     std::reverse(postorder.begin(), postorder.end());
231     for (const auto* instruction : postorder) {
232       if (instruction->opcode() == HloOpcode::kParameter) {
233         continue;
234       }
235       int64& total = index_usage_count[instruction];
236       if (indexing_users[instruction].empty()) {
237         total = index_usage_count[fusion];
238       } else {
239         total = 0;
240         for (const auto* user : indexing_users[instruction]) {
241           int64 weight = 1;
242           // Concatenate is special: the index differs for each operand, so
243           // in the worst case we have to deal with as many index values as
244           // the number of operands of Concatenate. By considering the worst
245           // case, we are more conservative than necessary regarding
246           // refusing to fuse.
247           if (user->opcode() == HloOpcode::kConcatenate) {
248             weight = user->operand_count();
249           }
250           total += index_usage_count[user] * weight;
251         }
252       }
253       for (const auto* operand : instruction->operands()) {
254         // For simplicity we assume that all shape and layout changing
255         // operations invalidate index reuse.
256         if (Shape::Equal().IgnoreElementType()(operand->shape(),
257                                                instruction->shape())) {
258           // If the index is reused, it means the operand gets index values
259           // from the same set of (indirect) users as 'instruction' itself.
260           indexing_users[operand].insert(indexing_users[instruction].begin(),
261                                          indexing_users[instruction].end());
262         } else {
263           // If the index is not reused, it means 'instruction' computes a
264           // new index derived from the index it gets.
265           indexing_users[operand].insert(instruction);
266         }
267       }
268     }
269   };
270   evaluate_fusion_computation(consumer);
271 
272   // Also account for the 'producer' if it would be fused. Find the operand it
273   // corresponds to.
274   for (int64 operand_num = 0; operand_num < consumer->operand_count();
275        ++operand_num) {
276     if (consumer->operand(operand_num) == producer) {
277       auto instruction = consumer->fused_parameter(operand_num);
278       int64& total = index_usage_count[producer];
279       total = 0;
280       for (const auto* user : indexing_users[instruction]) {
281         total += index_usage_count[user];
282       }
283       break;
284     }
285   }
286 
287   // If 'producer' is a fusion node as well, also evaluate it.
288   if (producer->opcode() == HloOpcode::kFusion) {
289     evaluate_fusion_computation(producer);
290   }
291 
292   // Sum up the total number of emitted ops.
293   int64 total = 0;
294   for (const auto& entry : index_usage_count) {
295     total += entry.second;
296   }
297 
298   // Check that the code duplication has at most a factor of 8 (where 8 is an
299   // arbitrary constant that seems to work).
300   return total > 8 * index_usage_count.size();
301 }
302 
303 }  // namespace xla
304