1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
17 
18 #include <vector>
19 
20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
32 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/types.h"
43 
44 namespace xla {
45 namespace llvm_ir {
46 
47 namespace {
48 
49 // Adds the inner comparison loop body where we compare elements.
EmitCompareLoopBody(int64 iteration_bound,int64 num_values,llvm::Value * element_pair_index,int64 xor_mask,llvm::Type * index_type,std::function<llvm::Value * (int64 operand,llvm::Value * index)> element_address,std::function<void (int64 operand,llvm::Value * index,llvm::Value * value)> write_element,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b,bool needs_bounds_checks=true)50 Status EmitCompareLoopBody(
51     int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index,
52     int64 xor_mask, llvm::Type* index_type,
53     std::function<llvm::Value*(int64 operand, llvm::Value* index)>
54         element_address,
55     std::function<void(int64 operand, llvm::Value* index, llvm::Value* value)>
56         write_element,
57     const EmitCallToNestedComputationCallback& emit_compare_callback,
58     llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
59   auto index_typed_constant = [&](int64 value) {
60     return llvm::ConstantInt::get(index_type, value);
61   };
62   // The 'xor_mask' determines which elements are compared against each other.
63   // Index 'current_keys_index' will be compared with 'current_keys_index' xor
64   // 'xor_mask'. This means that we will always compare a block of consecutive
65   // elements against elements from the adjacent block of the same size. When
66   // 'xor_mask' is a power of 2, it immediately identifies the size of such a
67   // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
68   // that case, we essentially flip the last 'k' - 1 bits when computing the
69   // position of the element to compare to, so the block size is 2^(k - 1).
70   int64 block_size = xor_mask;
71   // Check if it is a value 2^k - 1.
72   if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
73     block_size = (xor_mask + 1) / 2;
74   }
75   auto current_keys_index = element_pair_index;
76   if (block_size == 1) {
77     // If the block size is 1, we take every second element and compare it to
78     // the next one.
79     current_keys_index =
80         b->CreateMul(current_keys_index, index_typed_constant(2));
81   } else if (block_size * 2 < iteration_bound) {
82     // current_keys_index iterates through the 'left' elements of the element
83     // pairs to be compared. We first need to compute the comparison block to
84     // which the element belongs. The block id of that block is index /
85     // block_size.
86     auto block_id =
87         b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
88     // The index of the 'left' element within its block is simply the remainder
89     // when dividing by 'block_size'.
90     auto index_within_block =
91         b->CreateURem(current_keys_index, index_typed_constant(block_size));
92     // The first element of the 'left' block of elements that is compared
93     // against elements from the adjacent 'right' block of elements is
94     // 'block_id' * (2 * 'block_size').
95     auto first_element_in_block =
96         b->CreateMul(block_id, index_typed_constant(2 * block_size));
97     current_keys_index =
98         b->CreateAdd(first_element_in_block, index_within_block);
99   }
100   auto compare_keys_index =
101       b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
102   // current_keys_index < compare_keys_index
103   llvm::Value* is_smaller_index =
104       b->CreateICmpSLT(current_keys_index, compare_keys_index);
105   // compare_keys_index < iteration_bound
106   llvm::Value* index_is_inbounds = b->CreateICmpSLT(
107       compare_keys_index, index_typed_constant(iteration_bound));
108   llvm::Value* do_comparison =
109       needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
110                           : b->getInt1(true);
111 
112   // if (is_smaller_index && index_is_inbounds)
113   KernelSupportLibrary ksl(b);
114   return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
115     std::vector<llvm::Value*> values_to_compare;
116     for (int i = 0; i < num_values; ++i) {
117       values_to_compare.push_back(element_address(i, compare_keys_index));
118       values_to_compare.push_back(element_address(i, current_keys_index));
119     }
120     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
121     llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
122         llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer",
123         b);
124     TF_RETURN_IF_ERROR(
125         emit_compare_callback(values_to_compare, compare_return_buffer));
126     llvm::Value* result = b->CreateLoad(compare_return_buffer);
127 
128     // Check if the 'compare' function returns true.
129     llvm::Value* is_smaller_than =
130         b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
131                         "boolean_predicate");
132     ksl.If("is_smaller_than", is_smaller_than, [&]() {
133       for (int64 i = 0; i < num_values; ++i) {
134         // Swap the values.
135         auto value1 = b->CreateLoad(values_to_compare[i * 2]);
136         auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]);
137         write_element(i, current_keys_index, value1);
138         write_element(i, compare_keys_index, value2);
139       }
140     });
141     return Status::OK();
142   });
143 }
144 
EmitTiledCompareLoop(const IrArray::Index & tiled_keys_index,int64 dimension_to_sort,int64 dimension_to_sort_bound,absl::Span<const int64> xor_masks,const std::vector<IrArray> & params,const std::vector<llvm::Value * > & param_shmem_buffers,int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b)145 Status EmitTiledCompareLoop(
146     const IrArray::Index& tiled_keys_index, int64 dimension_to_sort,
147     int64 dimension_to_sort_bound, absl::Span<const int64> xor_masks,
148     const std::vector<IrArray>& params,
149     const std::vector<llvm::Value*>& param_shmem_buffers, int64 tile_size,
150     const EmitCallToNestedComputationCallback& emit_compare_callback,
151     llvm::IRBuilder<>* b) {
152   KernelSupportLibrary ksl(b);
153   llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic(
154       gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b);
155   llvm_ir::AddRangeMetadata(0, tile_size / 2,
156                             llvm::cast<llvm::Instruction>(thread_id));
157   thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
158                                /*isSigned=*/true, "thread.id.x");
159 
160   auto copy_loop_body =
161       [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
162               read_or_write) {
163         auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
164         auto current_keys_index =
165             b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
166         // We want to copy two adjacent elements. We first check whether the
167         // first index position is within bounds.
168         ksl.If(
169             "smaller_keys_index",
170             b->CreateICmpSLT(current_keys_index,
171                              tiled_keys_index.GetConstantWithIndexType(
172                                  dimension_to_sort_bound)),
173             [&]() {
174               auto cache_index = b->CreateShl(thread_id, value_one);
175               read_or_write(cache_index, current_keys_index);
176               // Increment to go to the next index position.
177               current_keys_index = b->CreateAdd(current_keys_index, value_one);
178               // Here we check whether the next index position is within bounds.
179               ksl.If("inner_smaller_keys_index",
180                      b->CreateICmpSLT(current_keys_index,
181                                       tiled_keys_index.GetConstantWithIndexType(
182                                           dimension_to_sort_bound)),
183                      [&]() {
184                        cache_index = b->CreateAdd(cache_index, value_one);
185                        read_or_write(cache_index, current_keys_index);
186                      });
187             });
188       };
189 
190   // Copy operand tiles from the operand buffers to shared memory.
191   std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
192   for (int64 i = 0; i < params.size(); ++i) {
193     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
194       keys_multi_index[dimension_to_sort] = index;
195       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
196                                 tiled_keys_index.GetType());
197       auto value = params[i].EmitReadArrayElement(keys_index, b);
198       b->CreateStore(value,
199                      b->CreateGEP(param_shmem_buffers[i],
200                                   {tiled_keys_index.GetConstantWithIndexType(0),
201                                    cache_index}));
202     });
203   }
204   // Wait until all reads have happened.
205   gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, b);
206 
207   // Now emit the bodies of the comparison loops.
208   auto element_address = [&](int64 operand, llvm::Value* index) {
209     auto shared_memory_address =
210         b->CreateGEP(param_shmem_buffers[operand],
211                      {tiled_keys_index.GetConstantWithIndexType(0), index});
212     auto ptr_type = shared_memory_address->getType();
213     // We need a generic pointer with address space 0 instead of a pointer to
214     // shared memory (address space 3) so that we can pass it to the comparison
215     // computation.
216     return b->CreateAddrSpaceCast(
217         shared_memory_address,
218         llvm::PointerType::get(ptr_type->getPointerElementType(),
219                                /*AddressSpace=*/0));
220   };
221   auto write_element = [&](int64 operand, llvm::Value* index,
222                            llvm::Value* value) {
223     b->CreateStore(
224         value,
225         b->CreateGEP(param_shmem_buffers[operand],
226                      {tiled_keys_index.GetConstantWithIndexType(0), index}));
227   };
228   for (int64 xor_mask : xor_masks) {
229     // The index of the element pair to be compared within the tile stored in
230     // shared memory. We order the element pairs by the element with the smaller
231     // index.
232     auto element_pair_index = thread_id;
233     // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
234     // need any bounds checks.
235     if (dimension_to_sort_bound % tile_size) {
236       // Otherwise we need a bounds check for the last tile. The last tile has
237       // size 'dimension_to_sort_bound' % 'tile_size'.
238       TF_RETURN_IF_ERROR(ksl.IfWithStatus(
239           "is_last_tile",
240           b->CreateICmpUGE(
241               b->CreateMul(tiled_keys_index[dimension_to_sort],
242                            tiled_keys_index.GetConstantWithIndexType(2)),
243               tiled_keys_index.GetConstantWithIndexType(
244                   RoundDownToNearest(dimension_to_sort_bound, tile_size))),
245           [&]() {
246             return EmitCompareLoopBody(
247                 dimension_to_sort_bound % tile_size, params.size(),
248                 element_pair_index, xor_mask, tiled_keys_index.GetType(),
249                 element_address, write_element, emit_compare_callback, b);
250           },
251           [&]() {
252             return EmitCompareLoopBody(
253                 tile_size, params.size(), element_pair_index, xor_mask,
254                 tiled_keys_index.GetType(), element_address, write_element,
255                 emit_compare_callback, b,
256                 /*needs_bounds_checks=*/false);
257           }));
258     } else {
259       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
260           tile_size, params.size(), element_pair_index, xor_mask,
261           tiled_keys_index.GetType(), element_address, write_element,
262           emit_compare_callback, b,
263           /*needs_bounds_checks=*/false));
264     }
265     // Wait until all comparisons have happened.
266     gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {},
267                                    b);
268   }
269 
270   // Copy the operand tiles back from shared memory to the operand buffers.
271   for (int64 i = 0; i < params.size(); ++i) {
272     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
273       keys_multi_index[dimension_to_sort] = index;
274       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
275                                 tiled_keys_index.GetType());
276       auto value = b->CreateLoad(b->CreateGEP(
277           param_shmem_buffers[i],
278           {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
279       params[i].EmitWriteArrayElement(keys_index, value, b);
280     });
281   }
282   // We should normally synchronize here to make sure all writes have happened.
283   // However the very next thing each thread does is reading 2 elements from the
284   // operand buffer and writing it into the same location in shared memory from
285   // which it previously copied it to the operand buffer, and we synchronize
286   // after this has happened. We can be sure that a thread always writes to the
287   // same location in shared memory because we have exactly tile_size / 2 many
288   // threads, and the linear index calculated by ParallelLoopEmitter uses
289   // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
290   return Status::OK();
291 }
292 }  // namespace
293 
EmitSortInPlace(int64 dimension_to_sort,const std::vector<IrArray> & values_arrays,absl::string_view name,absl::Span<const int64> xor_masks,llvm::IRBuilder<> * b,const gpu::LaunchDimensions & launch_dimensions,int64 num_iterations_in_sort_dim,const int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback)294 Status EmitSortInPlace(
295     int64 dimension_to_sort, const std::vector<IrArray>& values_arrays,
296     absl::string_view name, absl::Span<const int64> xor_masks,
297     llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
298     int64 num_iterations_in_sort_dim, const int64 tile_size,
299     const EmitCallToNestedComputationCallback& emit_compare_callback) {
300   // Iterate through the keys shape in physical order, but skip the dimension to
301   // sort and make it the innermost loop which is the loop where the comparisons
302   // happen. In the dimension to sort, if we use tiling, we iterate through it
303   // in tiles of 64 elements each, so we use another loop that happens within
304   // one thread to process this tile worth of data (thereby combining several
305   // comparison stages of the bitonic sort algorithm because they all happen
306   // within those 64 elements and are therefore independent of the other
307   // comparisons).
308 
309   const Shape& keys_shape = values_arrays[0].GetShape();
310   int64 rank = keys_shape.rank();
311   int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
312   std::vector<int64> dimensions_in_iteration_order(rank);
313   std::vector<int64> iteration_order_to_logical_order(rank);
314   int64 dim = 0;
315   for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) {
316     if (dimension != dimension_to_sort) {
317       dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
318       iteration_order_to_logical_order[dim++] = dimension;
319     }
320   }
321   dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
322   iteration_order_to_logical_order[dim] = dimension_to_sort;
323 
324   Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
325                                                dimensions_in_iteration_order);
326 
327   // Allocate shared memory for the tiled compare loop.
328   std::vector<llvm::Value*> param_shmem_buffers(values_arrays.size(), nullptr);
329   if (xor_masks.size() > 1) {
330     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
331     for (int64 i = 0; i < values_arrays.size(); ++i) {
332       llvm::Type* tile_type = llvm::ArrayType::get(
333           llvm_ir::PrimitiveTypeToIrType(
334               values_arrays[i].GetShape().element_type(), module),
335           tile_size);
336       param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
337           module, tile_type, absl::StrCat(name, "_tile_param_", i));
338     }
339   }
340 
341   auto compare_loop_body_emitter =
342       [&](const IrArray::Index& tiles_index) -> Status {
343     // Naive C++ code for the inner compare loop:
344     //
345     // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
346     //   int64 j = i ^ xor_mask;
347     //   /* emitted in EmitCompareLoopBody() */
348     //   if (i < j && j < dimension_to_sort_bound) {
349     //     int64 min_key = std::min(keys[i], keys[j]);
350     //     keys[j] = std::max(keys[i], keys[j]);
351     //     keys[i] = min_key;
352     //   }
353     // }
354     //
355     // This follows the algorithm described on Wikipedia:
356     // https://en.wikipedia.org/wiki/Bitonic_sorter
357     std::vector<llvm::Value*> keys_multi_index(rank);
358     for (int64 i = 0; i < rank; ++i) {
359       keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
360     }
361     if (xor_masks.size() > 1) {
362       IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
363                                 tiles_index.GetType());
364       TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
365           keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
366           values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
367           b));
368     } else {
369       auto element_address = [&](int64 operand, llvm::Value* index) {
370         keys_multi_index[dimension_to_sort] = index;
371         IrArray::Index keys_index(keys_multi_index,
372                                   values_arrays[operand].GetShape(),
373                                   tiles_index.GetType());
374         return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
375       };
376       auto write_element = [&](int64 operand, llvm::Value* index,
377                                llvm::Value* value) {
378         keys_multi_index[dimension_to_sort] = index;
379         IrArray::Index keys_index(keys_multi_index,
380                                   values_arrays[operand].GetShape(),
381                                   tiles_index.GetType());
382         values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
383       };
384       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
385           dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
386           xor_masks[0], tiles_index.GetType(), element_address, write_element,
387           emit_compare_callback, b));
388     }
389     return Status::OK();
390   };
391   return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
392                                   launch_dimensions, b)
393       .EmitLoop(name);
394 }
395 
396 }  // namespace llvm_ir
397 }  // namespace xla
398