1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <vector>
18 
19 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h"
20 
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/IRBuilder.h"
24 #include "llvm/IR/Instructions.h"
25 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/service/name_uniquer.h"
33 #include "tensorflow/core/lib/core/status.h"
34 
35 namespace xla {
36 namespace gpu {
37 
IrEmitterNested(const HloModuleConfig & hlo_module_config,const HloComputation & nested_computation,IrEmitterContext * ir_emitter_context)38 IrEmitterNested::IrEmitterNested(const HloModuleConfig& hlo_module_config,
39                                  const HloComputation& nested_computation,
40                                  IrEmitterContext* ir_emitter_context)
41     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/true) {
42   std::vector<const HloInstruction*> io_hlos;
43   emitted_function_ =
44       EmitBasePointersForNestedComputation(nested_computation, &io_hlos);
45 }
46 
EmitBasePointersForNestedComputation(const HloComputation & nested_computation,std::vector<const HloInstruction * > * io_hlos)47 llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation(
48     const HloComputation& nested_computation,
49     std::vector<const HloInstruction*>* io_hlos) {
50   std::vector<llvm::Type*> argument_types;
51   std::vector<int64> argument_dereferenceable_bytes;
52   for (const HloInstruction* param :
53        nested_computation.parameter_instructions()) {
54     io_hlos->push_back(param);
55     const Shape& param_shape = param->shape();
56     argument_types.push_back(
57         llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo());
58     int64 param_size =
59         llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout());
60     argument_dereferenceable_bytes.push_back(param_size);
61   }
62   {
63     const HloInstruction* root = nested_computation.root_instruction();
64     io_hlos->push_back(root);
65     const Shape& root_shape = root->shape();
66     argument_types.push_back(
67         llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo());
68     int64 root_size = llvm_ir::ByteSizeOf(
69         root_shape, ir_emitter_context_->llvm_module()->getDataLayout());
70     argument_dereferenceable_bytes.push_back(root_size);
71   }
72   // The base pointer of the memory block for all pre-allocated temp buffers.
73   argument_types.push_back(b_.getInt8PtrTy());
74 
75   llvm::FunctionType* function_type =
76       llvm::FunctionType::get(b_.getVoidTy(), argument_types, false);
77   llvm::Function* function = llvm::Function::Create(
78       function_type,                       // The function type.
79       llvm::GlobalValue::InternalLinkage,  // The linkage type.
80       ir_emitter_context_->name_uniquer()->GetUniqueName(
81           llvm_ir::SanitizeFunctionName(
82               nested_computation.name())),  // The name of the function.
83       ir_emitter_context_->llvm_module());  // The parent LLVM module.
84   for (size_t arg_no = 0; arg_no < argument_dereferenceable_bytes.size();
85        ++arg_no) {
86     int64 arg_size = argument_dereferenceable_bytes[arg_no];
87     if (arg_size > 0) {
88       function->addDereferenceableAttr(arg_no + 1, arg_size);
89     }
90   }
91 
92   // TODO(b/65380986): Investigate if adding fast math flags for generated
93   // kernels makes sense.
94 
95   llvm::BasicBlock* entry_bb =
96       llvm::BasicBlock::Create(function->getContext(), "entry", function);
97   // Emit a "return void" at entry_bb's end, and sets the insert point before
98   // that return instruction.
99   b_.SetInsertPoint(llvm::ReturnInst::Create(function->getContext(), entry_bb));
100 
101   std::vector<const HloInstruction*> non_io_hlos;
102   for (const auto* hlo : nested_computation.instructions()) {
103     if (hlo->opcode() != HloOpcode::kParameter &&
104         hlo != nested_computation.root_instruction()) {
105       non_io_hlos.push_back(hlo);
106     }
107   }
108   bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos);
109   return function;
110 }
111 
HandleParameter(HloInstruction * parameter)112 Status IrEmitterNested::HandleParameter(HloInstruction* parameter) {
113   return Status::OK();
114 }
115 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)116 Status IrEmitterNested::EmitTargetElementLoop(
117     const HloInstruction& hlo,
118     const llvm_ir::ElementGenerator& element_generator) {
119   // For MOF we give the loop emitter an array for every output it should
120   // generate.
121   if (hlo.IsMultiOutputFusion()) {
122     std::vector<llvm_ir::IrArray> target_arrays =
123         ConstructIrArrayForOutputs(hlo);
124     TF_RETURN_IF_ERROR(
125         llvm_ir::LoopEmitter(element_generator, target_arrays, &b_).EmitLoop());
126     llvm_ir::EmitTuple(GetIrArray(hlo, hlo), target_arrays, &b_);
127     return Status::OK();
128   }
129   return llvm_ir::LoopEmitter(element_generator, GetIrArray(hlo, hlo), &b_)
130       .EmitLoop();
131 }
132 
133 }  // namespace gpu
134 }  // namespace xla
135