1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
17 #define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
18 
19 #if GOOGLE_CUDA
20 
21 #define EIGEN_USE_GPU
22 
23 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
24 #include "third_party/cub/device/device_reduce.cuh"
25 #include "third_party/cub/device/device_segmented_reduce.cuh"
26 #include "third_party/cub/iterator/counting_input_iterator.cuh"
27 #include "third_party/cub/iterator/transform_input_iterator.cuh"
28 #include "third_party/cub/warp/warp_reduce.cuh"
29 #include "cuda/include/cuComplex.h"
30 #include "tensorflow/core/kernels/reduction_ops.h"
31 #include "tensorflow/core/lib/core/bits.h"
32 #include "tensorflow/core/util/cuda_kernel_helper.h"
33 #include "tensorflow/core/util/permutation_input_iterator.h"
34 #include "tensorflow/core/util/transform_output_iterator.h"
35 
36 #include <sstream>
37 
38 namespace tensorflow {
39 namespace functor {
40 
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename T>
44 struct Sqrt {
operatorSqrt45   __host__ __device__ T operator()(const T& a) const {
46     return Eigen::numext::sqrt(a);
47   }
48 };
49 
50 template <typename T>
51 struct Sum {
operatorSum52   __host__ __device__ T operator()(const T& a, const T& b) const {
53     return a + b;
54   }
55 };
56 
57 // needed to work around a compiler bug in nvcc - it doesn't seem to like
58 // the overloaded addition op for std::complex
59 template <>
60 struct Sum<std::complex<float>> {
61   __host__ __device__ std::complex<float> operator()(
62       const std::complex<float>& a, const std::complex<float>& b) const {
63     auto result = cuCaddf(make_cuComplex(a.real(), a.imag()),
64                           make_cuComplex(b.real(), b.imag()));
65     return std::complex<float>(result.x, result.y);
66   }
67 };
68 
69 template <>
70 struct Sum<std::complex<double>> {
71   __host__ __device__ std::complex<double> operator()(
72       const std::complex<double>& a, const std::complex<double>& b) const {
73     auto result = cuCadd(make_cuDoubleComplex(a.real(), a.imag()),
74                          make_cuDoubleComplex(b.real(), b.imag()));
75     return std::complex<double>(result.x, result.y);
76   }
77 };
78 
79 template <typename T>
80 struct Prod {
81   __host__ __device__ T operator()(const T& a, const T& b) const {
82     return a * b;
83   }
84 };
85 
86 // needed to work around a compiler bug in nvcc - it doesn't seem to like
87 // the overloaded multiply op for std::complex
88 template <>
89 struct Prod<std::complex<float>> {
90   __host__ __device__ std::complex<float> operator()(
91       const std::complex<float>& a, const std::complex<float>& b) const {
92     auto result = cuCmulf(make_cuComplex(a.real(), a.imag()),
93                           make_cuComplex(b.real(), b.imag()));
94     return std::complex<float>(result.x, result.y);
95   }
96 };
97 
98 template <>
99 struct Prod<std::complex<double>> {
100   __host__ __device__ std::complex<double> operator()(
101       const std::complex<double>& a, const std::complex<double>& b) const {
102     auto result = cuCmul(make_cuDoubleComplex(a.real(), a.imag()),
103                          make_cuDoubleComplex(b.real(), b.imag()));
104     return std::complex<double>(result.x, result.y);
105   }
106 };
107 
108 template <typename T>
109 struct Square {
110   __host__ __device__ T operator()(const T& a) const {
111     return Prod<T>()(a, Eigen::numext::conj(a));
112   }
113 };
114 
115 template <typename T, typename outT = T>
116 struct DividesBy {
117   T divisor;
118 
119   __host__ __device__ explicit DividesBy(T divisor) : divisor(divisor) {}
120 
121   __host__ __device__ outT operator()(const T& x) const { return x / divisor; }
122 };
123 
124 // needed to work around a compiler bug in nvcc - it doesn't seem to like
125 // the overloaded ops for std::complex
126 template <>
127 struct DividesBy<std::complex<float>> {
128   cuFloatComplex divisor;
129 
130   __host__ __device__ explicit DividesBy(std::complex<float> divisor)
131       : divisor(make_cuComplex(divisor.real(), divisor.imag())) {}
132 
133   // implements
134   __host__ __device__ std::complex<float> operator()(
135       const std::complex<float>& x) const {
136     auto result = cuCdivf(make_cuComplex(x.real(), x.imag()), divisor);
137     return std::complex<float>(result.x, result.y);
138   }
139 };
140 
141 template <>
142 struct DividesBy<std::complex<double>> {
143   cuDoubleComplex divisor;
144 
145   __host__ __device__ explicit DividesBy(std::complex<double> divisor)
146       : divisor(make_cuDoubleComplex(divisor.real(), divisor.imag())) {}
147 
148   // implements
149   __host__ __device__ std::complex<double> operator()(
150       const std::complex<double>& x) const {
151     auto result = cuCdiv(make_cuDoubleComplex(x.real(), x.imag()), divisor);
152     return std::complex<double>(result.x, result.y);
153   }
154 };
155 
156 template <>
157 struct DividesBy<float, Eigen::half> {
158   float divisor;
159 
160   __host__ __device__ explicit DividesBy(float divisor) : divisor(divisor) {}
161 
162   __host__ __device__ Eigen::half operator()(const float& x) const {
163     return Eigen::half(x / divisor);
164   }
165 };
166 
167 struct HalfToFloat {
168   __host__ __device__ float operator()(const Eigen::half& x) const {
169     return Eigen::half_impl::half_to_float(x);
170   }
171 };
172 
173 struct FloatToHalf {
174   __host__ __device__ Eigen::half operator()(const float& x) const {
175     return Eigen::half_impl::float_to_half_rtne(x);
176   }
177 };
178 
179 struct And {
180   __host__ __device__ bool operator()(const bool& a, const bool& b) const {
181     return a && b;
182   }
183 };
184 
185 struct Or {
186   __host__ __device__ bool operator()(const bool& a, const bool& b) const {
187     return a || b;
188   }
189 };
190 
191 // each block does a grid strided loop and reduces its values locally
192 // the case of one block is used for low latency small reductions to scalars
193 template <typename T, typename outT, int num_threads, typename Op>
194 __global__ void BlockReduceKernel(
195     T in, outT out, int num_elems, Op op,
196     typename std::iterator_traits<T>::value_type initVal) {
197   const int bid = blockIdx.x;
198   const int tid = threadIdx.x;
199 
200   const int gid = bid * blockDim.x + tid;
201   const int stride = blockDim.x * gridDim.x;
202 
203   typedef typename std::iterator_traits<T>::value_type value_type;
204 
205   value_type sum = initVal;
206   if (gid < num_elems) {
207     sum = in[gid];
208     for (int pos = gid + stride; pos < num_elems; pos += stride) {
209       sum = op(sum, in[pos]);
210     }
211   }
212 
213   typedef cub::BlockReduce<value_type, num_threads> BlockReduce;
214 
215   __shared__ typename BlockReduce::TempStorage temp_storage;
216 
217   // only include input values in the reduction
218   //
219   // elements: -----------------
220   // grid:     |====|====|====|====|====|
221   const int num_elements_to_reduce =
222       max(min(num_elems - bid * blockDim.x, num_threads), 0);
223 
224   sum = BlockReduce(temp_storage).Reduce(sum, op, num_elements_to_reduce);
225 
226   if (tid == 0) out[bid] = sum;
227 }
228 
229 // maps a warp to each row
230 template <typename T, typename outT, typename Op>
231 __global__ void RowReduceKernel(
232     T in, outT out, int num_rows, int num_cols, Op op,
233     typename std::iterator_traits<T>::value_type initVal) {
234   typedef typename std::iterator_traits<T>::value_type value_type;
235   // Defensive index computation to avoid integer overflow.
236   assert(blockDim.x % 32 == 0);
237   int warps_per_block = blockDim.x / 32;
238   int warp_index = threadIdx.x / 32;
239   const int row = blockIdx.x * warps_per_block + warp_index;
240   const int lane = threadIdx.x % 32;
241 
242   if (num_cols == 1) {
243     int gid = threadIdx.x + blockIdx.x * blockDim.x;
244     if (gid < num_rows) out[gid] = in[gid];
245     return;
246   }
247 
248   value_type sum = initVal;
249   int col = lane;
250 
251   if (row < num_rows && col < num_cols) {
252     sum = in[row * num_cols + col];
253     col += 32;
254     for (; col < num_cols; col += 32) {
255       sum = op(sum, in[row * num_cols + col]);
256     }
257   }
258 
259   typedef cub::WarpReduce<value_type> WarpReduce;
260 
261   __shared__ typename WarpReduce::TempStorage temp_storage;
262 
263   sum = WarpReduce(temp_storage).Reduce(sum, op, min(num_cols, 32));
264 
265   if (row < num_rows && lane == 0) out[row] = sum;
266 }
267 
268 template <typename T1>
269 struct storage_type {
270   T1 val;
271   __host__ __device__ storage_type() {}
272   __host__ __device__ operator T1() { return val; }
273   __host__ __device__ storage_type<T1>& operator=(const T1& in) {
274     val = in;
275     return *this;
276   }
277 };
278 
279 template <typename T2>
280 struct storage_type<std::complex<T2>> {
281   T2 real;
282   T2 imag;
283   __host__ __device__ storage_type() {}
284   __host__ __device__ operator std::complex<T2>() {
285     return std::complex<T2>(real, imag);
286   }
287   __host__ __device__ storage_type<std::complex<T2>>& operator=(
288       const std::complex<T2>& in) {
289     real = in.real();
290     imag = in.imag();
291     return *this;
292   }
293 };
294 
295 // Works only if there are <= 16 columns
296 // each warps sums over multiple rows at once
297 template <typename T, typename outT, typename Op>
298 __global__ void ColumnReduceMax16ColumnsKernel(
299     T in, outT out, int num_rows, int num_cols, Op op,
300     typename std::iterator_traits<T>::value_type initVal) {
301   typedef typename std::iterator_traits<T>::value_type value_type;
302   int rows_per_warp = 32 / num_cols;
303 
304   const int lane = threadIdx.x % 32;
305   const int lane_row = lane / num_cols;
306 
307   const int start_row_warp =
308       rows_per_warp * (blockIdx.y * blockDim.y + threadIdx.y);
309   const int start_row_lane = start_row_warp + lane_row;
310   int row = start_row_lane;
311   int col = lane % num_cols;
312 
313   value_type sum = initVal;
314   if (row * num_cols + col < num_rows * num_cols)
315     sum = in[row * num_cols + col];
316 
317   // 1D array necessary due to bug in CUDA 9 compiler.
318   // TODO(nluehr) revert to 2D array when compiler is ready.
319   // This is to mimic the following, but without any constructors:
320   //   __shared__ storage_type<value_type> partial_sums[32 * 33];
321   __shared__ __align__(
322       alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
323   value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
324 
325   row += rows_per_warp * gridDim.y * blockDim.y;
326   for (; row < num_rows; row += rows_per_warp * gridDim.y * blockDim.y) {
327     int global_pos = row * num_cols + col;
328     if (global_pos < (num_rows * num_cols))
329       sum = op(sum, in[row * num_cols + col]);
330   }
331 
332   const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp);
333   // not the most efficient way to do this sum
334   for (int i = 1; i < rows_in_this_warp; ++i) {
335     value_type tmp = cub::ShuffleIndex<32, value_type>(
336         sum, static_cast<int>(threadIdx.x + i * num_cols), 0xffffffff);
337     if (lane < num_cols) sum = op(sum, tmp);
338   }
339 
340   if (lane < num_cols) partial_sums[lane * 33 + threadIdx.y] = sum;
341 
342   __syncthreads();
343 
344   if (threadIdx.y == 0 && threadIdx.x < num_cols) {
345     value_type s = partial_sums[threadIdx.x * 33];
346 
347     if (blockDim.y > 1) {
348       for (int row = 1; row < blockDim.y; ++row) {
349         value_type t = partial_sums[threadIdx.x * 33 + row];
350         s = op(s, t);
351       }
352     }
353 
354     out[col * gridDim.y + blockIdx.y] = s;
355   }
356 }
357 
358 // Maps each block to a column range 32 wide
359 template <typename T, typename outT, typename Op>
360 __global__ void ColumnReduceKernel(
361     T in, outT out, int num_rows, int num_cols, Op op,
362     typename std::iterator_traits<T>::value_type initVal) {
363   typedef typename std::iterator_traits<T>::value_type value_type;
364   int row = blockIdx.y * blockDim.y + threadIdx.y;
365   int col = blockIdx.x * 32 + threadIdx.x;
366 
367   value_type sum = initVal;
368   if (row < num_rows && col < num_cols) sum = in[row * num_cols + col];
369 
370   // 1D array necessary due to bug in CUDA 9 compiler.
371   // TODO(nluehr) revert to 2D array when compiler is ready.
372   // This is to mimic the following, but without constructors:
373   //     __shared__ storage_type<value_type> partial_sums[32 * 33];
374   __shared__ __align__(
375       alignof(value_type)) char partial_sums_raw[32 * 33 * sizeof(value_type)];
376   value_type* partial_sums = reinterpret_cast<value_type*>(partial_sums_raw);
377 
378   row += gridDim.y * blockDim.y;
379 
380   if (col < num_cols) {
381     for (; row < num_rows; row += gridDim.y * blockDim.y) {
382       sum = op(sum, in[row * num_cols + col]);
383     }
384   }
385 
386   partial_sums[threadIdx.x * 33 + threadIdx.y] = sum;
387 
388   __syncthreads();
389 
390   if (threadIdx.y == 0 && col < num_cols) {
391     value_type s = partial_sums[threadIdx.x * 33];
392 
393     // only include input values in the reduction
394     // elem   block_rows
395     //  -         =
396     //  -         =
397     //  #         #  block boundary
398     //  -         =
399     //  -         =
400     //  #         #  block boundary
401     //  -         =
402     //            =
403     const int numRowsThisBlock =
404         min(blockDim.y, num_rows - blockIdx.y * blockDim.y);
405 
406     for (int row = 1; row < numRowsThisBlock; ++row) {
407       value_type t = partial_sums[threadIdx.x * 33 + row];
408       s = op(s, t);
409     }
410 
411     out[col * gridDim.y + blockIdx.y] = s;
412   }
413 }
414 
415 // does multiple warp size segmented reductions in parallel
416 // segments cannot cross warp boundaries (mainly used for reducing the segments
417 // that come from the Max16Columns column reduction kernel)
418 template <typename T, typename outT, typename Op>
419 __global__ void CleanupSegments(
420     T partial_sums, outT out, int num_rows, int num_cols, int segment_size,
421     Op op, typename std::iterator_traits<T>::value_type initVal) {
422   typedef typename std::iterator_traits<T>::value_type value_type;
423   const int tid = threadIdx.x + blockIdx.x * blockDim.x;
424 
425   value_type val = initVal;
426   if (tid < segment_size * num_cols) val = partial_sums[tid];
427 
428   typedef cub::WarpReduce<value_type> WarpReduce;
429 
430   __shared__ typename WarpReduce::TempStorage temp_storage;
431 
432   const bool head_flag = (threadIdx.x % segment_size) == 0;
433   value_type sum =
434       WarpReduce(temp_storage).HeadSegmentedReduce(val, head_flag, op);
435 
436   if (head_flag && tid < segment_size * num_cols) {
437     out[tid / segment_size] = sum;
438   }
439 }
440 
441 // assigns one thread to a column
442 template <typename T, typename outT, typename Op>
443 __global__ void ColumnReduceSimpleKernel(T in, outT out, int num_planes,
444                                          int num_rows, int num_cols, Op op) {
445   typedef typename std::iterator_traits<T>::value_type value_type;
446   const int gid = threadIdx.x + blockIdx.x * blockDim.x;
447   const int elems_per_plane = num_rows * num_cols;
448 
449   const int plane = gid / num_cols;
450   const int col = gid % num_cols;
451 
452   if (plane >= num_planes) return;
453 
454   if (num_rows == 1) {
455     out[plane * elems_per_plane + col] = in[plane * elems_per_plane + col];
456     return;
457   }
458 
459   value_type sum = op(in[plane * elems_per_plane + col],
460                       in[plane * elems_per_plane + num_cols + col]);
461   for (int row = 2; row < num_rows; ++row) {
462     sum = op(sum, in[plane * elems_per_plane + row * num_cols + col]);
463   }
464 
465   out[plane * num_cols + col] = sum;
466 }
467 
468 struct RowOffset {
469   __host__ __device__ explicit RowOffset(const int& cols) : cols_(cols) {}
470 
471   __host__ __device__ int operator()(const int& x) const { return cols_ * x; }
472 
473   int cols_;
474 };
475 
476 struct GatherOp {
477   __host__ __device__ GatherOp(const int& extent_x, const int& extent_y,
478                                const int& extent_z, bool kOne)
479       : extent_x_(extent_x),
480         extent_y_(extent_y),
481         extent_z_(extent_z),
482         kOne_(kOne) {
483     if (kOne_)
484       group_size_ = extent_y_;
485     else
486       group_size_ = extent_x_ * extent_z_;
487   }
488 
489   __host__ __device__ int operator()(const int& ind) const {
490     const int group = kOne_ ? ind / group_size_ : ind % group_size_;
491     const int offset = kOne_ ? ind % group_size_ : ind / group_size_;
492 
493     const int x = group / extent_z_;
494     const int z = group % extent_z_;
495 
496     return x * extent_y_ * extent_z_ + z + offset * extent_z_;
497   }
498 
499   int extent_x_;
500   int extent_y_;
501   int extent_z_;
502   bool kOne_;
503   int group_size_;
504 };
505 
506 template <typename T, typename Op, typename OUT_T, typename IN_T>
507 void LaunchScalarReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
508                            int in_size, Op op, T init,
509                            const cudaStream_t& cu_stream) {
510   // handle situations where low latency is important better than CUB
511   if (in_size <= 4096) {
512     const int num_blocks = 1;
513     const int num_threads = 256;
514     TF_CHECK_OK(CudaLaunchKernel(
515         BlockReduceKernel<IN_T, OUT_T, num_threads, Op>, num_blocks,
516         num_threads, 0, cu_stream, in, out, in_size, op, init));
517     return;
518   } else if (in_size <= 1 << 18) {
519     const int num_threads = 256;
520     const int num_blocks = std::min(32, Eigen::divup(in_size, num_threads));
521     // it seems like tailoring this to the GPU
522     // would be more effective, but all attempts
523     // at making this a multiple of the number of
524     // multiprocessors have lead to lower perf
525     // in general
526     // TODO(eriche) investigate this more
527 
528     Tensor temp_storage;
529     OP_REQUIRES_OK(
530         ctx,
531         ctx->allocate_temp(
532             DT_INT8, TensorShape({static_cast<int64>(num_blocks * sizeof(T))}),
533             &temp_storage));
534 
535     TF_CHECK_OK(CudaLaunchKernel(BlockReduceKernel<IN_T, T*, num_threads, Op>,
536                                  num_blocks, num_threads, 0, cu_stream, in,
537                                  (T*)temp_storage.flat<int8_t>().data(),
538                                  in_size, op, init));
539 
540     // take care that we only reduce blocks that had some valid elements in them
541     // TODO(eriche): CUB currently has a bug in HeadSegmentedReduce that
542     // requires it to be used with a full warp.  Can reduce 32 -> num_blocks
543     // when this is fixed.
544     TF_CHECK_OK(CudaLaunchKernel(CleanupSegments<T*, OUT_T, Op>, 1, 32, 0,
545                                  cu_stream,
546                                  (T*)temp_storage.flat<int8_t>().data(), out, 1,
547                                  1, num_blocks, op, init));
548     return;
549   }
550 
551   size_t temp_storage_bytes = 0;
552   auto reduce = [&](void* temp_storage_ptr) {
553     auto success =
554         cub::DeviceReduce::Reduce(temp_storage_ptr, temp_storage_bytes, in, out,
555                                   in_size, op, init, cu_stream);
556 
557     OP_REQUIRES(
558         ctx, success == 0,
559         errors::Internal("CUB reduce error ", cudaGetErrorString(success)));
560   };
561 
562   reduce(nullptr);  // Get required amount of temp storage.
563 
564   Tensor temp_storage;
565   OP_REQUIRES_OK(
566       ctx, ctx->allocate_temp(
567                DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
568                &temp_storage));
569 
570   reduce(temp_storage.flat<int8_t>().data());  // Do reduction.
571 }
572 
573 template <typename T, typename Op, typename OUT_T, typename IN_T>
574 void LaunchRowReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int num_rows,
575                         int num_cols, Op op, T init,
576                         const cudaStream_t& cu_stream) {
577   if (num_cols < 1024) {
578     const int threads_per_block = 128;
579     const int warps_per_block = threads_per_block / 32;
580     int num_blocks = (num_rows + warps_per_block - 1) / warps_per_block;
581 
582     TF_CHECK_OK(CudaLaunchKernel(RowReduceKernel<IN_T, OUT_T, Op>, num_blocks,
583                                  threads_per_block, 0, cu_stream, in, out,
584                                  num_rows, num_cols, op, init));
585     return;
586   }
587 
588   // setup segment offsets with counting and transform iterator
589   RowOffset row_offset_op(num_cols);
590   cub::CountingInputIterator<int> counting_iter(0);
591   cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
592       transform_iter(counting_iter, row_offset_op);
593 
594   size_t temp_storage_bytes = 0;
595   auto reduce = [&](void* temp_storage_ptr) {
596     auto success = cub::DeviceSegmentedReduce::Reduce(
597         temp_storage_ptr, temp_storage_bytes, in, out, num_rows, transform_iter,
598         transform_iter + 1, op, init, cu_stream);
599 
600     OP_REQUIRES(ctx, success == 0,
601                 errors::Internal("CUB segmented reduce error",
602                                  cudaGetErrorString(success)));
603   };
604 
605   reduce(nullptr);  // Get required amount of temp storage.
606 
607   Tensor temp_storage;
608   OP_REQUIRES_OK(
609       ctx, ctx->allocate_temp(
610                DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
611                &temp_storage));
612 
613   reduce(temp_storage.flat<int8_t>().data());  // Do reduction.
614 }
615 
616 template <typename T, typename Op, typename OUT_T, typename IN_T>
617 void LaunchColumnReduction_LTE16Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
618                                      int extent_x, int extent_y, Op op, T init,
619                                      const cudaStream_t& cu_stream) {
620   int rows_per_warp = 32 / extent_y;
621   dim3 block_dim(32, std::min(Eigen::divup(extent_x, rows_per_warp), 32), 1);
622   dim3 grid_dim(1,
623                 Eigen::divup(static_cast<unsigned int>(extent_x),
624                              rows_per_warp * block_dim.y),
625                 1);
626 
627   grid_dim.y = std::min((int)grid_dim.y, 32);
628 
629   if (grid_dim.y > 2 && grid_dim.y < 32) {
630     int log2 = Log2Floor(grid_dim.y);
631     grid_dim.y = 1 << log2;
632   }
633 
634   if (grid_dim.y == 1) {
635     TF_CHECK_OK(CudaLaunchKernel(
636         ColumnReduceMax16ColumnsKernel<IN_T, OUT_T, Op>, grid_dim, block_dim, 0,
637         cu_stream, in, out, extent_x, extent_y, op, init));
638   } else {
639     Tensor temp_storage;
640     OP_REQUIRES_OK(ctx,
641                    ctx->allocate_temp(DT_INT8,
642                                       TensorShape({static_cast<int64>(
643                                           sizeof(T) * extent_y * grid_dim.y)}),
644                                       &temp_storage));
645     TF_CHECK_OK(CudaLaunchKernel(ColumnReduceMax16ColumnsKernel<IN_T, T*, Op>,
646                                  grid_dim, block_dim, 0, cu_stream, in,
647                                  (T*)temp_storage.flat<int8_t>().data(),
648                                  extent_x, extent_y, op, init));
649 
650     dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1);
651     dim3 num_threads(128, 1, 1);
652     TF_CHECK_OK(CudaLaunchKernel(CleanupSegments<T*, OUT_T, Op>, new_grid_dim,
653                                  num_threads, 0, cu_stream,
654                                  (T*)temp_storage.flat<int8_t>().data(), out,
655                                  extent_x, extent_y, grid_dim.y, op, init));
656   }
657 }
658 
659 template <typename T, typename Op, typename OUT_T, typename IN_T>
660 void LaunchColumnReduction_LTE4096Cols(OpKernelContext* ctx, OUT_T out, IN_T in,
661                                        int extent_x, int extent_y, Op op,
662                                        T init, const cudaStream_t& cu_stream) {
663   dim3 block_dim(32, std::min(extent_x, 32), 1);
664   dim3 grid_dim((extent_y + 31) / 32, 1, 1);
665 
666   if (grid_dim.x < 16) grid_dim.y = std::min((extent_x + 31) / 32, 32);
667 
668   if (grid_dim.y > 2 && grid_dim.y < 32) {
669     int log2 = Log2Floor(grid_dim.y);
670     grid_dim.y = 1 << log2;
671   }
672 
673   if (grid_dim.y == 1) {
674     TF_CHECK_OK(CudaLaunchKernel(ColumnReduceKernel<IN_T, OUT_T, Op>, grid_dim,
675                                  block_dim, 0, cu_stream, in, out, extent_x,
676                                  extent_y, op, init));
677   } else {
678     Tensor temp_storage;
679     OP_REQUIRES_OK(ctx,
680                    ctx->allocate_temp(DT_INT8,
681                                       TensorShape({static_cast<int64>(
682                                           sizeof(T) * extent_y * grid_dim.y)}),
683                                       &temp_storage));
684 
685     TF_CHECK_OK(CudaLaunchKernel(
686         ColumnReduceKernel<IN_T, T*, Op>, grid_dim, block_dim, 0, cu_stream, in,
687         (T*)temp_storage.flat<int8_t>().data(), extent_x, extent_y, op, init));
688 
689     dim3 new_grid_dim((grid_dim.y * extent_y + 31) / 32, 1, 1);
690     dim3 num_threads(128, 1, 1);
691     TF_CHECK_OK(CudaLaunchKernel(CleanupSegments<T*, OUT_T, Op>, new_grid_dim,
692                                  block_dim, 0, cu_stream,
693                                  (T*)temp_storage.flat<int8_t>().data(), out,
694                                  extent_x, extent_y, grid_dim.y, op, init));
695   }
696 }
697 
698 template <typename T, typename Op, typename OUT_T, typename IN_T>
699 void LaunchColumnReduction(OpKernelContext* ctx, OUT_T out, IN_T in,
700                            int extent_x, int extent_y, Op op, T init,
701                            const cudaStream_t& cu_stream) {
702   if (extent_y <= 16) {
703     LaunchColumnReduction_LTE16Cols(ctx, out, in, extent_x, extent_y, op, init,
704                                     cu_stream);
705   } else if (extent_y <= 4096) {
706     LaunchColumnReduction_LTE4096Cols(ctx, out, in, extent_x, extent_y, op,
707                                       init, cu_stream);
708   } else {
709     int threads_per_block = 128;
710     int num_blocks = Eigen::divup(extent_y, threads_per_block);
711 
712     TF_CHECK_OK(CudaLaunchKernel(ColumnReduceSimpleKernel<IN_T, OUT_T, Op>,
713                                  num_blocks, threads_per_block, 0, cu_stream,
714                                  in, out, 1, extent_x, extent_y, op));
715   }
716 }
717 
718 template <typename T, typename Op, typename OUT_T, typename IN_T>
719 void Launch3DYReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
720                         int extent_y, int extent_z, Op op, T init,
721                         const cudaStream_t& cu_stream) {
722   int threads_per_block = 128;
723   int num_blocks =
724       (extent_x * extent_z + threads_per_block - 1) / threads_per_block;
725 
726   // TODO(eriche): this won't be very good in the case of small x
727   //                small z and large y.
728   TF_CHECK_OK(CudaLaunchKernel(ColumnReduceSimpleKernel<IN_T, OUT_T, Op>,
729                                num_blocks, threads_per_block, 0, cu_stream, in,
730                                out, extent_x, extent_y, extent_z, op));
731 }
732 
733 template <typename T, typename Op, typename OUT_T, typename IN_T>
734 void Launch3DXZReduction(OpKernelContext* ctx, OUT_T out, IN_T in, int extent_x,
735                          int extent_y, int extent_z, Op op, T init,
736                          const cudaStream_t& cu_stream) {
737   // setup segment offsets with counting and transform iterator
738   RowOffset row_offset_op(extent_x * extent_z);
739   cub::CountingInputIterator<int> counting_iter(0);
740   cub::TransformInputIterator<int, RowOffset, cub::CountingInputIterator<int>>
741       transform_iter(counting_iter, row_offset_op);
742 
743   GatherOp gather_op(extent_x, extent_y, extent_z, false);
744   typedef cub::TransformInputIterator<int, GatherOp,
745                                       cub::CountingInputIterator<int>>
746       gatherIterType;
747   gatherIterType gather_iter(counting_iter, gather_op);
748 
749   PermutationInputIterator<T, IN_T, gatherIterType> permute_iter(in,
750                                                                  gather_iter);
751 
752   std::size_t temp_storage_bytes = 0;
753   auto reduce = [&](void* temp_storage_ptr) {
754     auto success = cub::DeviceSegmentedReduce::Reduce(
755         temp_storage_ptr, temp_storage_bytes, permute_iter, out, extent_y,
756         transform_iter, transform_iter + 1, op, init, cu_stream);
757 
758     OP_REQUIRES(ctx, success == 0,
759                 errors::Internal("CUB segmented reduce error",
760                                  cudaGetErrorString(success)));
761   };
762 
763   reduce(nullptr);  // Get required amount of temp storage.
764 
765   Tensor temp_storage;
766   OP_REQUIRES_OK(
767       ctx, ctx->allocate_temp(
768                DT_INT8, TensorShape({static_cast<int64>(temp_storage_bytes)}),
769                &temp_storage));
770 
771   reduce(temp_storage.flat<int8_t>().data());  // Do reduction.
772 }
773 
774 namespace reduction_op_helper {
775 
776 template <typename T, typename Op>
777 struct IsSum {
778   constexpr static bool value =
779       (std::is_same<Op, cub::Sum>::value ||
780        std::is_same<Op, Eigen::internal::SumReducer<T>>::value ||
781        std::is_same<Op, Sum<T>>::value);
782 };
783 
784 template <typename T, typename Op>
785 struct IsMax {
786   constexpr static bool value =
787       (std::is_same<Op, cub::Max>::value ||
788        std::is_same<Op, Eigen::internal::MaxReducer<T>>::value);
789 };
790 
791 template <typename T, typename Op>
792 struct IsMin {
793   constexpr static bool value =
794       (std::is_same<Op, cub::Min>::value ||
795        std::is_same<Op, Eigen::internal::MinReducer<T>>::value);
796 };
797 
798 template <typename T, typename Op>
799 struct IsProd {
800   constexpr static bool value =
801       (std::is_same<Op, Prod<T>>::value ||
802        std::is_same<Op, Eigen::internal::ProdReducer<T>>::value);
803 };
804 
805 template <typename T, typename Op>
806 struct IdentityValue {
807   static_assert(IsSum<T, Op>::value || IsMax<T, Op>::value ||
808                     IsMin<T, Op>::value || IsProd<T, Op>::value ||
809                     std::is_same<Op, And>::value || std::is_same<Op, Or>::value,
810                 "IdentityValue not yet defined for this type");
811 
812   template <typename U = T, typename OpCopy = Op>
813   U operator()(
814       typename std::enable_if<IsSum<U, OpCopy>::value, U>::type t = U(0)) {
815     return t;
816   }
817 
818   template <typename U = T, typename OpCopy = Op>
819   U operator()(typename std::enable_if<IsMax<U, OpCopy>::value, U>::type t =
820                    Eigen::NumTraits<U>::lowest()) {
821     return t;
822   }
823 
824   template <typename U = T, typename OpCopy = Op>
825   U operator()(typename std::enable_if<IsMin<U, OpCopy>::value, U>::type t =
826                    Eigen::NumTraits<U>::highest()) {
827     return t;
828   }
829 
830   template <typename U = T, typename OpCopy = Op>
831   U operator()(
832       typename std::enable_if<IsProd<U, OpCopy>::value, U>::type t = U(1)) {
833     return t;
834   }
835 
836   template <typename U = T, typename OpCopy = Op>
837   U operator()(typename std::enable_if<std::is_same<OpCopy, And>::value,
838                                        bool>::type t = true) {
839     return t;
840   }
841 
842   template <typename U = T, typename OpCopy = Op>
843   U operator()(typename std::enable_if<std::is_same<OpCopy, Or>::value,
844                                        bool>::type t = false) {
845     return t;
846   }
847 };
848 
849 }  // namespace reduction_op_helper
850 
851 template <typename T, typename Op, typename OUT_T, typename IN_T,
852           typename ReductionAxes>
853 void ReduceImpl(OpKernelContext* ctx, OUT_T out, IN_T in, int in_rank,
854                 int in_dim0, int in_dim1, int in_dim2, int out_rank,
855                 const ReductionAxes& reduction_axes, Op op) {
856   T init = reduction_op_helper::IdentityValue<T, Op>()();
857   const cudaStream_t& cu_stream = GetCudaStream(ctx);
858   if (out_rank == 0) {
859     const int in_size = in_dim0 * in_dim1 * in_dim2;
860     LaunchScalarReduction(ctx, out, in, in_size, op, init, cu_stream);
861   } else if (in_rank == 2 && out_rank == 1 &&
862              reduction_axes[0] == 1) {  // row reduction
863     LaunchRowReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
864   } else if (in_rank == 2 && out_rank == 1 &&
865              reduction_axes[0] == 0) {  // column reduction
866     LaunchColumnReduction(ctx, out, in, in_dim0, in_dim1, op, init, cu_stream);
867   } else if (in_rank == 3 && out_rank == 2 && reduction_axes[0] == 1) {
868     Launch3DYReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
869                        cu_stream);
870   } else if (in_rank == 3 && out_rank == 1 && reduction_axes[0] == 0 &&
871              reduction_axes[1] == 2) {
872     Launch3DXZReduction(ctx, out, in, in_dim0, in_dim1, in_dim2, op, init,
873                         cu_stream);
874   } else {
875     std::stringstream ss;
876     ss << "Invalid reduction requested: in_rank, out_rank, axes " << in_rank
877        << " " << out_rank;
878     if (out_rank == 1) ss << " " << reduction_axes[0];
879     if (out_rank == 2) ss << " " << reduction_axes[1];
880     LOG(FATAL) << ss.str();
881   }
882 }
883 
884 template <typename Reducer>
885 struct ReduceFunctor<GPUDevice, Reducer> {
886   template <typename OUT_T, typename IN_T, typename ReductionAxes>
887   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
888                      const ReductionAxes& reduction_axes,
889                      const Reducer& reducer);
890 };
891 
892 template <typename T>
893 struct ReduceFunctor<GPUDevice, Eigen::internal::SumReducer<T>> {
894   template <typename OUT_T, typename IN_T, typename ReductionAxes>
895   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
896                      const ReductionAxes& reduction_axes,
897                      const Eigen::internal::SumReducer<T>& reducer) {
898     ReduceImpl<T, Sum<T>, T*, T*, ReductionAxes>(
899         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
900         in.rank() >= 2 ? in.dimension(1) : 1,
901         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
902         Sum<T>());
903   }
904 
905   template <typename OUT_T>
906   static void FillIdentity(const GPUDevice& d, OUT_T out,
907                            const Eigen::internal::SumReducer<T>& reducer) {
908     FillIdentityEigenImpl(d, To32Bit(out), reducer);
909   }
910 };
911 
912 // TODO(rmlarsen): Specialize for float16.
913 template <typename T>
914 struct ReduceFunctor<GPUDevice, functor::EuclideanNormReducer<T>> {
915   template <typename OUT_T, typename IN_T, typename ReductionAxes>
916   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
917                      const ReductionAxes& reduction_axes,
918                      const functor::EuclideanNormReducer<T>& reducer) {
919     typedef cub::TransformInputIterator<T, Square<T>, T*> inputIterType;
920     inputIterType input_itr((T*)in.data(), Square<T>());
921     typedef TransformOutputIterator<T, T, Sqrt<T>> outputIterType;
922     outputIterType output_itr((T*)out.data(), Sqrt<T>());
923     ReduceImpl<T, Sum<T>, outputIterType, inputIterType, ReductionAxes>(
924         ctx, output_itr, input_itr, in.rank(), in.dimension(0),
925         in.rank() >= 2 ? in.dimension(1) : 1,
926         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
927         Sum<T>());
928   }
929 
930   template <typename OUT_T>
931   static void FillIdentity(const GPUDevice& d, OUT_T out,
932                            const functor::EuclideanNormReducer<T>& reducer) {
933     FillIdentityEigenImpl(d, To32Bit(out), reducer);
934   }
935 };
936 
937 template <typename T>
938 struct ReduceFunctor<GPUDevice, functor::MeanReducer<T>> {
939   template <typename OUT_T, typename IN_T, typename ReductionAxes>
940   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
941                      const ReductionAxes& reduction_axes,
942                      const functor::MeanReducer<T>& reducer) {
943     int divisor = 1;
944     if (out.rank() == 0)
945       divisor = in.size();
946     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
947       divisor = in.dimension(0);
948     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
949       divisor = in.dimension(1);
950     else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
951              reduction_axes[1] == 2)
952       divisor = in.dimension(0) * in.dimension(2);
953     else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
954       divisor = in.dimension(1);
955 
956     DividesBy<T> div_op(static_cast<T>(divisor));
957     TransformOutputIterator<T, T, DividesBy<T>> itr((T*)out.data(), div_op);
958     ReduceImpl<T, Sum<T>, TransformOutputIterator<T, T, DividesBy<T>>, T*,
959                ReductionAxes>(ctx, itr, (T*)in.data(), in.rank(),
960                               in.dimension(0),
961                               in.rank() >= 2 ? in.dimension(1) : 1,
962                               in.rank() >= 3 ? in.dimension(2) : 1, out.rank(),
963                               reduction_axes, Sum<T>());
964   }
965 
966   template <typename OUT_T>
967   static void FillIdentity(const GPUDevice& d, OUT_T out,
968                            const functor::MeanReducer<T>& reducer) {
969     FillIdentityEigenImpl(d, To32Bit(out), reducer);
970   }
971 };
972 
973 template <>
974 struct ReduceFunctor<GPUDevice, functor::MeanReducer<Eigen::half>> {
975   template <typename OUT_T, typename IN_T, typename ReductionAxes>
976   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
977                      const ReductionAxes& reduction_axes,
978                      const functor::MeanReducer<Eigen::half>& reducer) {
979     float divisor = 1.f;
980     if (out.rank() == 0)
981       divisor = in.size();
982     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 0)
983       divisor = in.dimension(0);
984     else if (out.rank() == 1 && in.rank() == 2 && reduction_axes[0] == 1)
985       divisor = in.dimension(1);
986     else if (out.rank() == 1 && in.rank() == 3 && reduction_axes[0] == 0 &&
987              reduction_axes[1] == 2)
988       divisor = in.dimension(0) * in.dimension(2);
989     else if (out.rank() == 2 && in.rank() == 3 && reduction_axes[0] == 1)
990       divisor = in.dimension(1);
991     DividesBy<float, Eigen::half> div_op(divisor);
992 
993     typedef cub::TransformInputIterator<float, HalfToFloat, Eigen::half*>
994         inputIterType;
995     inputIterType input_itr((Eigen::half*)in.data(), HalfToFloat());
996 
997     typedef TransformOutputIterator<Eigen::half, float,
998                                     DividesBy<float, Eigen::half>>
999         outputIterType;
1000     outputIterType itr((Eigen::half*)out.data(), div_op);
1001 
1002     ReduceImpl<float, cub::Sum, outputIterType, inputIterType, ReductionAxes>(
1003         ctx, itr, input_itr, in.rank(), in.dimension(0),
1004         in.rank() >= 2 ? in.dimension(1) : 1,
1005         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1006         cub::Sum());
1007   }
1008 
1009   template <typename OUT_T>
1010   static void FillIdentity(const GPUDevice& d, OUT_T out,
1011                            const functor::MeanReducer<Eigen::half>& reducer) {
1012     FillIdentityEigenImpl(d, To32Bit(out), reducer);
1013   }
1014 };
1015 
1016 template <typename T>
1017 struct ReduceFunctor<GPUDevice, Eigen::internal::MaxReducer<T>> {
1018   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1019   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1020                      const ReductionAxes& reduction_axes,
1021                      const Eigen::internal::MaxReducer<T>& reducer) {
1022     ReduceImpl<T, cub::Max, T*, T*, ReductionAxes>(
1023         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1024         in.rank() >= 2 ? in.dimension(1) : 1,
1025         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1026         cub::Max());
1027   }
1028 
1029   template <typename OUT_T>
1030   static void FillIdentity(const GPUDevice& d, OUT_T out,
1031                            const Eigen::internal::MaxReducer<T>& reducer) {
1032     FillIdentityEigenImpl(d, To32Bit(out), reducer);
1033   }
1034 };
1035 
1036 template <typename T>
1037 struct ReduceFunctor<GPUDevice, Eigen::internal::MinReducer<T>> {
1038   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1039   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1040                      const ReductionAxes& reduction_axes,
1041                      const Eigen::internal::MinReducer<T>& reducer) {
1042     ReduceImpl<T, cub::Min, T*, T*, ReductionAxes>(
1043         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1044         in.rank() >= 2 ? in.dimension(1) : 1,
1045         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1046         cub::Min());
1047   }
1048 
1049   template <typename OUT_T>
1050   static void FillIdentity(const GPUDevice& d, OUT_T out,
1051                            const Eigen::internal::MinReducer<T>& reducer) {
1052     FillIdentityEigenImpl(d, To32Bit(out), reducer);
1053   }
1054 };
1055 
1056 template <typename T>
1057 struct ReduceFunctor<GPUDevice, Eigen::internal::ProdReducer<T>> {
1058   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1059   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1060                      const ReductionAxes& reduction_axes,
1061                      const Eigen::internal::ProdReducer<T>& reducer) {
1062     ReduceImpl<T, Prod<T>, T*, T*, ReductionAxes>(
1063         ctx, (T*)out.data(), (T*)in.data(), in.rank(), in.dimension(0),
1064         in.rank() >= 2 ? in.dimension(1) : 1,
1065         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1066         Prod<T>());
1067   }
1068 
1069   template <typename OUT_T>
1070   static void FillIdentity(const GPUDevice& d, OUT_T out,
1071                            const Eigen::internal::ProdReducer<T>& reducer) {
1072     FillIdentityEigenImpl(d, To32Bit(out), reducer);
1073   }
1074 };
1075 
1076 template <>
1077 struct ReduceFunctor<GPUDevice, Eigen::internal::AndReducer> {
1078   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1079   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1080                      const ReductionAxes& reduction_axes,
1081                      const Eigen::internal::AndReducer& reducer) {
1082     ReduceImpl<bool, And, bool*, bool*, ReductionAxes>(
1083         ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
1084         in.rank() >= 2 ? in.dimension(1) : 1,
1085         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes,
1086         And());
1087   }
1088 
1089   template <typename OUT_T>
1090   static void FillIdentity(const GPUDevice& d, OUT_T out,
1091                            const Eigen::internal::AndReducer& reducer) {
1092     FillIdentityEigenImpl(d, To32Bit(out), reducer);
1093   }
1094 };
1095 
1096 template <>
1097 struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
1098   template <typename OUT_T, typename IN_T, typename ReductionAxes>
1099   static void Reduce(OpKernelContext* ctx, OUT_T out, IN_T in,
1100                      const ReductionAxes& reduction_axes,
1101                      const Eigen::internal::OrReducer& reducer) {
1102     ReduceImpl<bool, Or, bool*, bool*, ReductionAxes>(
1103         ctx, (bool*)out.data(), (bool*)in.data(), in.rank(), in.dimension(0),
1104         in.rank() >= 2 ? in.dimension(1) : 1,
1105         in.rank() >= 3 ? in.dimension(2) : 1, out.rank(), reduction_axes, Or());
1106   }
1107 
1108   template <typename OUT_T>
1109   static void FillIdentity(const GPUDevice& d, OUT_T out,
1110                            const Eigen::internal::OrReducer& reducer) {
1111     FillIdentityEigenImpl(d, To32Bit(out), reducer);
1112   }
1113 };
1114 
1115 }  // namespace functor
1116 }  // namespace tensorflow
1117 
1118 #endif  // GOOGLE_CUDA
1119 
1120 #endif  // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
1121