1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
17 #define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #define EIGEN_USE_GPU
22 
23 #include <algorithm>
24 #include <array>
25 #include <limits>
26 #include <utility>
27 
28 #if GOOGLE_CUDA
29 #include "third_party/gpus/cuda/include/cuda.h"
30 #endif
31 #include "tensorflow/core/framework/register_types.h"
32 #include "tensorflow/core/kernels/conv_2d.h"
33 #include "tensorflow/core/lib/math/math_util.h"
34 #include "tensorflow/core/util/gpu_kernel_helper.h"
35 #include "tensorflow/core/util/tensor_format.h"
36 
37 namespace tensorflow {
38 
39 typedef Eigen::GpuDevice GPUDevice;
40 
41 namespace functor {
42 
43 template <typename T, bool conjugate>
44 struct maybe_conj {
runmaybe_conj45   __device__ static __inline__ T run(T x) {
46     if (conjugate) {
47       return Eigen::numext::conj(x);
48     } else {
49       return x;
50     }
51   }
52 };
53 
54 // Partial specializations for Gpu types used to store complex numbers.
55 template <bool conjugate>
56 struct maybe_conj<float2, conjugate> {
57   __device__ static __inline__ float2 run(float2 c) {
58     if (conjugate) {
59       float2 c_conj;
60       c_conj.x = c.x;
61       c_conj.y = -c.y;
62       return c_conj;
63     } else {
64       return c;
65     }
66   }
67 };
68 
69 template <bool conjugate>
70 struct maybe_conj<double2, conjugate> {
71   __device__ static __inline__ double2 run(double2 c) {
72     if (conjugate) {
73       double2 c_conj;
74       c_conj.x = c.x;
75       c_conj.y = -c.y;
76       return c_conj;
77     } else {
78       return c;
79     }
80   }
81 };
82 
83 // TODO(mjanusz): Move this to a shared util file.
84 // A simple array that contains data that can be passed between CPU and GPU.
85 template <typename T, int IndexCount, T DefaultValue>
86 struct Array {
87   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const {
88     return data[index];
89   }
90   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) {
91     return data[index];
92   }
93   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() {
94     for (int i = 0; i < IndexCount; i++) {
95       data[i] = DefaultValue;
96     }
97   }
98   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) {
99     data[0] = a0;
100     for (int i = 1; i < IndexCount; i++) {
101       data[i] = DefaultValue;
102     }
103   }
104   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) {
105     data[0] = a0;
106     data[1] = a1;
107     for (int i = 2; i < IndexCount; i++) {
108       data[i] = DefaultValue;
109     }
110   }
111   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) {
112     data[0] = a0;
113     data[1] = a1;
114     data[2] = a2;
115     for (int i = 3; i < IndexCount; i++) {
116       data[i] = DefaultValue;
117     }
118   }
119   EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) {
120     for (int i = 0; i < IndexCount; i++) {
121       data[i] = array[i];
122     }
123   }
124   T data[IndexCount];
125 };
126 
127 // A dimension type with compile-time known size.
128 template <int IndexCount>
129 struct Dimension : Array<int, IndexCount, 1> {
130   typedef Array<int, IndexCount, 1> Base;
131   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {}
132   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {}
133   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1)
134       : Base(a0, a1) {}
135   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2)
136       : Base(a0, a1, a2) {}
137   EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array)
138       : Base(array) {}
139 };
140 
141 // An index type with compile-time known size.
142 template <int IndexCount>
143 struct Index : Array<int, IndexCount, 0> {
144   typedef Array<int, IndexCount, 0> Base;
145   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {}
146   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {}
147   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {}
148   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2)
149       : Base(a0, a1, a2) {}
150 };
151 
152 // A helper function that converts a tensor index into a flat array index.
153 template <int IndexCount>
154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat(
155     const Index<IndexCount>& index, const Dimension<IndexCount>& dims) {
156   int flat_index = index[0];
157   for (int i = 1; i < IndexCount; i++) {
158     flat_index = flat_index * dims[i] + index[i];
159   }
160   return flat_index;
161 }
162 
163 // A helper function that converts a flat array index into a tensor index.
164 template <int IndexCount>
165 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
166     int index, const Dimension<IndexCount>& dims) {
167   Index<IndexCount> tensor_index;
168   for (int i = IndexCount - 1; i >= 0; i--) {
169     int new_index = index / dims[i];
170     tensor_index[i] = index - dims[i] * new_index;
171     index = new_index;
172   }
173   return tensor_index;
174 }
175 
176 // A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
177 // the given shuffle permutation in template parameters. Shuffle permutation
178 // <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
179 // 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
180 // will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
181 //
182 // Requires that nthreads is equal to the total number of elements in the input
183 // tensor.
184 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
185 __global__ void ShuffleInTensor3Simple(int nthreads,
186                                        const T* __restrict__ input,
187                                        Dimension<3> input_dims,
188                                        T* __restrict__ output) {
189   Dimension<3> output_dims;
190   output_dims[sp0] = input_dims[0];
191   output_dims[sp1] = input_dims[1];
192   output_dims[sp2] = input_dims[2];
193 
194   // Iterate over output as opposed to iterating over input for better
195   // performance. Iterating over output will generate sequential writes and
196   // random reads that performs better compared to sequential reads and random
197   // writes.
198   GPU_1D_KERNEL_LOOP(output_index, nthreads) {
199     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
200 
201     Index<3> input_tensor_index;
202     input_tensor_index[0] = output_tensor_index[sp0];
203     input_tensor_index[1] = output_tensor_index[sp1];
204     input_tensor_index[2] = output_tensor_index[sp2];
205 
206     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
207 
208     output[output_index] =
209         maybe_conj<T, conjugate>::run(ldg(input + input_index));
210   }
211 }
212 
213 static constexpr int kUnroll = 4;
214 
215 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
216 __global__ void ShuffleInTensor3SimpleVector(int nthreads,
217                                              const T* __restrict__ input,
218                                              Dimension<3> input_dims,
219                                              T* __restrict__ output) {
220   Dimension<3> output_dims;
221   output_dims[sp0] = input_dims[0];
222   output_dims[sp1] = input_dims[1];
223   output_dims[sp2] = input_dims[2];
224 
225   const int stride = blockDim.x * gridDim.x * kUnroll;
226   const int tid = blockIdx.x * blockDim.x + threadIdx.x;
227   T buf[kUnroll];
228 
229   int output_index;
230   for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads;
231        output_index += stride) {
232 #pragma unroll
233     for (int i = 0; i < kUnroll; i++) {
234       int output_index_i = output_index + i;
235       Index<3> output_tensor_index =
236           FlatToTensorIndex(output_index_i, output_dims);
237       Index<3> input_tensor_index;
238       input_tensor_index[0] = output_tensor_index[sp0];
239       input_tensor_index[1] = output_tensor_index[sp1];
240       input_tensor_index[2] = output_tensor_index[sp2];
241 
242       int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims);
243       buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i));
244     }
245     float2* out = reinterpret_cast<float2*>(output + output_index);
246     *out = *reinterpret_cast<float2*>(buf);
247   }
248 
249   for (; output_index < nthreads; ++output_index) {
250     Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
251 
252     Index<3> input_tensor_index;
253     input_tensor_index[0] = output_tensor_index[sp0];
254     input_tensor_index[1] = output_tensor_index[sp1];
255     input_tensor_index[2] = output_tensor_index[sp2];
256 
257     int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
258 
259     output[output_index] =
260         maybe_conj<T, conjugate>::run(ldg(input + input_index));
261   }
262 }
263 
264 // Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor,
265 // where dimensions are zero-based: output[i][j][k] = input[i][k][j].
266 //
267 // Each thread block operates on a single tile, a rectangle of dimensions
268 // TileSizeI x TileSizeJ.
269 //
270 // In general, for best performance, you should probably set TileSizeI,
271 // TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs).
272 // With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get
273 // the best performance on K40 GPUs.
274 template <typename T, int NumThreads, int TileSizeI, int TileSizeJ,
275           bool conjugate = false>
276 __global__ void SwapDimension1And2InTensor3UsingTiles(
277     const T* __restrict__ input, Dimension<3> input_dims,
278     T* __restrict__ output) {
279   eigen_assert(blockDim.x == NumThreads);
280   eigen_assert(blockDim.y == 1);
281   eigen_assert(blockDim.z == 1);
282   eigen_assert(gridDim.y == 1);
283   eigen_assert(gridDim.z == 1);
284 
285   constexpr int ReadRowPerPass = NumThreads / TileSizeJ;
286   constexpr int WriteRowPerPass = NumThreads / TileSizeI;
287   // One extra line in the inner dimension to avoid share memory bank conflict.
288   // This is to mimic the following, but no constructor of T can be invoked.
289   //     __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
290 #if GOOGLE_CUDA
291   __shared__ __align__(
292       alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)];
293   typedef T(*SharedMemoryTile)[TileSizeJ + 1];
294   SharedMemoryTile shared_memory_tile =
295       reinterpret_cast<SharedMemoryTile>(shared_mem_raw);
296 #elif TENSORFLOW_USE_ROCM
297   __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1];
298 #endif
299 
300   int x = threadIdx.x;
301 
302   Dimension<3> output_dims = {
303       input_dims[0],
304       input_dims[2],
305       input_dims[1],
306   };
307 
308   Dimension<3> input_dims_in_tiles = {
309       input_dims[0],
310       (input_dims[1] + TileSizeI - 1) / TileSizeI,
311       (input_dims[2] + TileSizeJ - 1) / TileSizeJ,
312   };
313 
314   Index<3> input_tile_index =
315       FlatToTensorIndex(blockIdx.x, input_dims_in_tiles);
316 
317   Index<3> input_tile_origin = {
318       input_tile_index[0],
319       input_tile_index[1] * TileSizeI,
320       input_tile_index[2] * TileSizeJ,
321   };
322 
323   int input_origin_flat_index =
324       TensorIndexToFlat(input_tile_origin, input_dims);
325 
326   bool full_tile = true;
327   int tile_width = TileSizeJ;
328 
329   // Only the last row or column may not have the full size.
330   if (input_tile_index[2] == input_dims_in_tiles[2] - 1) {
331     tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ;
332     full_tile &= false;
333   }
334 
335   int tile_height = TileSizeI;
336 
337   if (input_tile_index[1] == input_dims_in_tiles[1] - 1) {
338     tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI;
339     full_tile &= false;
340   }
341 
342   // Calculate effective thread number. This ensures that we use the largest
343   // number of threads available to form a regular thread block with no
344   // trailing incomplete lines.
345   constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ;
346 
347   if (x < in_effective_thread_num) {
348     // Orient the logical thread block with respect to the input array.
349     // ie. align the contiguous dimension of thread blocks with the contiguous
350     // dimension of the input array.
351     int ti = x / TileSizeJ;
352     int tj = x % TileSizeJ;
353     int input_index = input_origin_flat_index + ti * input_dims[2] + tj;
354     int input_increment = ReadRowPerPass * input_dims[2];
355 
356     if (full_tile) {
357 #pragma unroll
358       for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) {
359         shared_memory_tile[i_loc][tj] =
360             maybe_conj<T, conjugate>::run(input[input_index]);
361         input_index += input_increment;
362       }
363     } else {
364       if (tj < tile_width) {
365         for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) {
366           shared_memory_tile[i_loc][tj] =
367               maybe_conj<T, conjugate>::run(input[input_index]);
368           input_index += input_increment;
369         }
370       }
371     }
372   }
373 
374   __syncthreads();
375 
376   Index<3> output_tile_index = {
377       input_tile_index[0],
378       input_tile_index[2],
379       input_tile_index[1],
380   };
381 
382   Index<3> output_tile_origin = {
383       output_tile_index[0],
384       output_tile_index[1] * TileSizeJ,
385       output_tile_index[2] * TileSizeI,
386   };
387 
388   int output_origin_flat_index =
389       TensorIndexToFlat(output_tile_origin, output_dims);
390 
391   constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI;
392 
393   if (x < out_effective_thread_num) {
394     // Re-orient the logical thread block with respect to the output array.
395     // ie. align the contiguous dimension of thread blocks with contiguous
396     // dimension of the output array.
397     int ti = x / TileSizeI;
398     int tj = x % TileSizeI;
399     int output_index = output_origin_flat_index + ti * output_dims[2] + tj;
400     int output_increment = WriteRowPerPass * output_dims[2];
401 
402     if (full_tile) {
403 #pragma unroll
404       for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) {
405         output[output_index] = shared_memory_tile[tj][i_loc];
406         output_index += output_increment;
407       }
408     } else {
409       if (tj < tile_height) {
410         for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) {
411           output[output_index] = shared_memory_tile[tj][i_loc];
412           output_index += output_increment;
413         }
414       }
415     }
416   }
417 }
418 
419 // A Gpu custom kernel that convert input to output, given proper padding on
420 // the left and the top.
421 template <typename T, int NDIMS>
422 __global__ void PadInputCustomKernelNHWC(
423     int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
424     T* __restrict__ output, Dimension<NDIMS> output_dims,
425     Dimension<NDIMS - 2> padding_left, T padding_value) {
426   GPU_1D_KERNEL_LOOP(index, nthreads) {
427     int output_index = index;
428     Index<NDIMS> output_tensor_index =
429         FlatToTensorIndex(output_index, output_dims);
430 
431     Index<NDIMS> input_tensor_index;
432     input_tensor_index[0] = output_tensor_index[0];  // batch
433     bool ok = true;
434     for (int i = 1; i < NDIMS - 1; i++) {
435       input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1];
436       ok &=
437           (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
438     }
439     input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1];  // channels
440 
441     if (ok) {
442       const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
443       output[output_index] = input[input_index];
444     } else {
445       output[output_index] = padding_value;
446     }
447   }
448 }
449 
450 template <typename T, int NDIMS>
451 __global__ void PadInputCustomKernelNCHW(
452     int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims,
453     T* __restrict__ output, Dimension<NDIMS> output_dims,
454     Dimension<NDIMS - 2> padding_left, T padding_value) {
455   GPU_1D_KERNEL_LOOP(index, nthreads) {
456     int output_index = index;
457     Index<NDIMS> output_tensor_index =
458         FlatToTensorIndex(output_index, output_dims);
459 
460     Index<NDIMS> input_tensor_index;
461     input_tensor_index[0] = output_tensor_index[0];  // batch
462     input_tensor_index[1] = output_tensor_index[1];  // channels
463     bool ok = true;
464     for (int i = 2; i < NDIMS; i++) {
465       input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2];
466       ok &=
467           (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]);
468     }
469 
470     if (ok) {
471       const int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
472       output[output_index] = input[input_index];
473     } else {
474       output[output_index] = padding_value;
475     }
476   }
477 }
478 
479 // A GPU helper function that converts TensorFlow filter format to Cudnn filter
480 // format.
481 template <typename T, int NDIMS>
482 struct TransformFilter<GPUDevice, T, int, NDIMS> {
483   typedef GPUDevice Device;
484   void operator()(const Device& d, FilterTensorFormat dst_filter_format,
485                   typename TTypes<T, NDIMS, int>::ConstTensor in,
486                   typename TTypes<T, NDIMS, int>::Tensor out) {
487     Dimension<3> combined_dims;
488     combined_dims[0] = in.dimension(0);  // spatial dimensions
489     for (int i = 1; i < NDIMS - 2; i++) {
490       combined_dims[0] *= in.dimension(i);
491     }
492     combined_dims[1] = in.dimension(NDIMS - 2);  // input filters
493     combined_dims[2] = in.dimension(NDIMS - 1);  // output filters
494     GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
495 
496     if (dst_filter_format == FORMAT_OIHW) {
497       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
498                                   config.block_count, config.thread_per_block,
499                                   0, d.stream(), config.virtual_thread_count,
500                                   in.data(), combined_dims, out.data()));
501 
502     } else if (dst_filter_format == FORMAT_OHWI) {
503       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>,
504                                   config.block_count, config.thread_per_block,
505                                   0, d.stream(), config.virtual_thread_count,
506                                   in.data(), combined_dims, out.data()));
507 
508     } else {
509       LOG(ERROR) << "Unsupported filter format: "
510                  << ToString(dst_filter_format);
511     }
512   }
513 };
514 
515 // Converts Cudnn filter format OIHW or OHWI back to TensorFlow filter format
516 // HWIO.
517 template <typename T, int NDIMS>
518 struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
519   typedef GPUDevice Device;
520   void operator()(const Device& d, FilterTensorFormat src_filter_format,
521                   typename TTypes<T, NDIMS>::ConstTensor in,
522                   typename TTypes<T, NDIMS>::Tensor out) {
523     Dimension<3> combined_dims;
524 
525     if (src_filter_format == FORMAT_OIHW) {
526       combined_dims[0] = in.dimension(0);  // output filters
527       combined_dims[1] = in.dimension(1);  // input filters
528       combined_dims[2] = in.dimension(2);  // spatial dimensions
529       for (int i = 3; i < NDIMS; ++i) {
530         combined_dims[2] *= in.dimension(i);
531       }
532 
533       GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
534       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>,
535                                   config.block_count, config.thread_per_block,
536                                   0, d.stream(), config.virtual_thread_count,
537                                   in.data(), combined_dims, out.data()));
538 
539     } else if (src_filter_format == FORMAT_OHWI) {
540       combined_dims[0] = in.dimension(0);  // output filters
541       combined_dims[1] = in.dimension(1);  // spatial dimensions
542       for (int i = 2; i < NDIMS - 1; i++) {
543         combined_dims[1] *= in.dimension(i);
544       }
545       combined_dims[2] = in.dimension(NDIMS - 1);  // input filters
546 
547       GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
548       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>,
549                                   config.block_count, config.thread_per_block,
550                                   0, d.stream(), config.virtual_thread_count,
551                                   in.data(), combined_dims, out.data()));
552 
553     } else {
554       // TODO(ezhulenev): Set error status in OpKernelContext instead.
555       LOG(FATAL) << "Unsupported filter format: "
556                  << ToString(src_filter_format);
557     }
558   }
559 };
560 
561 // A GPU helper function that converts input tensor to a larger output tensor,
562 // given proper padding values. The padded value is zero.
563 template <typename T, int NDIMS>
564 struct PadInput<GPUDevice, T, int, NDIMS> {
565   typedef GPUDevice Device;
566   void operator()(const Device& d,
567                   typename TTypes<T, NDIMS, int>::ConstTensor in,
568                   const std::array<int, NDIMS - 2>& padding_left,
569                   const std::array<int, NDIMS - 2>& padding_right,
570                   typename TTypes<T, NDIMS, int>::Tensor out,
571                   TensorFormat format, const T& padding_value) {
572     GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d);
573     Dimension<NDIMS> input_dims;
574     for (int i = 0; i < NDIMS; ++i) {
575       input_dims[i] = in.dimension(i);
576     }
577     Dimension<NDIMS> output_dims;
578     for (int i = 0; i < NDIMS; ++i) {
579       output_dims[i] = out.dimension(i);
580     }
581 
582     const Dimension<NDIMS - 2> padding_left_dim(padding_left);
583 
584     if (format == FORMAT_NHWC) {
585       TF_CHECK_OK(GpuLaunchKernel(
586           PadInputCustomKernelNHWC<T, NDIMS>, config.block_count,
587           config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
588           in.data(), input_dims, out.data(), output_dims, padding_left_dim,
589           padding_value));
590     } else if (format == FORMAT_NCHW) {
591       TF_CHECK_OK(GpuLaunchKernel(
592           PadInputCustomKernelNCHW<T, NDIMS>, config.block_count,
593           config.thread_per_block, 0, d.stream(), config.virtual_thread_count,
594           in.data(), input_dims, out.data(), output_dims, padding_left_dim,
595           padding_value));
596     } else {
597       LOG(FATAL) << "Invalid data format: " << format;
598     }
599   }
600 };
601 
602 // We want std::equal_to and std::greater, but they're not constexpr until
603 // C++14.
604 struct EqualTo {
605   constexpr bool operator()(int a, int b) const { return a == b; }
606 };
607 
608 struct GreaterThan {
609   constexpr bool operator()(int a, int b) const { return a > b; }
610 };
611 
612 // For each data type, the tile size possibility frontier denotes the tile size
613 // combinations that consume the most computational resources constrained by
614 // - number of threads per SM limit,
615 // - limit on size of the short dimension (<=15) due to the definition of
616 //   narrow matrix,
617 // - shared memory limit and
618 // - some experimentally determined, type-specific constraint on the product of
619 //   two side lengths to increase grid-level parallelism.
620 //
621 // A tile size combination lies on the frontier if and only if one or more
622 // constraint mentioned above is hit. Tile size combinations lying outside this
623 // frontier are either not possible, or are slower than the alternatives.
624 //
625 // It is instrumental to consider, for each data type, two subsets of the
626 // corresponding frontier:
627 // - long side frontier: the union of the biggest tile size combination for
628 //   each legal long side len.
629 // - non long side frontier: the frontier set minus the long side frontier.
630 //
631 // TileSizePossibilityFrontierCheck defines the frontier using only the long
632 // side frontier tile size combinations (since one can easily extrapolate
633 // the entire frontier from this subset). It serves as a utility function
634 // to help us determine where a tile size combination of interest lies with
635 // resepect to the frontier.
636 template <typename Op>
637 constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide,
638                                                 int TileShortSide,
639                                                 int size_of_t, Op op) {
640   // clang-format off
641 
642   return (size_of_t == 16 && ((TileLongSide == 32   && op(TileShortSide, 4))  ||
643                              (TileLongSide == 64   && op(TileShortSide, 4))  ||
644                              (TileLongSide == 128  && op(TileShortSide, 4))  ||
645                              (TileLongSide == 256  && op(TileShortSide, 2)))) ||
646           (size_of_t == 8 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
647                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
648                              (TileLongSide == 128  && op(TileShortSide, 8))  ||
649                              (TileLongSide == 256  && op(TileShortSide, 4))  ||
650                              (TileLongSide == 512  && op(TileShortSide, 2)))) ||
651           (size_of_t == 4 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
652                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
653                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
654                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
655                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
656                              (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
657           (size_of_t == 2 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
658                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
659                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
660                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
661                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
662                              (TileLongSide == 1024 && op(TileShortSide, 2)))) ||
663           (size_of_t == 1 && ((TileLongSide == 32   && op(TileShortSide, 15)) ||
664                              (TileLongSide == 64   && op(TileShortSide, 15)) ||
665                              (TileLongSide == 128  && op(TileShortSide, 15)) ||
666                              (TileLongSide == 256  && op(TileShortSide, 8))  ||
667                              (TileLongSide == 512  && op(TileShortSide, 4))  ||
668                              (TileLongSide == 1024 && op(TileShortSide, 2))));
669 
670   // clang-format on
671 }
672 
673 constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide,
674                                           int size_of_t) {
675   return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
676                                           size_of_t, EqualTo());
677 }
678 constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide,
679                                        int size_of_t) {
680   return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide,
681                                           size_of_t, GreaterThan());
682 }
683 constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide,
684                                              int TileShortSide, int size_of_t) {
685   // For a tile size combination (longside, shortside), lying on the frontier
686   // implies that (longside, shortside) is on or within the frontier but
687   // (longside*2, shortside) or (longside, shortside+1) is not. With the above
688   // criterion, we simply need to use !TileSizeOnLongSideFrontier to ensure that
689   // it is not on the long side frontier.
690   return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) &&
691          (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) ||
692           TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1,
693                                   size_of_t)) &&
694          !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t);
695 }
696 
697 // Helper function to launch a batch narrow matirx transpose kernel.
698 template <typename T, int TileLongSide, int TileShortSide>
699 void LaunchBatchNarrowMatrixTransposeKernel(
700     const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count,
701     const T* input, const Dimension<3>& input_dims, T* output) {
702   constexpr int NumThreads = TileLongSide;
703   if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) {
704     TF_CHECK_OK(GpuLaunchKernel(
705         SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide,
706                                               TileShortSide>,
707         total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
708         output));
709   } else {
710     TF_CHECK_OK(GpuLaunchKernel(
711         SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide,
712                                               TileLongSide>,
713         total_tiles_count, NumThreads, 0, d.stream(), input, input_dims,
714         output));
715   }
716 }
717 
718 // Recursive template function to search, in a trial-and-error manner, for the
719 // minimum tile size configuration satisfying the requested tile side lengths.
720 // An important invariant of this search procedure is that for an unsatisfied
721 // request, we always try doubling the long side len first, and only after
722 // the request is satisfied for the long side len do we begin incrementing
723 // the short side len.
724 //
725 // We have three specializations of this search function depending on where the
726 // current tile size combination lies with respect to the frontier.
727 // - It lies within the frontier. If request is not satisfied, for the next tile
728 // size combination, we first try doubling the long side len and if that does
729 // not work, we then increment the short side len.
730 // - It lies on the non long side frontier. If the request is not satisfied, we
731 // can only increment the short side len.
732 // - It lies on the long side frontier. We launch the kernel without checking if
733 // the request is satisfied or not.
734 template <typename T, int TileLongSide, int TileShortSide,
735           typename dummy = void>
736 struct BatchNarrowMatrixTransposeDispatcher {
737   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
738                    int total_tiles_count, const T* input,
739                    const Dimension<3>& input_dims, T* output) {
740     static_assert(
741         (TileLongSide & (TileLongSide - 1)) == 0,
742         "The length of the longer side of the tile is always a power of 2.");
743     bool request_satisfied =
744         std::max(tile_size_i, tile_size_j) <= TileLongSide &&
745         std::min(tile_size_i, tile_size_j) <= TileShortSide;
746 
747     if (request_satisfied) {
748       LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
749           d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
750           output);
751       return;
752     }
753 
754     // If the execution reaches here, then the kernel was not launched; we then
755     // determine whether it is the long side or the short side that falls short
756     // of the request and increase that parameter accordingly.
757     const bool long_side_request_not_satisfied =
758         std::max(tile_size_i, tile_size_j) > TileLongSide;
759 
760     if (long_side_request_not_satisfied) {
761       BatchNarrowMatrixTransposeDispatcher<
762           T, TileLongSide * 2, TileShortSide>::DoIt(d, tile_size_i, tile_size_j,
763                                                     total_tiles_count, input,
764                                                     input_dims, output);
765     } else {
766       BatchNarrowMatrixTransposeDispatcher<
767           T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
768                                                     total_tiles_count, input,
769                                                     input_dims, output);
770     }
771   }
772 };
773 
774 template <typename T, int TileLongSide, int TileShortSide>
775 struct BatchNarrowMatrixTransposeDispatcher<
776     T, TileLongSide, TileShortSide,
777     typename std::enable_if<TileSizeOnNonLongSideFrontier(
778                                 TileLongSide, TileShortSide, sizeof(T)),
779                             void>::type> {
780   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
781                    int total_tiles_count, const T* input,
782                    const Dimension<3>& input_dims, T* output) {
783     static_assert(
784         (TileLongSide & (TileLongSide - 1)) == 0,
785         "The length of the longer side of the tile is always a power of 2.");
786     bool request_satisfied =
787         std::max(tile_size_i, tile_size_j) <= TileLongSide &&
788         std::min(tile_size_i, tile_size_j) <= TileShortSide;
789 
790     if (request_satisfied) {
791       LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
792           d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
793           output);
794       return;
795     }
796 
797     // If the execution reaches here, then the kernel was not launched; since
798     // we are on the non long side frontier, we increment the short dimension
799     // and try again.
800     BatchNarrowMatrixTransposeDispatcher<
801         T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j,
802                                                   total_tiles_count, input,
803                                                   input_dims, output);
804   }
805 };
806 
807 template <typename T, int TileLongSide, int TileShortSide>
808 struct BatchNarrowMatrixTransposeDispatcher<
809     T, TileLongSide, TileShortSide,
810     typename std::enable_if<TileSizeOnLongSideFrontier(
811                                 TileLongSide, TileShortSide, sizeof(T)),
812                             void>::type> {
813   static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j,
814                    int total_tiles_count, const T* input,
815                    const Dimension<3>& input_dims, T* output) {
816     static_assert(
817         (TileLongSide & (TileLongSide - 1)) == 0,
818         "The length of the longer side of the tile is always a power of 2.");
819 
820     LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>(
821         d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims,
822         output);
823   }
824 };
825 
826 // This function tries to recover, in a brute force way, the frontier defined in
827 // TileSizePossibilityFrontierCheck as a vector of tile size combinations lying
828 // on the long side frontier. This vector is sufficient to determine the entire
829 // frontier.
830 //
831 // Note that if one changes the frontier definition in
832 // TileSizePossibilityFrontierCheck and forgets to set the largest short
833 // side len of the largest legal long side len to 2, this function will fail
834 // and crash the program.
835 template <int SizeOfT>
836 const std::vector<std::pair<int, int>>& GetTileSizesFrontier() {
837   static_assert(
838       SizeOfT <= 16,
839       "Currently, only data types of sizes 16 bytes or less are supported.");
840   static_assert((SizeOfT & (SizeOfT - 1)) == 0,
841                 "Data types must have sizes that are powers of 2.");
842 
843   // Expensive work to populate sizes, lazily run in a thread-safe
844   // manner the first time GetTileSizesFrontier<N> is called.
845   static auto* frontier = [] {
846     auto* frontier = new std::vector<std::pair<int, int>>();
847     const int kMaxLongSideLen = 1024;
848     const int kMaxShortSideLen = 15;
849     for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) {
850       for (int short_side = 2; short_side <= kMaxShortSideLen;
851            short_side += 1) {
852         if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) {
853           // The current combination lies on the frontier, thus we
854           // add it to the frontier definition.
855           frontier->push_back(std::make_pair(long_side, short_side));
856 
857           // The long side length is the largest one allowed iff its
858           // corresponding short side length is 2.
859           if (short_side == 2) return frontier;
860 
861           // We have exhausted all the possibilities in the frontier
862           // with the given long side length.
863           break;
864         }
865       }
866     }
867     LOG(FATAL)
868         << "The corresponding short side length of the largest long side "
869            "length has to be 2.";
870   }();
871   return *frontier;
872 }
873 
874 // Helper structs to help determine which data type to use given the size of
875 // the matrix data type. A transpose of elements of size N will use a kernel
876 // which operates on an array of TransposeElemType<N>::type.
877 template <int ElemBytes>
878 struct TransposeElemType;
879 template <>
880 struct TransposeElemType<1> {
881   using type = uint8;
882 };
883 template <>
884 struct TransposeElemType<2> {
885   using type = uint16;
886 };
887 template <>
888 struct TransposeElemType<4> {
889   using type = uint32;
890 };
891 template <>
892 struct TransposeElemType<8> {
893   using type = uint64;
894 };
895 template <>
896 struct TransposeElemType<16> {
897   using type = float4;
898 };
899 
900 // A helper function to make RunSwapDimension1And2InTensor3 concise. This
901 // helper function looks at the data type and input matrix sizes and decides
902 // the thread numbers and tile sizes to use.
903 template <typename T, bool conjugate = false>
904 void SwapDimension1And2InTensor3WithNarrowMatrices(
905     const GPUDevice& d, const T* input, const Dimension<3>& input_dims,
906     T* output, const int kMinDimensionToUseTiles) {
907   // Get available tile sizes here for the data type requested:
908   const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>();
909 
910   int tile_long_side_len = 0;
911   int tile_short_side_len = 0;
912   float lowest_cost = std::numeric_limits<float>::max();
913   int data_long_side = std::max(input_dims[1], input_dims[2]);
914 
915   for (auto tile_size_pair : tile_spec) {
916     int proposed_tile_long_side_len = tile_size_pair.first;
917 
918     // Number of threads that will not be doing anything useful when reading
919     // the matrix because the thread block size is bigger than the data block
920     // size.
921     int num_wasted_threads =
922         data_long_side - MathUtil::FloorOfRatio<int>(
923                              data_long_side, proposed_tile_long_side_len) *
924                              proposed_tile_long_side_len;
925 
926     int num_full_tiles = MathUtil::FloorOfRatio<int>(
927         data_long_side, proposed_tile_long_side_len);
928 
929     float cost = 0;
930 
931     // However, if we can execute two or more full tiles, then we gladly
932     // accept any number of wasted threads and ignore its cost.
933     if (num_full_tiles <= 1) cost = num_wasted_threads;
934 
935     // Using less than or equal to here because given the same cost, we
936     // would like to launch as many threads as possible.
937     if (cost <= lowest_cost) {
938       tile_long_side_len = proposed_tile_long_side_len;
939       tile_short_side_len = tile_size_pair.second;
940       lowest_cost = cost;
941     }
942   }
943 
944   // Request tile sizes such that the longer side of threadblock aligns with
945   // the longer side of input data block to maximize read throughput.
946   // The ideal tile shape is one where the length of the shorter side of the
947   // tile is equal to the length of the shorter side of the input matrix.
948   int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles
949                                   ? tile_long_side_len
950                                   : input_dims[1];
951   int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles
952                                   ? input_dims[2]
953                                   : tile_long_side_len;
954 
955   // Truncate the shorter size requested according to the manual limit set in
956   // tile_spec to make sure that we do not launch configurations violating
957   // hardware limits.
958   requested_tile_size_i =
959       requested_tile_size_i == tile_long_side_len
960           ? tile_long_side_len
961           : std::min(requested_tile_size_i, tile_short_side_len);
962   requested_tile_size_j =
963       requested_tile_size_j == tile_long_side_len
964           ? tile_long_side_len
965           : std::min(requested_tile_size_j, tile_short_side_len);
966 
967   Dimension<3> input_dims_in_tiles = {
968       input_dims[0],
969       MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i),
970       MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j),
971   };
972 
973   int total_tiles_count =
974       input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2];
975 
976   using ElemType = typename TransposeElemType<sizeof(T)>::type;
977   static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment.");
978   BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2>::DoIt(
979       d, requested_tile_size_i, requested_tile_size_j, total_tiles_count,
980       reinterpret_cast<const ElemType*>(input), input_dims,
981       reinterpret_cast<ElemType*>(output));
982 }
983 
984 // Launch the GPU kernel that would swap dimension-1 and dimension-2 in a
985 // 3D tensor. It looks at the shape of the incoming data, and decides the best
986 // strategy to launch.
987 template <typename T, bool conjugate = false>
988 void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
989                                     const Dimension<3>& input_dims, T* output) {
990   // If both dimensions are not trivial, use tiles for the actual swapping.
991   // If one dimension is trivial, use SmallDim kernel for swapping.
992   // Otherwise, the trivial swapping relying on the ldg cache is more efficient.
993   static const int kMinDimensionToUseTiles = 16;
994   static const int kMinDimensionToUseRectTiles = 96;
995 
996   bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles &&
997                       input_dims[2] >= kMinDimensionToUseTiles;
998   bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles ||
999                        input_dims[2] >= kMinDimensionToUseRectTiles;
1000   if (large_matrix) {
1001     // We get best performance when kTileSize is the number of threads in a warp
1002     // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256
1003     // threads.
1004     constexpr int kTileSize = 32;
1005     constexpr int kNumThreads = 256;
1006 
1007     Dimension<3> input_dims_in_tiles = {
1008         input_dims[0],
1009         MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize),
1010         MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize),
1011     };
1012 
1013     int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] *
1014                             input_dims_in_tiles[2];
1015     TF_CHECK_OK(GpuLaunchKernel(
1016         SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize,
1017                                               kTileSize, conjugate>,
1018         total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims,
1019         output));
1020 
1021   } else if (narrow_matrix) {
1022     SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>(
1023         d, input, input_dims, output, kMinDimensionToUseTiles);
1024   } else {
1025     int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
1026     GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d);
1027     TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>,
1028                                 config.block_count, config.thread_per_block, 0,
1029                                 d.stream(), config.virtual_thread_count, input,
1030                                 input_dims, output));
1031   }
1032 }
1033 
1034 // A GPU helper functor that does general dimension 1 and 2 switch for 3D
1035 // tensor.
1036 template <typename T, bool conjugate>
1037 struct SwapDimension1And2InTensor3<GPUDevice, T, conjugate> {
1038   typedef GPUDevice Device;
1039   void operator()(const Device& d, const T* in,
1040                   const gtl::ArraySlice<int64>& combined_dims, T* out) {
1041     Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
1042                                static_cast<int>(combined_dims[1]),
1043                                static_cast<int>(combined_dims[2])};
1044     RunSwapDimension1And2InTensor3<T, conjugate>(d, in, input_dims, out);
1045   }
1046 };
1047 
1048 // A GPU helper functor that does general dimension 0 and 2 switch for 3D
1049 // tensor.
1050 template <typename T, bool conjugate>
1051 struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
1052   typedef GPUDevice Device;
1053   void operator()(const Device& d, const T* in,
1054                   const gtl::ArraySlice<int64>& combined_dims, T* out) {
1055     Dimension<3> input_dims = {static_cast<int>(combined_dims[0]),
1056                                static_cast<int>(combined_dims[1]),
1057                                static_cast<int>(combined_dims[2])};
1058     size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
1059     GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d);
1060 
1061     auto out_ptr = reinterpret_cast<uintptr_t>(out);
1062     bool aligned = out_ptr % 16 == 0;
1063 
1064     bool use_vector = false;
1065     bool use_custom_config = false;
1066     if ((input_dims[0] <= 128 && input_dims[2] <= 128) ||
1067         input_dims[0] * input_dims[1] <= 128 ||
1068         input_dims[1] * input_dims[2] <= 8) {
1069       use_vector = true;
1070       use_custom_config = true;
1071     } else if (input_dims[1] * input_dims[2] <= 16384) {
1072       use_vector = true;
1073     }
1074 
1075     if (sizeof(T) == 2 && aligned && use_vector) {
1076       int block_count;
1077       if (use_custom_config) {
1078         block_count = (total_size + config.thread_per_block - 1) /
1079                       config.thread_per_block;
1080       } else {
1081         block_count = config.block_count;
1082       }
1083 
1084       TF_CHECK_OK(
1085           GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>,
1086                           block_count, config.thread_per_block / kUnroll, 0,
1087                           d.stream(), total_size, in, input_dims, out));
1088     } else {
1089       TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>,
1090                                   config.block_count, config.thread_per_block,
1091                                   0, d.stream(), config.virtual_thread_count,
1092                                   in, input_dims, out));
1093     }
1094   }
1095 };
1096 
1097 // A GPU helper functor that converts NHWC TensorFlow data format to
1098 // NCHW format that is accepted by Cudnn.
1099 template <typename T, int NDIMS>
1100 struct NHWCToNCHW<GPUDevice, T, NDIMS> {
1101   typedef GPUDevice Device;
1102   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
1103                   typename TTypes<T, NDIMS>::Tensor out) {
1104     Dimension<3> combined_dims;
1105     combined_dims[0] = in.dimension(0);  // N (batch)
1106     combined_dims[1] = in.dimension(1);  // spatial dimensions (HW)
1107     for (int i = 2; i < NDIMS - 1; ++i) {
1108       combined_dims[1] *= in.dimension(i);
1109     }
1110     combined_dims[2] = in.dimension(NDIMS - 1);  // C (channels)
1111     RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
1112   }
1113 };
1114 
1115 // A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow
1116 // Format.
1117 template <typename T, int NDIMS>
1118 struct NCHWToNHWC<GPUDevice, T, NDIMS> {
1119   typedef GPUDevice Device;
1120   void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
1121                   typename TTypes<T, NDIMS>::Tensor out) {
1122     Dimension<3> combined_dims;
1123     combined_dims[0] = in.dimension(0);  // N (batch)
1124     combined_dims[1] = in.dimension(1);  // C (channel)
1125     combined_dims[2] = in.dimension(2);  // spatial dimensions (HW)
1126     for (int i = 3; i < NDIMS; ++i) {
1127       combined_dims[2] *= in.dimension(i);
1128     }
1129     RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data());
1130   }
1131 };
1132 
1133 }  // namespace functor
1134 }  // namespace tensorflow
1135 
1136 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1137 
1138 #endif  // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_
1139