1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17 
18 #include <algorithm>
19 #include <functional>
20 
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Value.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38 
39 namespace xla {
40 
41 using llvm_ir::IrArray;
42 
DefaultAction(const HloInstruction * hlo)43 Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
44   indexed_generators_[hlo] =
45       [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
46     if (llvm::Value* generated_value = FindOrDefault(
47             generated_value_cache_[hlo], index.multidim(), nullptr)) {
48       llvm::BasicBlock* generated_value_bb = nullptr;
49       if (auto* generated_instruction =
50               llvm::dyn_cast<llvm::Instruction>(generated_value)) {
51         generated_value_bb = generated_instruction->getParent();
52       }
53       // Ideally, we should be able to reuse the cached generated value if it
54       // dominates the current insertion block. However, the check for dominance
55       // can be expensive and unreliable when the function is being constructed.
56       //
57       // It's also worth experimenting what if we don't do caching at all.
58       // LLVM's CSE or GVN should be able to easily merge common subexpressions
59       // that would be regenerated without caching. But this might increase the
60       // JIT compilation time.
61       if (generated_value_bb == nullptr ||
62           generated_value_bb == b_->GetInsertBlock()) {
63         VLOG(3) << "The cached generated value is reused.";
64         return generated_value;
65       }
66       VLOG(3) << "The cached generated value can't be reused, because it is in "
67                  "a different BB ("
68               << generated_value_bb->getName().str()
69               << ") from the current insertion block ("
70               << b_->GetInsertBlock()->getName().str() << ").";
71     }
72 
73     TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
74                         elemental_emitter_->MakeElementGenerator(
75                             hlo, indexed_generators_)(index));
76     generated_value_cache_[hlo][index.multidim()] = generated_value;
77     return generated_value;
78   };
79   return Status::OK();
80 }
81 
HandleConstant(const HloInstruction * constant)82 Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) {
83   unsigned global_address_space =
84       llvm_ir::GetGlobalMemoryAddressSpace(*module_);
85   indexed_generators_[constant] = [=](const IrArray::Index& index) {
86     const Literal& literal = constant->literal();
87     llvm::Constant* initializer =
88         llvm_ir::ConvertLiteralToIrConstant(literal, module_);
89     llvm::GlobalVariable* global = new llvm::GlobalVariable(
90         *b_->GetInsertBlock()->getModule(), initializer->getType(),
91         /*isConstant=*/true,
92         /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
93         /*Initializer=*/initializer,
94         /*Name=*/"", /*InsertBefore=*/nullptr,
95         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
96         /*AddressSpace=*/global_address_space,
97         /*isExternallyInitialized=*/false);
98 
99     global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
100     llvm::Constant* shape_constant =
101         llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
102             global,
103             llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
104     return IrArray(shape_constant, constant->shape())
105         .EmitReadArrayElement(index, b_, constant->name());
106   };
107 
108   return Status::OK();
109 }
110 
HandleGetTupleElement(const HloInstruction * get_tuple_element)111 Status FusedIrEmitter::HandleGetTupleElement(
112     const HloInstruction* get_tuple_element) {
113   return InternalError("Tuple parameters are not supported for fusion");
114 }
115 
HandleParameter(const HloInstruction * parameter)116 Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
117   if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
118     return InvalidArgument("Unbound parameter: %s", parameter->ToString());
119   }
120   return Status::OK();
121 }
122 
HandleTuple(const HloInstruction * tuple)123 Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
124   absl::Span<HloInstruction* const> operands(tuple->operands());
125   std::vector<llvm::Type*> operand_elemental_ir_types;
126   for (HloInstruction* operand : operands) {
127     operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
128         operand->shape().element_type(), module_));
129   }
130   indexed_generators_[tuple] =
131       [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
132     llvm::Value* ret = llvm::UndefValue::get(
133         llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
134     for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
135       TF_ASSIGN_OR_RETURN(llvm::Value * val_i,
136                           indexed_generators_[operands[i]](index));
137       ret = b_->CreateInsertValue(ret, val_i, i);
138     }
139     return ret;
140   };
141   return Status::OK();
142 }
143 
IsFusedIrEmitterInefficient(const HloInstruction * consumer,const HloInstruction * producer)144 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
145     const HloInstruction* consumer, const HloInstruction* producer) {
146   if (consumer->opcode() != HloOpcode::kFusion) {
147     return false;
148   }
149   FusionNodeIndexingEvaluation eval_consumer(consumer);
150   if (producer->opcode() != HloOpcode::kFusion) {
151     return eval_consumer.CodeDuplicationTooHigh(producer);
152   }
153   // If 'producer' is a fusion node as well, also evaluate it. Pass the
154   // evaluated duplication of the fusion node if it is merged into consumer.
155   FusionNodeIndexingEvaluation eval_producer(
156       producer, eval_consumer.EvaluateEmittedInstructions(producer));
157   return eval_producer.MaxCodeDuplicationTooHigh();
158 }
159 
GetGenerator(const HloInstruction * instruction)160 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
161     const HloInstruction* instruction) {
162   std::vector<const HloInstruction*> stack;
163   stack.push_back(instruction);
164   while (!stack.empty()) {
165     const HloInstruction* instr = stack.back();
166     stack.pop_back();
167     if (indexed_generators_.count(instr)) {
168       continue;
169     }
170     for (const HloInstruction* operand : instr->operands()) {
171       stack.push_back(operand);
172     }
173     switch (instr->opcode()) {
174       case HloOpcode::kConstant:
175         TF_RETURN_IF_ERROR(HandleConstant(instr));
176         break;
177       case HloOpcode::kGetTupleElement:
178         TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
179         break;
180       case HloOpcode::kParameter:
181         TF_RETURN_IF_ERROR(HandleParameter(instr));
182         break;
183       case HloOpcode::kTuple:
184         TF_RETURN_IF_ERROR(HandleTuple(instr));
185         break;
186       default:
187         TF_RETURN_IF_ERROR(DefaultAction(instr));
188         break;
189     }
190     CHECK(indexed_generators_.count(instr));
191   }
192   return indexed_generators_.at(instruction);
193 }
194 
195 }  // namespace xla
196