1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
17 
18 #include <vector>
19 
20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
33 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/platform/types.h"
42 
43 namespace xla {
44 namespace llvm_ir {
45 
46 namespace {
47 
48 // Adds the inner comparison loop body where we compare elements.
EmitCompareLoopBody(int64 iteration_bound,int64 num_values,llvm::Value * element_pair_index,int64 xor_mask,llvm::Type * index_type,std::function<llvm::Value * (int64 operand,llvm::Value * index)> element_address,std::function<void (int64 operand,llvm::Value * index,llvm::Value * value)> write_element,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b,bool needs_bounds_checks=true)49 Status EmitCompareLoopBody(
50     int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index,
51     int64 xor_mask, llvm::Type* index_type,
52     std::function<llvm::Value*(int64 operand, llvm::Value* index)>
53         element_address,
54     std::function<void(int64 operand, llvm::Value* index, llvm::Value* value)>
55         write_element,
56     const EmitCallToNestedComputationCallback& emit_compare_callback,
57     llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
58   auto index_typed_constant = [&](int64 value) {
59     return llvm::ConstantInt::get(index_type, value);
60   };
61   // The 'xor_mask' determines which elements are compared against each other.
62   // Index 'current_keys_index' will be compared with 'current_keys_index' xor
63   // 'xor_mask'. This means that we will always compare a block of consecutive
64   // elements against elements from the adjacent block of the same size. When
65   // 'xor_mask' is a power of 2, it immediately identifies the size of such a
66   // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
67   // that case, we essentially flip the last 'k' - 1 bits when computing the
68   // position of the element to compare to, so the block size is 2^(k - 1).
69   int64 block_size = xor_mask;
70   // Check if it is a value 2^k - 1.
71   if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
72     block_size = (xor_mask + 1) / 2;
73   }
74   auto current_keys_index = element_pair_index;
75   if (block_size == 1) {
76     // If the block size is 1, we take every second element and compare it to
77     // the next one.
78     current_keys_index =
79         b->CreateMul(current_keys_index, index_typed_constant(2));
80   } else if (block_size * 2 < iteration_bound) {
81     // current_keys_index iterates through the 'left' elements of the element
82     // pairs to be compared. We first need to compute the comparison block to
83     // which the element belongs. The block id of that block is index /
84     // block_size.
85     auto block_id =
86         b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
87     // The index of the 'left' element within its block is simply the remainder
88     // when dividing by 'block_size'.
89     auto index_within_block =
90         b->CreateURem(current_keys_index, index_typed_constant(block_size));
91     // The first element of the 'left' block of elements that is compared
92     // against elements from the adjacent 'right' block of elements is
93     // 'block_id' * (2 * 'block_size').
94     auto first_element_in_block =
95         b->CreateMul(block_id, index_typed_constant(2 * block_size));
96     current_keys_index =
97         b->CreateAdd(first_element_in_block, index_within_block);
98   }
99   auto compare_keys_index =
100       b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
101   // current_keys_index < compare_keys_index
102   llvm::Value* is_smaller_index =
103       b->CreateICmpSLT(current_keys_index, compare_keys_index);
104   // compare_keys_index < iteration_bound
105   llvm::Value* index_is_inbounds = b->CreateICmpSLT(
106       compare_keys_index, index_typed_constant(iteration_bound));
107   llvm::Value* do_comparison =
108       needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
109                           : b->getInt1(true);
110 
111   // if (is_smaller_index && index_is_inbounds)
112   KernelSupportLibrary ksl(b);
113   return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
114     std::vector<llvm::Value*> values_to_compare;
115     for (int i = 0; i < num_values; ++i) {
116       values_to_compare.push_back(element_address(i, compare_keys_index));
117       values_to_compare.push_back(element_address(i, current_keys_index));
118     }
119     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
120     llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
121         llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer",
122         b);
123     TF_RETURN_IF_ERROR(
124         emit_compare_callback(values_to_compare, compare_return_buffer));
125     llvm::Value* result = b->CreateLoad(compare_return_buffer);
126 
127     // Check if the 'compare' function returns true.
128     llvm::Value* is_smaller_than =
129         b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
130                         "boolean_predicate");
131     ksl.If("is_smaller_than", is_smaller_than, [&]() {
132       for (int64 i = 0; i < num_values; ++i) {
133         // Swap the values.
134         auto value1 = b->CreateLoad(values_to_compare[i * 2]);
135         auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]);
136         write_element(i, current_keys_index, value1);
137         write_element(i, compare_keys_index, value2);
138       }
139     });
140     return Status::OK();
141   });
142 }
143 
EmitTiledCompareLoop(const IrArray::Index & tiled_keys_index,int64 dimension_to_sort,int64 dimension_to_sort_bound,absl::Span<const int64> xor_masks,const std::vector<IrArray> & params,const std::vector<llvm::Value * > & param_shmem_buffers,int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b)144 Status EmitTiledCompareLoop(
145     const IrArray::Index& tiled_keys_index, int64 dimension_to_sort,
146     int64 dimension_to_sort_bound, absl::Span<const int64> xor_masks,
147     const std::vector<IrArray>& params,
148     const std::vector<llvm::Value*>& param_shmem_buffers, int64 tile_size,
149     const EmitCallToNestedComputationCallback& emit_compare_callback,
150     llvm::IRBuilder<>* b) {
151   KernelSupportLibrary ksl(b);
152   llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
153       llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b);
154   llvm_ir::AddRangeMetadata(0, tile_size / 2,
155                             llvm::cast<llvm::Instruction>(thread_id));
156   thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
157                                /*isSigned=*/true, "thread.id.x");
158 
159   auto copy_loop_body =
160       [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
161               read_or_write) {
162         auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
163         auto current_keys_index =
164             b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
165         // We want to copy two adjacent elements. We first check whether the
166         // first index position is within bounds.
167         ksl.If(
168             "smaller_keys_index",
169             b->CreateICmpSLT(current_keys_index,
170                              tiled_keys_index.GetConstantWithIndexType(
171                                  dimension_to_sort_bound)),
172             [&]() {
173               auto cache_index = b->CreateShl(thread_id, value_one);
174               read_or_write(cache_index, current_keys_index);
175               // Increment to go to the next index position.
176               current_keys_index = b->CreateAdd(current_keys_index, value_one);
177               // Here we check whether the next index position is within bounds.
178               ksl.If("inner_smaller_keys_index",
179                      b->CreateICmpSLT(current_keys_index,
180                                       tiled_keys_index.GetConstantWithIndexType(
181                                           dimension_to_sort_bound)),
182                      [&]() {
183                        cache_index = b->CreateAdd(cache_index, value_one);
184                        read_or_write(cache_index, current_keys_index);
185                      });
186             });
187       };
188 
189   // Copy operand tiles from the operand buffers to shared memory.
190   std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
191   for (int64 i = 0; i < params.size(); ++i) {
192     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
193       keys_multi_index[dimension_to_sort] = index;
194       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
195                                 tiled_keys_index.GetType());
196       auto value = params[i].EmitReadArrayElement(keys_index, b);
197       b->CreateStore(value,
198                      b->CreateGEP(param_shmem_buffers[i],
199                                   {tiled_keys_index.GetConstantWithIndexType(0),
200                                    cache_index}));
201     });
202   }
203   // Wait until all reads have happened.
204   llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b);
205 
206   // Now emit the bodies of the comparison loops.
207   auto element_address = [&](int64 operand, llvm::Value* index) {
208     auto shared_memory_address =
209         b->CreateGEP(param_shmem_buffers[operand],
210                      {tiled_keys_index.GetConstantWithIndexType(0), index});
211     auto ptr_type = shared_memory_address->getType();
212     // We need a generic pointer with address space 0 instead of a pointer to
213     // shared memory (address space 3) so that we can pass it to the comparison
214     // computation.
215     return b->CreateAddrSpaceCast(
216         shared_memory_address,
217         llvm::PointerType::get(ptr_type->getPointerElementType(),
218                                /*AddressSpace=*/0));
219   };
220   auto write_element = [&](int64 operand, llvm::Value* index,
221                            llvm::Value* value) {
222     b->CreateStore(
223         value,
224         b->CreateGEP(param_shmem_buffers[operand],
225                      {tiled_keys_index.GetConstantWithIndexType(0), index}));
226   };
227   for (int64 xor_mask : xor_masks) {
228     // The index of the element pair to be compared within the tile stored in
229     // shared memory. We order the element pairs by the element with the smaller
230     // index.
231     auto element_pair_index = thread_id;
232     // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
233     // need any bounds checks.
234     if (dimension_to_sort_bound % tile_size) {
235       // Otherwise we need a bounds check for the last tile. The last tile has
236       // size 'dimension_to_sort_bound' % 'tile_size'.
237       TF_RETURN_IF_ERROR(ksl.IfWithStatus(
238           "is_last_tile",
239           b->CreateICmpUGE(
240               b->CreateMul(tiled_keys_index[dimension_to_sort],
241                            tiled_keys_index.GetConstantWithIndexType(2)),
242               tiled_keys_index.GetConstantWithIndexType(
243                   RoundDownToNearest(dimension_to_sort_bound, tile_size))),
244           [&]() {
245             return EmitCompareLoopBody(
246                 dimension_to_sort_bound % tile_size, params.size(),
247                 element_pair_index, xor_mask, tiled_keys_index.GetType(),
248                 element_address, write_element, emit_compare_callback, b);
249           },
250           [&]() {
251             return EmitCompareLoopBody(
252                 tile_size, params.size(), element_pair_index, xor_mask,
253                 tiled_keys_index.GetType(), element_address, write_element,
254                 emit_compare_callback, b,
255                 /*needs_bounds_checks=*/false);
256           }));
257     } else {
258       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
259           tile_size, params.size(), element_pair_index, xor_mask,
260           tiled_keys_index.GetType(), element_address, write_element,
261           emit_compare_callback, b,
262           /*needs_bounds_checks=*/false));
263     }
264     // Wait until all comparisons have happened.
265     llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b);
266   }
267 
268   // Copy the operand tiles back from shared memory to the operand buffers.
269   for (int64 i = 0; i < params.size(); ++i) {
270     copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
271       keys_multi_index[dimension_to_sort] = index;
272       IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
273                                 tiled_keys_index.GetType());
274       auto value = b->CreateLoad(b->CreateGEP(
275           param_shmem_buffers[i],
276           {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
277       params[i].EmitWriteArrayElement(keys_index, value, b);
278     });
279   }
280   // We should normally synchronize here to make sure all writes have happened.
281   // However the very next thing each thread does is reading 2 elements from the
282   // operand buffer and writing it into the same location in shared memory from
283   // which it previously copied it to the operand buffer, and we synchronize
284   // after this has happened. We can be sure that a thread always writes to the
285   // same location in shared memory because we have exactly tile_size / 2 many
286   // threads, and the linear index calculated by ParallelLoopEmitter uses
287   // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
288   return Status::OK();
289 }
290 }  // namespace
291 
EmitSortInPlace(int64 dimension_to_sort,const std::vector<IrArray> & values_arrays,absl::string_view name,absl::Span<const int64> xor_masks,llvm::IRBuilder<> * b,const gpu::LaunchDimensions & launch_dimensions,int64 num_iterations_in_sort_dim,const int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback)292 Status EmitSortInPlace(
293     int64 dimension_to_sort, const std::vector<IrArray>& values_arrays,
294     absl::string_view name, absl::Span<const int64> xor_masks,
295     llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
296     int64 num_iterations_in_sort_dim, const int64 tile_size,
297     const EmitCallToNestedComputationCallback& emit_compare_callback) {
298   // Iterate through the keys shape in physical order, but skip the dimension to
299   // sort and make it the innermost loop which is the loop where the comparisons
300   // happen. In the dimension to sort, if we use tiling, we iterate through it
301   // in tiles of 64 elements each, so we use another loop that happens within
302   // one thread to process this tile worth of data (thereby combining several
303   // comparison stages of the bitonic sort algorithm because they all happen
304   // within those 64 elements and are therefore independent of the other
305   // comparisons).
306 
307   const Shape& keys_shape = values_arrays[0].GetShape();
308   int64 rank = keys_shape.rank();
309   int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
310   std::vector<int64> dimensions_in_iteration_order(rank);
311   std::vector<int64> iteration_order_to_logical_order(rank);
312   int64 dim = 0;
313   for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) {
314     if (dimension != dimension_to_sort) {
315       dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
316       iteration_order_to_logical_order[dim++] = dimension;
317     }
318   }
319   dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
320   iteration_order_to_logical_order[dim] = dimension_to_sort;
321 
322   Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
323                                                dimensions_in_iteration_order);
324 
325   // Allocate shared memory for the tiled compare loop.
326   std::vector<llvm::Value*> param_shmem_buffers(values_arrays.size(), nullptr);
327   if (xor_masks.size() > 1) {
328     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
329     for (int64 i = 0; i < values_arrays.size(); ++i) {
330       llvm::Type* tile_type = llvm::ArrayType::get(
331           llvm_ir::PrimitiveTypeToIrType(
332               values_arrays[i].GetShape().element_type(), module),
333           tile_size);
334       param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
335           module, tile_type, absl::StrCat(name, "_tile_param_", i));
336     }
337   }
338 
339   auto compare_loop_body_emitter =
340       [&](const IrArray::Index& tiles_index) -> Status {
341     // Naive C++ code for the inner compare loop:
342     //
343     // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
344     //   int64 j = i ^ xor_mask;
345     //   /* emitted in EmitCompareLoopBody() */
346     //   if (i < j && j < dimension_to_sort_bound) {
347     //     int64 min_key = std::min(keys[i], keys[j]);
348     //     keys[j] = std::max(keys[i], keys[j]);
349     //     keys[i] = min_key;
350     //   }
351     // }
352     //
353     // This follows the algorithm described on Wikipedia:
354     // https://en.wikipedia.org/wiki/Bitonic_sorter
355     std::vector<llvm::Value*> keys_multi_index(rank);
356     for (int64 i = 0; i < rank; ++i) {
357       keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
358     }
359     if (xor_masks.size() > 1) {
360       IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
361                                 tiles_index.GetType());
362       TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
363           keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
364           values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
365           b));
366     } else {
367       auto element_address = [&](int64 operand, llvm::Value* index) {
368         keys_multi_index[dimension_to_sort] = index;
369         IrArray::Index keys_index(keys_multi_index,
370                                   values_arrays[operand].GetShape(),
371                                   tiles_index.GetType());
372         return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
373       };
374       auto write_element = [&](int64 operand, llvm::Value* index,
375                                llvm::Value* value) {
376         keys_multi_index[dimension_to_sort] = index;
377         IrArray::Index keys_index(keys_multi_index,
378                                   values_arrays[operand].GetShape(),
379                                   tiles_index.GetType());
380         values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
381       };
382       TF_RETURN_IF_ERROR(EmitCompareLoopBody(
383           dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
384           xor_masks[0], tiles_index.GetType(), element_address, write_element,
385           emit_compare_callback, b));
386     }
387     return Status::OK();
388   };
389   return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
390                                   launch_dimensions, b)
391       .EmitLoop(name);
392 }
393 
394 }  // namespace llvm_ir
395 }  // namespace xla
396