1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <vector>
18
19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
20
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instructions.h"
25 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
33 #include "tensorflow/core/lib/core/status.h"
34
35 namespace xla {
36 namespace gpu {
37
IrEmitterNested(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)38 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
39 const HloComputation& nested_computation,
40 IrEmitterContext* ir_emitter_context)
41 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) {
42 std::vector<const HloInstruction*> io_hlos;
43 emitted_function_ =
44 EmitBasePointersForNestedComputation(nested_computation, &io_hlos);
45 }
46
EmitBasePointersForNestedComputation(const HloComputation & nested_computation,std::vector<const HloInstruction * > * io_hlos)47 llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
48 const HloComputation& nested_computation,
49 std::vector<const HloInstruction*>* io_hlos) {
50 std::vector<llvm::Type*> argument_types;
51 std::vector<int64> argument_dereferenceable_bytes;
52 for (const HloInstruction* param :
53 nested_computation.parameter_instructions()) {
54 io_hlos->push_back(param);
55 const Shape& param_shape = param->shape();
56 argument_types.push_back(
57 llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
58 int64 param_size =
59 llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
60 argument_dereferenceable_bytes.push_back(param_size);
61 }
62 {
63 const HloInstruction* root = nested_computation.root_instruction();
64 io_hlos->push_back(root);
65 const Shape& root_shape = root->shape();
66 argument_types.push_back(
67 llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
68 int64 root_size = llvm_ir::ByteSizeOf(
69 root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
70 argument_dereferenceable_bytes.push_back(root_size);
71 }
72 // The base pointer of the memory block for all pre-allocated temp buffers.
73 argument_types.push_back(b_.getInt8PtrTy());
74
75 llvm::FunctionType* function_type =
76 llvm::FunctionType::get(b_.getVoidTy(), argument_types, false);
77 llvm::Function* function = llvm::Function::Create(
78 function_type, // The function type.
79 llvm::GlobalValue::InternalLinkage, // The linkage type.
80 ir_emitter_context_->name_uniquer()->GetUniqueName(
81 llvm_ir::SanitizeFunctionName(
82 nested_computation.name())), // The name of the function.
83 ir_emitter_context_->llvm_module()); // The parent LLVM module.
84 for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
85 ++arg_no) {
86 int64 arg_size = argument_dereferenceable_bytes[arg_no];
87 if (arg_size > 0) {
88 function->addDereferenceableAttr(arg_no + 1, arg_size);
89 }
90 }
91
92 // TODO(b/65380986): Investigate if adding fast math flags for generated
93 // kernels makes sense.
94
95 llvm::BasicBlock* entry_bb =
96 llvm::BasicBlock::Create(function->getContext(), "entry", function);
97 // Emit a "return void" at entry_bb's end, and sets the insert point before
98 // that return instruction.
99 b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb));
100
101 std::vector<const HloInstruction*> non_io_hlos;
102 for (const auto* hlo : nested_computation.instructions()) {
103 if (hlo->opcode() != HloOpcode::kParameter &&
104 hlo != nested_computation.root_instruction()) {
105 non_io_hlos.push_back(hlo);
106 }
107 }
108 bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
109 return function;
110 }
111
HandleParameter(HloInstruction * parameter)112 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
113 return Status::OK();
114 }
115
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)116 Status IrEmitterNested::EmitTargetElementLoop(
117 const HloInstruction& hlo,
118 const llvm_ir::ElementGenerator& element_generator) {
119 // For MOF we give the loop emitter an array for every output it should
120 // generate.
121 if (hlo.IsMultiOutputFusion()) {
122 std::vector<llvm_ir::IrArray> target_arrays =
123 ConstructIrArrayForOutputs(hlo);
124 TF_RETURN_IF_ERROR(
125 llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
126 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
127 return Status::OK();
128 }
129 return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
130 .EmitLoop();
131 }
132
133 } // namespace gpu
134 } // namespace xla
135