1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 
18 #include "tensorflow/compiler/xla/service/cpu/ir_function.h"
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
22 #include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
23 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 
26 namespace xla {
27 namespace cpu {
28 
GetComputeFunctionParams(llvm::Module * llvm_module,const int64 num_dynamic_loop_bounds)29 static std::vector<llvm::Type*> GetComputeFunctionParams(
30     llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) {
31   llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext());
32   llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
33   llvm::Type* i64_ptr_type =
34       llvm::Type::getInt64PtrTy(llvm_module->getContext());
35   std::vector<llvm::Type*> compute_function_params(
36       {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
37   if (num_dynamic_loop_bounds > 0) {
38     compute_function_params.push_back(i64_ptr_type);
39   }
40   compute_function_params.push_back(i64_ptr_type);
41   return compute_function_params;
42 }
43 
IrFunction(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config,llvm::Module * llvm_module,llvm::IRBuilder<> * b,int64 num_dynamic_loop_bounds)44 IrFunction::IrFunction(const string& function_name,
45                        llvm::Function::LinkageTypes linkage,
46                        const HloModuleConfig& module_config,
47                        llvm::Module* llvm_module, llvm::IRBuilder<>* b,
48                        int64 num_dynamic_loop_bounds)
49     : b_(b),
50       llvm_module_(llvm_module),
51       caller_insert_point_guard_(*b),
52       num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
53   Initialize(function_name, linkage, module_config);
54 }
55 
~IrFunction()56 IrFunction::~IrFunction() {
57   // Emit function return value.
58   b_->CreateRetVoid();
59 }
60 
GetDynamicLoopBounds()61 DynamicLoopBounds IrFunction::GetDynamicLoopBounds() {
62   DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_);
63   for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
64     dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0);
65     dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1);
66   }
67   return dynamic_loop_bounds;
68 }
69 
Initialize(const string & function_name,llvm::Function::LinkageTypes linkage,const HloModuleConfig & module_config)70 void IrFunction::Initialize(const string& function_name,
71                             llvm::Function::LinkageTypes linkage,
72                             const HloModuleConfig& module_config) {
73   // The function signature is:
74   //   void function(i8* retval, i8* run_options, i8** params, i8**
75   //   buffer_table,
76   //                 i64* dynamic_loop_bounds, i64* prof_counters)
77   //
78   // For thread local functions:
79   //   retval: points to the returned value.
80   //   params: address of an array with pointers to parameters.
81   //   buffer_table: is null
82   //
83   // For global functions:
84   //   retval: is null
85   //   params: is null
86   //   buffer_table: address of an array with pointers to temporary buffers and
87   //     entry computation parameters (but not to constant buffers).
88   //
89   // Therefore, the generated function's signature (FunctionType) is statically
90   // determined - parameter unpacking is done in code generated into the
91   // function, rather than by a prologue dictated by the platform ABI.
92   //
93   //                      /--------------\
94   //   retval ----------> | return value |
95   //                      \--------------/
96   //
97   //                      /-------------------------------\
98   //   run_options -----> | xla::ExecutableRunOptions |
99   //                      \-------------------------------/
100   //
101   //                     /---------------------------------------------\
102   //   params -------->  |  param 0  |  param 1  | ..... |  param N-1  |
103   //                     |   addr    |   addr    |       |   addr      |
104   //                     \---------------------------------------------/
105   //                          |           |                   |
106   //                          |           |                   |
107   //                          V           V                   V
108   //                     /---------\  /---------\         /-----------\
109   //                     | param 0 |  | param 1 |         | param N-1 |
110   //                     \---------/  \---------/         \-----------/
111   //
112   //                     /---------------------------------------------\
113   //   buffer_table--->  |  buff  0  |  guff  1  | ..... |  buff  N-1  |
114   //                     |   addr    |   addr    |       |   addr      |
115   //                     \---------------------------------------------/
116   //                          |           |                   |
117   //                          |           |                   |
118   //                          V           V                   V
119   //                     /---------\  /---------\         /-----------\
120   //                     | temp  0 |  | temp  1 |         | temp  N-1 |
121   //                     \---------/  \---------/         \-----------/
122   //
123   //                        /--------------------------------------------\
124   // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
125   //  (elided for aot)      \--------------------------------------------/
126   //
127   //                     /---------------------------------------------\
128   //   prof counters ->  | counter 0 | counter 1 | ..... | counter N-1 |
129   //                     \---------------------------------------------/
130 
131   // Even though the type of params and buffer_table is void** in the host's
132   // view, in LLVM IR this is represented by i8*, similarly to void*. It's up to
133   // the code to use GEPs to unravel the indirection layers.
134   llvm::FunctionType* function_type = llvm::FunctionType::get(
135       /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
136       /*Params=*/
137       GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_),
138       /*isVarArg=*/false);
139 
140   // Functions with local linkage get an inlining bonus.  Because we know
141   // a-priori that embedded functions (non-entry functions) will not have its
142   // name resolved, give it local linkage.
143   function_ = llvm_ir::CreateCpuFunction(function_type, linkage, module_config,
144                                          function_name, llvm_module_);
145 
146   // Set meaningful names for the function's arguments: useful for debugging.
147   llvm::Function::arg_iterator arg_iter = function_->arg_begin();
148   arg_iter->setName("retval");
149   result_arg_ = &*arg_iter;
150   (++arg_iter)->setName("run_options");
151   exec_run_options_arg_ = &*arg_iter;
152   (++arg_iter)->setName("params");
153   parameters_arg_ = &*arg_iter;
154   (++arg_iter)->setName("buffer_table");
155   buffer_table_arg_ = &*arg_iter;
156   if (num_dynamic_loop_bounds_ > 0) {
157     (++arg_iter)->setName("dynamic_loop_bounds");
158     dynamic_loop_bounds_arg_ = &*arg_iter;
159   }
160   (++arg_iter)->setName("prof_counters");
161   profile_counters_arg_ = &*arg_iter;
162 
163   // We know a-priori that the function arguments are guaranteed to point to
164   // disjoint objects.
165   llvm::Argument* retval = result_arg();
166   for (llvm::Argument& argument : function_->args()) {
167     // However, the return buffer aliases the temporaries and thus cannot be
168     // marked noalias.
169     if (&argument == retval) {
170       continue;
171     }
172     function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
173   }
174 
175   b_->SetInsertPoint(llvm::BasicBlock::Create(
176       /*Context=*/llvm_module_->getContext(),
177       /*Name=*/"entry",
178       /*Parent=*/function_));
179 }
180 
GetDynamicLoopBound(const int64 offset)181 llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
182   CHECK_GT(num_dynamic_loop_bounds_, 0);
183   CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
184   string name = absl::StrCat("dynamic_loop_bound_", offset);
185   return b_->CreateLoad(b_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
186                                       b_->getInt64(offset), name));
187 }
188 
EncodeArrayFunctionArguments(absl::Span<llvm::Value * const> arguments,absl::string_view name,llvm::IRBuilder<> * b)189 llvm::Value* EncodeArrayFunctionArguments(
190     absl::Span<llvm::Value* const> arguments, absl::string_view name,
191     llvm::IRBuilder<>* b) {
192   llvm::Value* arguments_buffer;
193   llvm::Type* int8ptr_ty = b->getInt8PtrTy();
194   if (arguments.empty()) {
195     arguments_buffer = llvm::Constant::getNullValue(int8ptr_ty->getPointerTo());
196   } else {
197     arguments_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
198         int8ptr_ty, b->getInt32(arguments.size()),
199         absl::StrCat(name, "_parameter_addresses"), b);
200 
201     for (size_t i = 0; i < arguments.size(); i++) {
202       llvm::Value* parameter_as_i8ptr = b->CreateBitCast(
203           arguments[i], b->getInt8PtrTy(),
204           absl::StrCat(name, "_parameter_", i, "_address_as_i8ptr"));
205       llvm::Value* slot_in_param_addresses =
206           b->CreateInBoundsGEP(arguments_buffer, {b->getInt64(i)});
207       b->CreateStore(parameter_as_i8ptr, slot_in_param_addresses);
208     }
209   }
210   return arguments_buffer;
211 }
212 
213 // Emits code to allocate an array of parameter address pointers, and store
214 // each address from 'parameter_addresses'.
215 // Returns an array of compute function call arguments (including parameter
216 // address buffer).
GetArrayFunctionCallArguments(absl::Span<llvm::Value * const> parameter_addresses,llvm::IRBuilder<> * b,absl::string_view name,llvm::Value * return_value_buffer,llvm::Value * exec_run_options_arg,llvm::Value * buffer_table_arg,llvm::Value * profile_counters_arg)217 std::vector<llvm::Value*> GetArrayFunctionCallArguments(
218     absl::Span<llvm::Value* const> parameter_addresses, llvm::IRBuilder<>* b,
219     absl::string_view name, llvm::Value* return_value_buffer,
220     llvm::Value* exec_run_options_arg, llvm::Value* buffer_table_arg,
221     llvm::Value* profile_counters_arg) {
222   llvm::Value* parameter_addresses_buffer =
223       EncodeArrayFunctionArguments(parameter_addresses, name, b);
224 
225   const auto to_int8_ptr = [=](llvm::Value* ptr) {
226     return b->CreatePointerCast(ptr, b->getInt8PtrTy());
227   };
228   std::vector<llvm::Value*> arguments{
229       to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg),
230       parameter_addresses_buffer, buffer_table_arg};
231   if (profile_counters_arg != nullptr) {
232     arguments.push_back(profile_counters_arg);
233   }
234   return arguments;
235 }
236 
237 // Emits a call to a runtime fork/join function which dispatches parallel
238 // calls to 'parallel_function' (and joins threads before returning).
EmitCallToParallelForkJoin(const std::vector<llvm::Value * > & arguments,const Shape & shape,const std::vector<int64> & dimension_partition_counts,llvm::IRBuilder<> * b,llvm::Function * parallel_function,const string & name)239 Status EmitCallToParallelForkJoin(
240     const std::vector<llvm::Value*>& arguments, const Shape& shape,
241     const std::vector<int64>& dimension_partition_counts, llvm::IRBuilder<>* b,
242     llvm::Function* parallel_function, const string& name) {
243   llvm::Module* module = b->GetInsertBlock()->getModule();
244 
245   // Build ParallelForkJoin function type.
246   std::vector<llvm::Type*> compute_function_params =
247       GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0);
248   // Number of parallel compute functions.
249   compute_function_params.push_back(b->getInt32Ty());
250   // Array of partitions. There is an array element for each
251   // partition x partition_dim x 2 (for dimension start and limit).
252   compute_function_params.push_back(
253       llvm::Type::getInt64PtrTy(module->getContext()));
254   // Number of partitioned most-major dimensions in 'shape'.
255   compute_function_params.push_back(b->getInt32Ty());
256   // Function pointer for compute function to be dispatched in parallel.
257   compute_function_params.push_back(
258       llvm::Type::getInt8PtrTy(module->getContext()));
259 
260   llvm::FunctionType* fork_join_type = llvm::FunctionType::get(
261       /*Result=*/llvm::Type::getVoidTy(module->getContext()),
262       /*Params=*/compute_function_params,
263       /*isVarArg=*/false);
264 
265   llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
266       module
267           ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
268                                 fork_join_type)
269           .getCallee());
270   fork_join_func->setCallingConv(llvm::CallingConv::C);
271   fork_join_func->setDoesNotThrow();
272 
273   // Add common compute function arguments.
274   std::vector<llvm::Value*> fork_join_arguments(arguments);
275 
276   // Create ShapePartitionIterator to generate all partitions of 'shape'.
277   ShapePartitionIterator partition_iterator(shape, dimension_partition_counts);
278   const int64 num_partitions = partition_iterator.GetTotalPartitionCount();
279   // Add argument specifying the number of parallel partitions.
280   fork_join_arguments.push_back(b->getInt32(num_partitions));
281 
282   // The number of partitioned most-major dimensions in 'shape'.
283   const int32 num_partitioned_dims = dimension_partition_counts.size();
284   // A dimension partition consists of two elements: [start_index, limit_index).
285   const int32 dim_partition_size = 2;
286   // Calculate array partition stride.
287   const int32 array_partition_stride =
288       num_partitioned_dims * dim_partition_size;
289   // Calculate the total number of elements in the partition array.
290   const int32 partition_array_size =
291       dim_partition_size * num_partitioned_dims * num_partitions;
292 
293   // Store dimension partition values as llvm constants in 'partitions'.
294   // See comments in runtime_fork_join.cc for array layout description.
295   std::vector<llvm::Constant*> partitions(partition_array_size);
296   for (int32 i = 0; i < num_partitions; ++i) {
297     std::vector<std::pair<int64, int64>> dim_partitions =
298         partition_iterator.GetPartition(i);
299     CHECK_EQ(num_partitioned_dims, dim_partitions.size());
300     const int32 partition_index = i * array_partition_stride;
301     for (int32 j = 0; j < num_partitioned_dims; ++j) {
302       const std::pair<int64, int64>& dim_partition = dim_partitions[j];
303       const int32 index = partition_index + j * dim_partition_size;
304       // Store partition [dim_start, dim_limit) intervals for each dimension.
305       partitions[index] = b->getInt64(dim_partition.first);
306       partitions[index + 1] =
307           b->getInt64(dim_partition.first + dim_partition.second);
308     }
309   }
310 
311   // Create global variable out of dimension partitions in 'partitions'.
312   llvm::ArrayType* partitions_array_type =
313       llvm::ArrayType::get(b->getInt64Ty(), partition_array_size);
314   llvm::Constant* partitions_array =
315       llvm::ConstantArray::get(partitions_array_type, partitions);
316   llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable(
317       /*M=*/*module,
318       /*Ty=*/partitions_array_type,
319       /*isConstant=*/true,
320       /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
321       /*Initializer=*/partitions_array,
322       /*Name=*/
323       absl::StrCat(name, "_parallel_dimension_partitions"));
324 
325   // Add argument specifying parallel dimension partitions.
326   fork_join_arguments.push_back(
327       b->CreateBitCast(global_partitions_array,
328                        llvm::Type::getInt64PtrTy(module->getContext())));
329   // Add argument specifying the number of partitioned most-major dimensions.
330   fork_join_arguments.push_back(b->getInt32(num_partitioned_dims));
331   // Add argument for parallel compute function pointer.
332   fork_join_arguments.push_back(
333       b->CreateBitCast(parallel_function, b->getInt8PtrTy()));
334   // Emit call to parallel fork/join.
335   b->CreateCall(fork_join_func, fork_join_arguments);
336 
337   return Status::OK();
338 }
339 
340 }  // namespace cpu
341 }  // namespace xla
342