1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
17
18 #include <vector>
19
20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
32 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
33 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/util.h"
40 #include "tensorflow/core/lib/core/status.h"
41 #include "tensorflow/core/platform/types.h"
42
43 namespace xla {
44 namespace llvm_ir {
45
46 namespace {
47
48 // Adds the inner comparison loop body where we compare elements.
EmitCompareLoopBody(int64 iteration_bound,int64 num_values,llvm::Value * element_pair_index,int64 xor_mask,llvm::Type * index_type,std::function<llvm::Value * (int64 operand,llvm::Value * index)> element_address,std::function<void (int64 operand,llvm::Value * index,llvm::Value * value)> write_element,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b,bool needs_bounds_checks=true)49 Status EmitCompareLoopBody(
50 int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index,
51 int64 xor_mask, llvm::Type* index_type,
52 std::function<llvm::Value*(int64 operand, llvm::Value* index)>
53 element_address,
54 std::function<void(int64 operand, llvm::Value* index, llvm::Value* value)>
55 write_element,
56 const EmitCallToNestedComputationCallback& emit_compare_callback,
57 llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
58 auto index_typed_constant = [&](int64 value) {
59 return llvm::ConstantInt::get(index_type, value);
60 };
61 // The 'xor_mask' determines which elements are compared against each other.
62 // Index 'current_keys_index' will be compared with 'current_keys_index' xor
63 // 'xor_mask'. This means that we will always compare a block of consecutive
64 // elements against elements from the adjacent block of the same size. When
65 // 'xor_mask' is a power of 2, it immediately identifies the size of such a
66 // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
67 // that case, we essentially flip the last 'k' - 1 bits when computing the
68 // position of the element to compare to, so the block size is 2^(k - 1).
69 int64 block_size = xor_mask;
70 // Check if it is a value 2^k - 1.
71 if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
72 block_size = (xor_mask + 1) / 2;
73 }
74 auto current_keys_index = element_pair_index;
75 if (block_size == 1) {
76 // If the block size is 1, we take every second element and compare it to
77 // the next one.
78 current_keys_index =
79 b->CreateMul(current_keys_index, index_typed_constant(2));
80 } else if (block_size * 2 < iteration_bound) {
81 // current_keys_index iterates through the 'left' elements of the element
82 // pairs to be compared. We first need to compute the comparison block to
83 // which the element belongs. The block id of that block is index /
84 // block_size.
85 auto block_id =
86 b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
87 // The index of the 'left' element within its block is simply the remainder
88 // when dividing by 'block_size'.
89 auto index_within_block =
90 b->CreateURem(current_keys_index, index_typed_constant(block_size));
91 // The first element of the 'left' block of elements that is compared
92 // against elements from the adjacent 'right' block of elements is
93 // 'block_id' * (2 * 'block_size').
94 auto first_element_in_block =
95 b->CreateMul(block_id, index_typed_constant(2 * block_size));
96 current_keys_index =
97 b->CreateAdd(first_element_in_block, index_within_block);
98 }
99 auto compare_keys_index =
100 b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
101 // current_keys_index < compare_keys_index
102 llvm::Value* is_smaller_index =
103 b->CreateICmpSLT(current_keys_index, compare_keys_index);
104 // compare_keys_index < iteration_bound
105 llvm::Value* index_is_inbounds = b->CreateICmpSLT(
106 compare_keys_index, index_typed_constant(iteration_bound));
107 llvm::Value* do_comparison =
108 needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
109 : b->getInt1(true);
110
111 // if (is_smaller_index && index_is_inbounds)
112 KernelSupportLibrary ksl(b);
113 return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
114 std::vector<llvm::Value*> values_to_compare;
115 for (int i = 0; i < num_values; ++i) {
116 values_to_compare.push_back(element_address(i, compare_keys_index));
117 values_to_compare.push_back(element_address(i, current_keys_index));
118 }
119 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
120 llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
121 llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer",
122 b);
123 TF_RETURN_IF_ERROR(
124 emit_compare_callback(values_to_compare, compare_return_buffer));
125 llvm::Value* result = b->CreateLoad(compare_return_buffer);
126
127 // Check if the 'compare' function returns true.
128 llvm::Value* is_smaller_than =
129 b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
130 "boolean_predicate");
131 ksl.If("is_smaller_than", is_smaller_than, [&]() {
132 for (int64 i = 0; i < num_values; ++i) {
133 // Swap the values.
134 auto value1 = b->CreateLoad(values_to_compare[i * 2]);
135 auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]);
136 write_element(i, current_keys_index, value1);
137 write_element(i, compare_keys_index, value2);
138 }
139 });
140 return Status::OK();
141 });
142 }
143
EmitTiledCompareLoop(const IrArray::Index & tiled_keys_index,int64 dimension_to_sort,int64 dimension_to_sort_bound,absl::Span<const int64> xor_masks,const std::vector<IrArray> & params,const std::vector<llvm::Value * > & param_shmem_buffers,int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b)144 Status EmitTiledCompareLoop(
145 const IrArray::Index& tiled_keys_index, int64 dimension_to_sort,
146 int64 dimension_to_sort_bound, absl::Span<const int64> xor_masks,
147 const std::vector<IrArray>& params,
148 const std::vector<llvm::Value*>& param_shmem_buffers, int64 tile_size,
149 const EmitCallToNestedComputationCallback& emit_compare_callback,
150 llvm::IRBuilder<>* b) {
151 KernelSupportLibrary ksl(b);
152 llvm::Value* thread_id = llvm_ir::EmitCallToIntrinsic(
153 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b);
154 llvm_ir::AddRangeMetadata(0, tile_size / 2,
155 llvm::cast<llvm::Instruction>(thread_id));
156 thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
157 /*isSigned=*/true, "thread.id.x");
158
159 auto copy_loop_body =
160 [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
161 read_or_write) {
162 auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
163 auto current_keys_index =
164 b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
165 // We want to copy two adjacent elements. We first check whether the
166 // first index position is within bounds.
167 ksl.If(
168 "smaller_keys_index",
169 b->CreateICmpSLT(current_keys_index,
170 tiled_keys_index.GetConstantWithIndexType(
171 dimension_to_sort_bound)),
172 [&]() {
173 auto cache_index = b->CreateShl(thread_id, value_one);
174 read_or_write(cache_index, current_keys_index);
175 // Increment to go to the next index position.
176 current_keys_index = b->CreateAdd(current_keys_index, value_one);
177 // Here we check whether the next index position is within bounds.
178 ksl.If("inner_smaller_keys_index",
179 b->CreateICmpSLT(current_keys_index,
180 tiled_keys_index.GetConstantWithIndexType(
181 dimension_to_sort_bound)),
182 [&]() {
183 cache_index = b->CreateAdd(cache_index, value_one);
184 read_or_write(cache_index, current_keys_index);
185 });
186 });
187 };
188
189 // Copy operand tiles from the operand buffers to shared memory.
190 std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
191 for (int64 i = 0; i < params.size(); ++i) {
192 copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
193 keys_multi_index[dimension_to_sort] = index;
194 IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
195 tiled_keys_index.GetType());
196 auto value = params[i].EmitReadArrayElement(keys_index, b);
197 b->CreateStore(value,
198 b->CreateGEP(param_shmem_buffers[i],
199 {tiled_keys_index.GetConstantWithIndexType(0),
200 cache_index}));
201 });
202 }
203 // Wait until all reads have happened.
204 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b);
205
206 // Now emit the bodies of the comparison loops.
207 auto element_address = [&](int64 operand, llvm::Value* index) {
208 auto shared_memory_address =
209 b->CreateGEP(param_shmem_buffers[operand],
210 {tiled_keys_index.GetConstantWithIndexType(0), index});
211 auto ptr_type = shared_memory_address->getType();
212 // We need a generic pointer with address space 0 instead of a pointer to
213 // shared memory (address space 3) so that we can pass it to the comparison
214 // computation.
215 return b->CreateAddrSpaceCast(
216 shared_memory_address,
217 llvm::PointerType::get(ptr_type->getPointerElementType(),
218 /*AddressSpace=*/0));
219 };
220 auto write_element = [&](int64 operand, llvm::Value* index,
221 llvm::Value* value) {
222 b->CreateStore(
223 value,
224 b->CreateGEP(param_shmem_buffers[operand],
225 {tiled_keys_index.GetConstantWithIndexType(0), index}));
226 };
227 for (int64 xor_mask : xor_masks) {
228 // The index of the element pair to be compared within the tile stored in
229 // shared memory. We order the element pairs by the element with the smaller
230 // index.
231 auto element_pair_index = thread_id;
232 // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
233 // need any bounds checks.
234 if (dimension_to_sort_bound % tile_size) {
235 // Otherwise we need a bounds check for the last tile. The last tile has
236 // size 'dimension_to_sort_bound' % 'tile_size'.
237 TF_RETURN_IF_ERROR(ksl.IfWithStatus(
238 "is_last_tile",
239 b->CreateICmpUGE(
240 b->CreateMul(tiled_keys_index[dimension_to_sort],
241 tiled_keys_index.GetConstantWithIndexType(2)),
242 tiled_keys_index.GetConstantWithIndexType(
243 RoundDownToNearest(dimension_to_sort_bound, tile_size))),
244 [&]() {
245 return EmitCompareLoopBody(
246 dimension_to_sort_bound % tile_size, params.size(),
247 element_pair_index, xor_mask, tiled_keys_index.GetType(),
248 element_address, write_element, emit_compare_callback, b);
249 },
250 [&]() {
251 return EmitCompareLoopBody(
252 tile_size, params.size(), element_pair_index, xor_mask,
253 tiled_keys_index.GetType(), element_address, write_element,
254 emit_compare_callback, b,
255 /*needs_bounds_checks=*/false);
256 }));
257 } else {
258 TF_RETURN_IF_ERROR(EmitCompareLoopBody(
259 tile_size, params.size(), element_pair_index, xor_mask,
260 tiled_keys_index.GetType(), element_address, write_element,
261 emit_compare_callback, b,
262 /*needs_bounds_checks=*/false));
263 }
264 // Wait until all comparisons have happened.
265 llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_barrier0, {}, {}, b);
266 }
267
268 // Copy the operand tiles back from shared memory to the operand buffers.
269 for (int64 i = 0; i < params.size(); ++i) {
270 copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
271 keys_multi_index[dimension_to_sort] = index;
272 IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
273 tiled_keys_index.GetType());
274 auto value = b->CreateLoad(b->CreateGEP(
275 param_shmem_buffers[i],
276 {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
277 params[i].EmitWriteArrayElement(keys_index, value, b);
278 });
279 }
280 // We should normally synchronize here to make sure all writes have happened.
281 // However the very next thing each thread does is reading 2 elements from the
282 // operand buffer and writing it into the same location in shared memory from
283 // which it previously copied it to the operand buffer, and we synchronize
284 // after this has happened. We can be sure that a thread always writes to the
285 // same location in shared memory because we have exactly tile_size / 2 many
286 // threads, and the linear index calculated by ParallelLoopEmitter uses
287 // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
288 return Status::OK();
289 }
290 } // namespace
291
EmitSortInPlace(int64 dimension_to_sort,const std::vector<IrArray> & values_arrays,absl::string_view name,absl::Span<const int64> xor_masks,llvm::IRBuilder<> * b,const gpu::LaunchDimensions & launch_dimensions,int64 num_iterations_in_sort_dim,const int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback)292 Status EmitSortInPlace(
293 int64 dimension_to_sort, const std::vector<IrArray>& values_arrays,
294 absl::string_view name, absl::Span<const int64> xor_masks,
295 llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
296 int64 num_iterations_in_sort_dim, const int64 tile_size,
297 const EmitCallToNestedComputationCallback& emit_compare_callback) {
298 // Iterate through the keys shape in physical order, but skip the dimension to
299 // sort and make it the innermost loop which is the loop where the comparisons
300 // happen. In the dimension to sort, if we use tiling, we iterate through it
301 // in tiles of 64 elements each, so we use another loop that happens within
302 // one thread to process this tile worth of data (thereby combining several
303 // comparison stages of the bitonic sort algorithm because they all happen
304 // within those 64 elements and are therefore independent of the other
305 // comparisons).
306
307 const Shape& keys_shape = values_arrays[0].GetShape();
308 int64 rank = keys_shape.rank();
309 int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
310 std::vector<int64> dimensions_in_iteration_order(rank);
311 std::vector<int64> iteration_order_to_logical_order(rank);
312 int64 dim = 0;
313 for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) {
314 if (dimension != dimension_to_sort) {
315 dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
316 iteration_order_to_logical_order[dim++] = dimension;
317 }
318 }
319 dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
320 iteration_order_to_logical_order[dim] = dimension_to_sort;
321
322 Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
323 dimensions_in_iteration_order);
324
325 // Allocate shared memory for the tiled compare loop.
326 std::vector<llvm::Value*> param_shmem_buffers(values_arrays.size(), nullptr);
327 if (xor_masks.size() > 1) {
328 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
329 for (int64 i = 0; i < values_arrays.size(); ++i) {
330 llvm::Type* tile_type = llvm::ArrayType::get(
331 llvm_ir::PrimitiveTypeToIrType(
332 values_arrays[i].GetShape().element_type(), module),
333 tile_size);
334 param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
335 module, tile_type, absl::StrCat(name, "_tile_param_", i));
336 }
337 }
338
339 auto compare_loop_body_emitter =
340 [&](const IrArray::Index& tiles_index) -> Status {
341 // Naive C++ code for the inner compare loop:
342 //
343 // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
344 // int64 j = i ^ xor_mask;
345 // /* emitted in EmitCompareLoopBody() */
346 // if (i < j && j < dimension_to_sort_bound) {
347 // int64 min_key = std::min(keys[i], keys[j]);
348 // keys[j] = std::max(keys[i], keys[j]);
349 // keys[i] = min_key;
350 // }
351 // }
352 //
353 // This follows the algorithm described on Wikipedia:
354 // https://en.wikipedia.org/wiki/Bitonic_sorter
355 std::vector<llvm::Value*> keys_multi_index(rank);
356 for (int64 i = 0; i < rank; ++i) {
357 keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
358 }
359 if (xor_masks.size() > 1) {
360 IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
361 tiles_index.GetType());
362 TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
363 keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
364 values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
365 b));
366 } else {
367 auto element_address = [&](int64 operand, llvm::Value* index) {
368 keys_multi_index[dimension_to_sort] = index;
369 IrArray::Index keys_index(keys_multi_index,
370 values_arrays[operand].GetShape(),
371 tiles_index.GetType());
372 return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
373 };
374 auto write_element = [&](int64 operand, llvm::Value* index,
375 llvm::Value* value) {
376 keys_multi_index[dimension_to_sort] = index;
377 IrArray::Index keys_index(keys_multi_index,
378 values_arrays[operand].GetShape(),
379 tiles_index.GetType());
380 values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
381 };
382 TF_RETURN_IF_ERROR(EmitCompareLoopBody(
383 dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
384 xor_masks[0], tiles_index.GetType(), element_address, write_element,
385 emit_compare_callback, b));
386 }
387 return Status::OK();
388 };
389 return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
390 launch_dimensions, b)
391 .EmitLoop(name);
392 }
393
394 } // namespace llvm_ir
395 } // namespace xla
396