1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ 17 #define TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ 18 19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 20 21 #define EIGEN_USE_GPU 22 23 #include <algorithm> 24 #include <array> 25 #include <limits> 26 #include <utility> 27 28 #if GOOGLE_CUDA 29 #include "third_party/gpus/cuda/include/cuda.h" 30 #endif 31 #include "tensorflow/core/framework/register_types.h" 32 #include "tensorflow/core/kernels/conv_2d.h" 33 #include "tensorflow/core/lib/math/math_util.h" 34 #include "tensorflow/core/util/gpu_kernel_helper.h" 35 #include "tensorflow/core/util/tensor_format.h" 36 37 namespace tensorflow { 38 39 typedef Eigen::GpuDevice GPUDevice; 40 41 namespace functor { 42 43 template <typename T, bool conjugate> 44 struct maybe_conj { runmaybe_conj45 __device__ static __inline__ T run(T x) { 46 if (conjugate) { 47 return Eigen::numext::conj(x); 48 } else { 49 return x; 50 } 51 } 52 }; 53 54 // Partial specializations for Gpu types used to store complex numbers. 55 template <bool conjugate> 56 struct maybe_conj<float2, conjugate> { 57 __device__ static __inline__ float2 run(float2 c) { 58 if (conjugate) { 59 float2 c_conj; 60 c_conj.x = c.x; 61 c_conj.y = -c.y; 62 return c_conj; 63 } else { 64 return c; 65 } 66 } 67 }; 68 69 template <bool conjugate> 70 struct maybe_conj<double2, conjugate> { 71 __device__ static __inline__ double2 run(double2 c) { 72 if (conjugate) { 73 double2 c_conj; 74 c_conj.x = c.x; 75 c_conj.y = -c.y; 76 return c_conj; 77 } else { 78 return c; 79 } 80 } 81 }; 82 83 // TODO(mjanusz): Move this to a shared util file. 84 // A simple array that contains data that can be passed between CPU and GPU. 85 template <typename T, int IndexCount, T DefaultValue> 86 struct Array { 87 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T& operator[](int index) const { 88 return data[index]; 89 } 90 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T& operator[](int index) { 91 return data[index]; 92 } 93 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array() { 94 for (int i = 0; i < IndexCount; i++) { 95 data[i] = DefaultValue; 96 } 97 } 98 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0) { 99 data[0] = a0; 100 for (int i = 1; i < IndexCount; i++) { 101 data[i] = DefaultValue; 102 } 103 } 104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1) { 105 data[0] = a0; 106 data[1] = a1; 107 for (int i = 2; i < IndexCount; i++) { 108 data[i] = DefaultValue; 109 } 110 } 111 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Array(T a0, T a1, T a2) { 112 data[0] = a0; 113 data[1] = a1; 114 data[2] = a2; 115 for (int i = 3; i < IndexCount; i++) { 116 data[i] = DefaultValue; 117 } 118 } 119 EIGEN_STRONG_INLINE Array(const std::array<T, IndexCount>& array) { 120 for (int i = 0; i < IndexCount; i++) { 121 data[i] = array[i]; 122 } 123 } 124 T data[IndexCount]; 125 }; 126 127 // A dimension type with compile-time known size. 128 template <int IndexCount> 129 struct Dimension : Array<int, IndexCount, 1> { 130 typedef Array<int, IndexCount, 1> Base; 131 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension() : Base() {} 132 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0) : Base(a0) {} 133 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1) 134 : Base(a0, a1) {} 135 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Dimension(int a0, int a1, int a2) 136 : Base(a0, a1, a2) {} 137 EIGEN_STRONG_INLINE Dimension(const std::array<int, IndexCount>& array) 138 : Base(array) {} 139 }; 140 141 // An index type with compile-time known size. 142 template <int IndexCount> 143 struct Index : Array<int, IndexCount, 0> { 144 typedef Array<int, IndexCount, 0> Base; 145 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index() : Base() {} 146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0) : Base(a0) {} 147 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1) : Base(a0, a1) {} 148 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index(int a0, int a1, int a2) 149 : Base(a0, a1, a2) {} 150 }; 151 152 // A helper function that converts a tensor index into a flat array index. 153 template <int IndexCount> 154 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int TensorIndexToFlat( 155 const Index<IndexCount>& index, const Dimension<IndexCount>& dims) { 156 int flat_index = index[0]; 157 for (int i = 1; i < IndexCount; i++) { 158 flat_index = flat_index * dims[i] + index[i]; 159 } 160 return flat_index; 161 } 162 163 // A helper function that converts a flat array index into a tensor index. 164 template <int IndexCount> 165 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex( 166 int index, const Dimension<IndexCount>& dims) { 167 Index<IndexCount> tensor_index; 168 for (int i = IndexCount - 1; i >= 0; i--) { 169 int new_index = index / dims[i]; 170 tensor_index[i] = index - dims[i] * new_index; 171 index = new_index; 172 } 173 return tensor_index; 174 } 175 176 // A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to 177 // the given shuffle permutation in template parameters. Shuffle permutation 178 // <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0, 179 // 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1> 180 // will populate output so that input[x][y][z] is equal to (*output)[y][z][x]. 181 // 182 // Requires that nthreads is equal to the total number of elements in the input 183 // tensor. 184 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> 185 __global__ void ShuffleInTensor3Simple(int nthreads, 186 const T* __restrict__ input, 187 Dimension<3> input_dims, 188 T* __restrict__ output) { 189 Dimension<3> output_dims; 190 output_dims[sp0] = input_dims[0]; 191 output_dims[sp1] = input_dims[1]; 192 output_dims[sp2] = input_dims[2]; 193 194 // Iterate over output as opposed to iterating over input for better 195 // performance. Iterating over output will generate sequential writes and 196 // random reads that performs better compared to sequential reads and random 197 // writes. 198 GPU_1D_KERNEL_LOOP(output_index, nthreads) { 199 Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); 200 201 Index<3> input_tensor_index; 202 input_tensor_index[0] = output_tensor_index[sp0]; 203 input_tensor_index[1] = output_tensor_index[sp1]; 204 input_tensor_index[2] = output_tensor_index[sp2]; 205 206 int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 207 208 output[output_index] = 209 maybe_conj<T, conjugate>::run(ldg(input + input_index)); 210 } 211 } 212 213 static constexpr int kUnroll = 4; 214 215 template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> 216 __global__ void ShuffleInTensor3SimpleVector(int nthreads, 217 const T* __restrict__ input, 218 Dimension<3> input_dims, 219 T* __restrict__ output) { 220 Dimension<3> output_dims; 221 output_dims[sp0] = input_dims[0]; 222 output_dims[sp1] = input_dims[1]; 223 output_dims[sp2] = input_dims[2]; 224 225 const int stride = blockDim.x * gridDim.x * kUnroll; 226 const int tid = blockIdx.x * blockDim.x + threadIdx.x; 227 T buf[kUnroll]; 228 229 int output_index; 230 for (output_index = tid * kUnroll; output_index + kUnroll - 1 < nthreads; 231 output_index += stride) { 232 #pragma unroll 233 for (int i = 0; i < kUnroll; i++) { 234 int output_index_i = output_index + i; 235 Index<3> output_tensor_index = 236 FlatToTensorIndex(output_index_i, output_dims); 237 Index<3> input_tensor_index; 238 input_tensor_index[0] = output_tensor_index[sp0]; 239 input_tensor_index[1] = output_tensor_index[sp1]; 240 input_tensor_index[2] = output_tensor_index[sp2]; 241 242 int input_index_i = TensorIndexToFlat(input_tensor_index, input_dims); 243 buf[i] = maybe_conj<T, conjugate>::run(ldg(input + input_index_i)); 244 } 245 float2* out = reinterpret_cast<float2*>(output + output_index); 246 *out = *reinterpret_cast<float2*>(buf); 247 } 248 249 for (; output_index < nthreads; ++output_index) { 250 Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); 251 252 Index<3> input_tensor_index; 253 input_tensor_index[0] = output_tensor_index[sp0]; 254 input_tensor_index[1] = output_tensor_index[sp1]; 255 input_tensor_index[2] = output_tensor_index[sp2]; 256 257 int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 258 259 output[output_index] = 260 maybe_conj<T, conjugate>::run(ldg(input + input_index)); 261 } 262 } 263 264 // Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor, 265 // where dimensions are zero-based: output[i][j][k] = input[i][k][j]. 266 // 267 // Each thread block operates on a single tile, a rectangle of dimensions 268 // TileSizeI x TileSizeJ. 269 // 270 // In general, for best performance, you should probably set TileSizeI, 271 // TileSizeJ equal to the number of threads in a warp (32 in nvidia GPUs). 272 // With a TileSizeI, TileSizeJ of 32, NumThreads of 128 or 256 seems to get 273 // the best performance on K40 GPUs. 274 template <typename T, int NumThreads, int TileSizeI, int TileSizeJ, 275 bool conjugate = false> 276 __global__ void SwapDimension1And2InTensor3UsingTiles( 277 const T* __restrict__ input, Dimension<3> input_dims, 278 T* __restrict__ output) { 279 eigen_assert(blockDim.x == NumThreads); 280 eigen_assert(blockDim.y == 1); 281 eigen_assert(blockDim.z == 1); 282 eigen_assert(gridDim.y == 1); 283 eigen_assert(gridDim.z == 1); 284 285 constexpr int ReadRowPerPass = NumThreads / TileSizeJ; 286 constexpr int WriteRowPerPass = NumThreads / TileSizeI; 287 // One extra line in the inner dimension to avoid share memory bank conflict. 288 // This is to mimic the following, but no constructor of T can be invoked. 289 // __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; 290 #if GOOGLE_CUDA 291 __shared__ __align__( 292 alignof(T)) char shared_mem_raw[TileSizeI * (TileSizeJ + 1) * sizeof(T)]; 293 typedef T(*SharedMemoryTile)[TileSizeJ + 1]; 294 SharedMemoryTile shared_memory_tile = 295 reinterpret_cast<SharedMemoryTile>(shared_mem_raw); 296 #elif TENSORFLOW_USE_ROCM 297 __shared__ T shared_memory_tile[TileSizeI][TileSizeJ + 1]; 298 #endif 299 300 int x = threadIdx.x; 301 302 Dimension<3> output_dims = { 303 input_dims[0], 304 input_dims[2], 305 input_dims[1], 306 }; 307 308 Dimension<3> input_dims_in_tiles = { 309 input_dims[0], 310 (input_dims[1] + TileSizeI - 1) / TileSizeI, 311 (input_dims[2] + TileSizeJ - 1) / TileSizeJ, 312 }; 313 314 Index<3> input_tile_index = 315 FlatToTensorIndex(blockIdx.x, input_dims_in_tiles); 316 317 Index<3> input_tile_origin = { 318 input_tile_index[0], 319 input_tile_index[1] * TileSizeI, 320 input_tile_index[2] * TileSizeJ, 321 }; 322 323 int input_origin_flat_index = 324 TensorIndexToFlat(input_tile_origin, input_dims); 325 326 bool full_tile = true; 327 int tile_width = TileSizeJ; 328 329 // Only the last row or column may not have the full size. 330 if (input_tile_index[2] == input_dims_in_tiles[2] - 1) { 331 tile_width = input_dims[2] - (input_dims_in_tiles[2] - 1) * TileSizeJ; 332 full_tile &= false; 333 } 334 335 int tile_height = TileSizeI; 336 337 if (input_tile_index[1] == input_dims_in_tiles[1] - 1) { 338 tile_height = input_dims[1] - (input_dims_in_tiles[1] - 1) * TileSizeI; 339 full_tile &= false; 340 } 341 342 // Calculate effective thread number. This ensures that we use the largest 343 // number of threads available to form a regular thread block with no 344 // trailing incomplete lines. 345 constexpr int in_effective_thread_num = NumThreads / TileSizeJ * TileSizeJ; 346 347 if (x < in_effective_thread_num) { 348 // Orient the logical thread block with respect to the input array. 349 // ie. align the contiguous dimension of thread blocks with the contiguous 350 // dimension of the input array. 351 int ti = x / TileSizeJ; 352 int tj = x % TileSizeJ; 353 int input_index = input_origin_flat_index + ti * input_dims[2] + tj; 354 int input_increment = ReadRowPerPass * input_dims[2]; 355 356 if (full_tile) { 357 #pragma unroll 358 for (int i_loc = ti; i_loc < (TileSizeI); i_loc += ReadRowPerPass) { 359 shared_memory_tile[i_loc][tj] = 360 maybe_conj<T, conjugate>::run(input[input_index]); 361 input_index += input_increment; 362 } 363 } else { 364 if (tj < tile_width) { 365 for (int i_loc = ti; i_loc < (tile_height); i_loc += ReadRowPerPass) { 366 shared_memory_tile[i_loc][tj] = 367 maybe_conj<T, conjugate>::run(input[input_index]); 368 input_index += input_increment; 369 } 370 } 371 } 372 } 373 374 __syncthreads(); 375 376 Index<3> output_tile_index = { 377 input_tile_index[0], 378 input_tile_index[2], 379 input_tile_index[1], 380 }; 381 382 Index<3> output_tile_origin = { 383 output_tile_index[0], 384 output_tile_index[1] * TileSizeJ, 385 output_tile_index[2] * TileSizeI, 386 }; 387 388 int output_origin_flat_index = 389 TensorIndexToFlat(output_tile_origin, output_dims); 390 391 constexpr int out_effective_thread_num = NumThreads / TileSizeI * TileSizeI; 392 393 if (x < out_effective_thread_num) { 394 // Re-orient the logical thread block with respect to the output array. 395 // ie. align the contiguous dimension of thread blocks with contiguous 396 // dimension of the output array. 397 int ti = x / TileSizeI; 398 int tj = x % TileSizeI; 399 int output_index = output_origin_flat_index + ti * output_dims[2] + tj; 400 int output_increment = WriteRowPerPass * output_dims[2]; 401 402 if (full_tile) { 403 #pragma unroll 404 for (int i_loc = ti; i_loc < (TileSizeJ); i_loc += WriteRowPerPass) { 405 output[output_index] = shared_memory_tile[tj][i_loc]; 406 output_index += output_increment; 407 } 408 } else { 409 if (tj < tile_height) { 410 for (int i_loc = ti; i_loc < (tile_width); i_loc += WriteRowPerPass) { 411 output[output_index] = shared_memory_tile[tj][i_loc]; 412 output_index += output_increment; 413 } 414 } 415 } 416 } 417 } 418 419 // A Gpu custom kernel that convert input to output, given proper padding on 420 // the left and the top. 421 template <typename T, int NDIMS> 422 __global__ void PadInputCustomKernelNHWC( 423 int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims, 424 T* __restrict__ output, Dimension<NDIMS> output_dims, 425 Dimension<NDIMS - 2> padding_left, T padding_value) { 426 GPU_1D_KERNEL_LOOP(index, nthreads) { 427 int output_index = index; 428 Index<NDIMS> output_tensor_index = 429 FlatToTensorIndex(output_index, output_dims); 430 431 Index<NDIMS> input_tensor_index; 432 input_tensor_index[0] = output_tensor_index[0]; // batch 433 bool ok = true; 434 for (int i = 1; i < NDIMS - 1; i++) { 435 input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 1]; 436 ok &= 437 (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); 438 } 439 input_tensor_index[NDIMS - 1] = output_tensor_index[NDIMS - 1]; // channels 440 441 if (ok) { 442 const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 443 output[output_index] = input[input_index]; 444 } else { 445 output[output_index] = padding_value; 446 } 447 } 448 } 449 450 template <typename T, int NDIMS> 451 __global__ void PadInputCustomKernelNCHW( 452 int nthreads, const T* __restrict__ input, Dimension<NDIMS> input_dims, 453 T* __restrict__ output, Dimension<NDIMS> output_dims, 454 Dimension<NDIMS - 2> padding_left, T padding_value) { 455 GPU_1D_KERNEL_LOOP(index, nthreads) { 456 int output_index = index; 457 Index<NDIMS> output_tensor_index = 458 FlatToTensorIndex(output_index, output_dims); 459 460 Index<NDIMS> input_tensor_index; 461 input_tensor_index[0] = output_tensor_index[0]; // batch 462 input_tensor_index[1] = output_tensor_index[1]; // channels 463 bool ok = true; 464 for (int i = 2; i < NDIMS; i++) { 465 input_tensor_index[i] = output_tensor_index[i] - padding_left[i - 2]; 466 ok &= 467 (input_tensor_index[i] >= 0 && input_tensor_index[i] < input_dims[i]); 468 } 469 470 if (ok) { 471 const int input_index = TensorIndexToFlat(input_tensor_index, input_dims); 472 output[output_index] = input[input_index]; 473 } else { 474 output[output_index] = padding_value; 475 } 476 } 477 } 478 479 // A GPU helper function that converts TensorFlow filter format to Cudnn filter 480 // format. 481 template <typename T, int NDIMS> 482 struct TransformFilter<GPUDevice, T, int, NDIMS> { 483 typedef GPUDevice Device; 484 void operator()(const Device& d, FilterTensorFormat dst_filter_format, 485 typename TTypes<T, NDIMS, int>::ConstTensor in, 486 typename TTypes<T, NDIMS, int>::Tensor out) { 487 Dimension<3> combined_dims; 488 combined_dims[0] = in.dimension(0); // spatial dimensions 489 for (int i = 1; i < NDIMS - 2; i++) { 490 combined_dims[0] *= in.dimension(i); 491 } 492 combined_dims[1] = in.dimension(NDIMS - 2); // input filters 493 combined_dims[2] = in.dimension(NDIMS - 1); // output filters 494 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 495 496 if (dst_filter_format == FORMAT_OIHW) { 497 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>, 498 config.block_count, config.thread_per_block, 499 0, d.stream(), config.virtual_thread_count, 500 in.data(), combined_dims, out.data())); 501 502 } else if (dst_filter_format == FORMAT_OHWI) { 503 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 1, 2, 0>, 504 config.block_count, config.thread_per_block, 505 0, d.stream(), config.virtual_thread_count, 506 in.data(), combined_dims, out.data())); 507 508 } else { 509 LOG(ERROR) << "Unsupported filter format: " 510 << ToString(dst_filter_format); 511 } 512 } 513 }; 514 515 // Converts Cudnn filter format OIHW or OHWI back to TensorFlow filter format 516 // HWIO. 517 template <typename T, int NDIMS> 518 struct ReverseTransformFilter<GPUDevice, T, NDIMS> { 519 typedef GPUDevice Device; 520 void operator()(const Device& d, FilterTensorFormat src_filter_format, 521 typename TTypes<T, NDIMS>::ConstTensor in, 522 typename TTypes<T, NDIMS>::Tensor out) { 523 Dimension<3> combined_dims; 524 525 if (src_filter_format == FORMAT_OIHW) { 526 combined_dims[0] = in.dimension(0); // output filters 527 combined_dims[1] = in.dimension(1); // input filters 528 combined_dims[2] = in.dimension(2); // spatial dimensions 529 for (int i = 3; i < NDIMS; ++i) { 530 combined_dims[2] *= in.dimension(i); 531 } 532 533 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 534 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0>, 535 config.block_count, config.thread_per_block, 536 0, d.stream(), config.virtual_thread_count, 537 in.data(), combined_dims, out.data())); 538 539 } else if (src_filter_format == FORMAT_OHWI) { 540 combined_dims[0] = in.dimension(0); // output filters 541 combined_dims[1] = in.dimension(1); // spatial dimensions 542 for (int i = 2; i < NDIMS - 1; i++) { 543 combined_dims[1] *= in.dimension(i); 544 } 545 combined_dims[2] = in.dimension(NDIMS - 1); // input filters 546 547 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 548 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 0, 1>, 549 config.block_count, config.thread_per_block, 550 0, d.stream(), config.virtual_thread_count, 551 in.data(), combined_dims, out.data())); 552 553 } else { 554 // TODO(ezhulenev): Set error status in OpKernelContext instead. 555 LOG(FATAL) << "Unsupported filter format: " 556 << ToString(src_filter_format); 557 } 558 } 559 }; 560 561 // A GPU helper function that converts input tensor to a larger output tensor, 562 // given proper padding values. The padded value is zero. 563 template <typename T, int NDIMS> 564 struct PadInput<GPUDevice, T, int, NDIMS> { 565 typedef GPUDevice Device; 566 void operator()(const Device& d, 567 typename TTypes<T, NDIMS, int>::ConstTensor in, 568 const std::array<int, NDIMS - 2>& padding_left, 569 const std::array<int, NDIMS - 2>& padding_right, 570 typename TTypes<T, NDIMS, int>::Tensor out, 571 TensorFormat format, const T& padding_value) { 572 GpuLaunchConfig config = GetGpuLaunchConfig(out.size(), d); 573 Dimension<NDIMS> input_dims; 574 for (int i = 0; i < NDIMS; ++i) { 575 input_dims[i] = in.dimension(i); 576 } 577 Dimension<NDIMS> output_dims; 578 for (int i = 0; i < NDIMS; ++i) { 579 output_dims[i] = out.dimension(i); 580 } 581 582 const Dimension<NDIMS - 2> padding_left_dim(padding_left); 583 584 if (format == FORMAT_NHWC) { 585 TF_CHECK_OK(GpuLaunchKernel( 586 PadInputCustomKernelNHWC<T, NDIMS>, config.block_count, 587 config.thread_per_block, 0, d.stream(), config.virtual_thread_count, 588 in.data(), input_dims, out.data(), output_dims, padding_left_dim, 589 padding_value)); 590 } else if (format == FORMAT_NCHW) { 591 TF_CHECK_OK(GpuLaunchKernel( 592 PadInputCustomKernelNCHW<T, NDIMS>, config.block_count, 593 config.thread_per_block, 0, d.stream(), config.virtual_thread_count, 594 in.data(), input_dims, out.data(), output_dims, padding_left_dim, 595 padding_value)); 596 } else { 597 LOG(FATAL) << "Invalid data format: " << format; 598 } 599 } 600 }; 601 602 // We want std::equal_to and std::greater, but they're not constexpr until 603 // C++14. 604 struct EqualTo { 605 constexpr bool operator()(int a, int b) const { return a == b; } 606 }; 607 608 struct GreaterThan { 609 constexpr bool operator()(int a, int b) const { return a > b; } 610 }; 611 612 // For each data type, the tile size possibility frontier denotes the tile size 613 // combinations that consume the most computational resources constrained by 614 // - number of threads per SM limit, 615 // - limit on size of the short dimension (<=15) due to the definition of 616 // narrow matrix, 617 // - shared memory limit and 618 // - some experimentally determined, type-specific constraint on the product of 619 // two side lengths to increase grid-level parallelism. 620 // 621 // A tile size combination lies on the frontier if and only if one or more 622 // constraint mentioned above is hit. Tile size combinations lying outside this 623 // frontier are either not possible, or are slower than the alternatives. 624 // 625 // It is instrumental to consider, for each data type, two subsets of the 626 // corresponding frontier: 627 // - long side frontier: the union of the biggest tile size combination for 628 // each legal long side len. 629 // - non long side frontier: the frontier set minus the long side frontier. 630 // 631 // TileSizePossibilityFrontierCheck defines the frontier using only the long 632 // side frontier tile size combinations (since one can easily extrapolate 633 // the entire frontier from this subset). It serves as a utility function 634 // to help us determine where a tile size combination of interest lies with 635 // resepect to the frontier. 636 template <typename Op> 637 constexpr bool TileSizePossibilityFrontierCheck(int TileLongSide, 638 int TileShortSide, 639 int size_of_t, Op op) { 640 // clang-format off 641 642 return (size_of_t == 16 && ((TileLongSide == 32 && op(TileShortSide, 4)) || 643 (TileLongSide == 64 && op(TileShortSide, 4)) || 644 (TileLongSide == 128 && op(TileShortSide, 4)) || 645 (TileLongSide == 256 && op(TileShortSide, 2)))) || 646 (size_of_t == 8 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 647 (TileLongSide == 64 && op(TileShortSide, 15)) || 648 (TileLongSide == 128 && op(TileShortSide, 8)) || 649 (TileLongSide == 256 && op(TileShortSide, 4)) || 650 (TileLongSide == 512 && op(TileShortSide, 2)))) || 651 (size_of_t == 4 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 652 (TileLongSide == 64 && op(TileShortSide, 15)) || 653 (TileLongSide == 128 && op(TileShortSide, 15)) || 654 (TileLongSide == 256 && op(TileShortSide, 8)) || 655 (TileLongSide == 512 && op(TileShortSide, 4)) || 656 (TileLongSide == 1024 && op(TileShortSide, 2)))) || 657 (size_of_t == 2 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 658 (TileLongSide == 64 && op(TileShortSide, 15)) || 659 (TileLongSide == 128 && op(TileShortSide, 15)) || 660 (TileLongSide == 256 && op(TileShortSide, 8)) || 661 (TileLongSide == 512 && op(TileShortSide, 4)) || 662 (TileLongSide == 1024 && op(TileShortSide, 2)))) || 663 (size_of_t == 1 && ((TileLongSide == 32 && op(TileShortSide, 15)) || 664 (TileLongSide == 64 && op(TileShortSide, 15)) || 665 (TileLongSide == 128 && op(TileShortSide, 15)) || 666 (TileLongSide == 256 && op(TileShortSide, 8)) || 667 (TileLongSide == 512 && op(TileShortSide, 4)) || 668 (TileLongSide == 1024 && op(TileShortSide, 2)))); 669 670 // clang-format on 671 } 672 673 constexpr bool TileSizeOnLongSideFrontier(int TileLongSide, int TileShortSide, 674 int size_of_t) { 675 return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide, 676 size_of_t, EqualTo()); 677 } 678 constexpr bool TileSizeOutsideFrontier(int TileLongSide, int TileShortSide, 679 int size_of_t) { 680 return TileSizePossibilityFrontierCheck(TileLongSide, TileShortSide, 681 size_of_t, GreaterThan()); 682 } 683 constexpr bool TileSizeOnNonLongSideFrontier(int TileLongSide, 684 int TileShortSide, int size_of_t) { 685 // For a tile size combination (longside, shortside), lying on the frontier 686 // implies that (longside, shortside) is on or within the frontier but 687 // (longside*2, shortside) or (longside, shortside+1) is not. With the above 688 // criterion, we simply need to use !TileSizeOnLongSideFrontier to ensure that 689 // it is not on the long side frontier. 690 return !TileSizeOutsideFrontier(TileLongSide, TileShortSide, size_of_t) && 691 (TileSizeOutsideFrontier(TileLongSide * 2, TileShortSide, size_of_t) || 692 TileSizeOutsideFrontier(TileLongSide, TileShortSide + 1, 693 size_of_t)) && 694 !TileSizeOnLongSideFrontier(TileLongSide, TileShortSide, size_of_t); 695 } 696 697 // Helper function to launch a batch narrow matirx transpose kernel. 698 template <typename T, int TileLongSide, int TileShortSide> 699 void LaunchBatchNarrowMatrixTransposeKernel( 700 const GPUDevice& d, int tile_size_i, int tile_size_j, int total_tiles_count, 701 const T* input, const Dimension<3>& input_dims, T* output) { 702 constexpr int NumThreads = TileLongSide; 703 if (tile_size_i <= TileLongSide && tile_size_j <= TileShortSide) { 704 TF_CHECK_OK(GpuLaunchKernel( 705 SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileLongSide, 706 TileShortSide>, 707 total_tiles_count, NumThreads, 0, d.stream(), input, input_dims, 708 output)); 709 } else { 710 TF_CHECK_OK(GpuLaunchKernel( 711 SwapDimension1And2InTensor3UsingTiles<T, NumThreads, TileShortSide, 712 TileLongSide>, 713 total_tiles_count, NumThreads, 0, d.stream(), input, input_dims, 714 output)); 715 } 716 } 717 718 // Recursive template function to search, in a trial-and-error manner, for the 719 // minimum tile size configuration satisfying the requested tile side lengths. 720 // An important invariant of this search procedure is that for an unsatisfied 721 // request, we always try doubling the long side len first, and only after 722 // the request is satisfied for the long side len do we begin incrementing 723 // the short side len. 724 // 725 // We have three specializations of this search function depending on where the 726 // current tile size combination lies with respect to the frontier. 727 // - It lies within the frontier. If request is not satisfied, for the next tile 728 // size combination, we first try doubling the long side len and if that does 729 // not work, we then increment the short side len. 730 // - It lies on the non long side frontier. If the request is not satisfied, we 731 // can only increment the short side len. 732 // - It lies on the long side frontier. We launch the kernel without checking if 733 // the request is satisfied or not. 734 template <typename T, int TileLongSide, int TileShortSide, 735 typename dummy = void> 736 struct BatchNarrowMatrixTransposeDispatcher { 737 static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, 738 int total_tiles_count, const T* input, 739 const Dimension<3>& input_dims, T* output) { 740 static_assert( 741 (TileLongSide & (TileLongSide - 1)) == 0, 742 "The length of the longer side of the tile is always a power of 2."); 743 bool request_satisfied = 744 std::max(tile_size_i, tile_size_j) <= TileLongSide && 745 std::min(tile_size_i, tile_size_j) <= TileShortSide; 746 747 if (request_satisfied) { 748 LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>( 749 d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, 750 output); 751 return; 752 } 753 754 // If the execution reaches here, then the kernel was not launched; we then 755 // determine whether it is the long side or the short side that falls short 756 // of the request and increase that parameter accordingly. 757 const bool long_side_request_not_satisfied = 758 std::max(tile_size_i, tile_size_j) > TileLongSide; 759 760 if (long_side_request_not_satisfied) { 761 BatchNarrowMatrixTransposeDispatcher< 762 T, TileLongSide * 2, TileShortSide>::DoIt(d, tile_size_i, tile_size_j, 763 total_tiles_count, input, 764 input_dims, output); 765 } else { 766 BatchNarrowMatrixTransposeDispatcher< 767 T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j, 768 total_tiles_count, input, 769 input_dims, output); 770 } 771 } 772 }; 773 774 template <typename T, int TileLongSide, int TileShortSide> 775 struct BatchNarrowMatrixTransposeDispatcher< 776 T, TileLongSide, TileShortSide, 777 typename std::enable_if<TileSizeOnNonLongSideFrontier( 778 TileLongSide, TileShortSide, sizeof(T)), 779 void>::type> { 780 static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, 781 int total_tiles_count, const T* input, 782 const Dimension<3>& input_dims, T* output) { 783 static_assert( 784 (TileLongSide & (TileLongSide - 1)) == 0, 785 "The length of the longer side of the tile is always a power of 2."); 786 bool request_satisfied = 787 std::max(tile_size_i, tile_size_j) <= TileLongSide && 788 std::min(tile_size_i, tile_size_j) <= TileShortSide; 789 790 if (request_satisfied) { 791 LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>( 792 d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, 793 output); 794 return; 795 } 796 797 // If the execution reaches here, then the kernel was not launched; since 798 // we are on the non long side frontier, we increment the short dimension 799 // and try again. 800 BatchNarrowMatrixTransposeDispatcher< 801 T, TileLongSide, TileShortSide + 1>::DoIt(d, tile_size_i, tile_size_j, 802 total_tiles_count, input, 803 input_dims, output); 804 } 805 }; 806 807 template <typename T, int TileLongSide, int TileShortSide> 808 struct BatchNarrowMatrixTransposeDispatcher< 809 T, TileLongSide, TileShortSide, 810 typename std::enable_if<TileSizeOnLongSideFrontier( 811 TileLongSide, TileShortSide, sizeof(T)), 812 void>::type> { 813 static void DoIt(const GPUDevice& d, int tile_size_i, int tile_size_j, 814 int total_tiles_count, const T* input, 815 const Dimension<3>& input_dims, T* output) { 816 static_assert( 817 (TileLongSide & (TileLongSide - 1)) == 0, 818 "The length of the longer side of the tile is always a power of 2."); 819 820 LaunchBatchNarrowMatrixTransposeKernel<T, TileLongSide, TileShortSide>( 821 d, tile_size_i, tile_size_j, total_tiles_count, input, input_dims, 822 output); 823 } 824 }; 825 826 // This function tries to recover, in a brute force way, the frontier defined in 827 // TileSizePossibilityFrontierCheck as a vector of tile size combinations lying 828 // on the long side frontier. This vector is sufficient to determine the entire 829 // frontier. 830 // 831 // Note that if one changes the frontier definition in 832 // TileSizePossibilityFrontierCheck and forgets to set the largest short 833 // side len of the largest legal long side len to 2, this function will fail 834 // and crash the program. 835 template <int SizeOfT> 836 const std::vector<std::pair<int, int>>& GetTileSizesFrontier() { 837 static_assert( 838 SizeOfT <= 16, 839 "Currently, only data types of sizes 16 bytes or less are supported."); 840 static_assert((SizeOfT & (SizeOfT - 1)) == 0, 841 "Data types must have sizes that are powers of 2."); 842 843 // Expensive work to populate sizes, lazily run in a thread-safe 844 // manner the first time GetTileSizesFrontier<N> is called. 845 static auto* frontier = [] { 846 auto* frontier = new std::vector<std::pair<int, int>>(); 847 const int kMaxLongSideLen = 1024; 848 const int kMaxShortSideLen = 15; 849 for (int long_side = 32; long_side <= kMaxLongSideLen; long_side *= 2) { 850 for (int short_side = 2; short_side <= kMaxShortSideLen; 851 short_side += 1) { 852 if (TileSizeOnLongSideFrontier(long_side, short_side, SizeOfT)) { 853 // The current combination lies on the frontier, thus we 854 // add it to the frontier definition. 855 frontier->push_back(std::make_pair(long_side, short_side)); 856 857 // The long side length is the largest one allowed iff its 858 // corresponding short side length is 2. 859 if (short_side == 2) return frontier; 860 861 // We have exhausted all the possibilities in the frontier 862 // with the given long side length. 863 break; 864 } 865 } 866 } 867 LOG(FATAL) 868 << "The corresponding short side length of the largest long side " 869 "length has to be 2."; 870 }(); 871 return *frontier; 872 } 873 874 // Helper structs to help determine which data type to use given the size of 875 // the matrix data type. A transpose of elements of size N will use a kernel 876 // which operates on an array of TransposeElemType<N>::type. 877 template <int ElemBytes> 878 struct TransposeElemType; 879 template <> 880 struct TransposeElemType<1> { 881 using type = uint8; 882 }; 883 template <> 884 struct TransposeElemType<2> { 885 using type = uint16; 886 }; 887 template <> 888 struct TransposeElemType<4> { 889 using type = uint32; 890 }; 891 template <> 892 struct TransposeElemType<8> { 893 using type = uint64; 894 }; 895 template <> 896 struct TransposeElemType<16> { 897 using type = float4; 898 }; 899 900 // A helper function to make RunSwapDimension1And2InTensor3 concise. This 901 // helper function looks at the data type and input matrix sizes and decides 902 // the thread numbers and tile sizes to use. 903 template <typename T, bool conjugate = false> 904 void SwapDimension1And2InTensor3WithNarrowMatrices( 905 const GPUDevice& d, const T* input, const Dimension<3>& input_dims, 906 T* output, const int kMinDimensionToUseTiles) { 907 // Get available tile sizes here for the data type requested: 908 const auto& tile_spec = GetTileSizesFrontier<sizeof(T)>(); 909 910 int tile_long_side_len = 0; 911 int tile_short_side_len = 0; 912 float lowest_cost = std::numeric_limits<float>::max(); 913 int data_long_side = std::max(input_dims[1], input_dims[2]); 914 915 for (auto tile_size_pair : tile_spec) { 916 int proposed_tile_long_side_len = tile_size_pair.first; 917 918 // Number of threads that will not be doing anything useful when reading 919 // the matrix because the thread block size is bigger than the data block 920 // size. 921 int num_wasted_threads = 922 data_long_side - MathUtil::FloorOfRatio<int>( 923 data_long_side, proposed_tile_long_side_len) * 924 proposed_tile_long_side_len; 925 926 int num_full_tiles = MathUtil::FloorOfRatio<int>( 927 data_long_side, proposed_tile_long_side_len); 928 929 float cost = 0; 930 931 // However, if we can execute two or more full tiles, then we gladly 932 // accept any number of wasted threads and ignore its cost. 933 if (num_full_tiles <= 1) cost = num_wasted_threads; 934 935 // Using less than or equal to here because given the same cost, we 936 // would like to launch as many threads as possible. 937 if (cost <= lowest_cost) { 938 tile_long_side_len = proposed_tile_long_side_len; 939 tile_short_side_len = tile_size_pair.second; 940 lowest_cost = cost; 941 } 942 } 943 944 // Request tile sizes such that the longer side of threadblock aligns with 945 // the longer side of input data block to maximize read throughput. 946 // The ideal tile shape is one where the length of the shorter side of the 947 // tile is equal to the length of the shorter side of the input matrix. 948 int requested_tile_size_i = input_dims[1] >= kMinDimensionToUseTiles 949 ? tile_long_side_len 950 : input_dims[1]; 951 int requested_tile_size_j = input_dims[1] >= kMinDimensionToUseTiles 952 ? input_dims[2] 953 : tile_long_side_len; 954 955 // Truncate the shorter size requested according to the manual limit set in 956 // tile_spec to make sure that we do not launch configurations violating 957 // hardware limits. 958 requested_tile_size_i = 959 requested_tile_size_i == tile_long_side_len 960 ? tile_long_side_len 961 : std::min(requested_tile_size_i, tile_short_side_len); 962 requested_tile_size_j = 963 requested_tile_size_j == tile_long_side_len 964 ? tile_long_side_len 965 : std::min(requested_tile_size_j, tile_short_side_len); 966 967 Dimension<3> input_dims_in_tiles = { 968 input_dims[0], 969 MathUtil::CeilOfRatio<int>(input_dims[1], requested_tile_size_i), 970 MathUtil::CeilOfRatio<int>(input_dims[2], requested_tile_size_j), 971 }; 972 973 int total_tiles_count = 974 input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; 975 976 using ElemType = typename TransposeElemType<sizeof(T)>::type; 977 static_assert(alignof(T) >= alignof(ElemType), "Unexpected data alignment."); 978 BatchNarrowMatrixTransposeDispatcher<ElemType, 32, 2>::DoIt( 979 d, requested_tile_size_i, requested_tile_size_j, total_tiles_count, 980 reinterpret_cast<const ElemType*>(input), input_dims, 981 reinterpret_cast<ElemType*>(output)); 982 } 983 984 // Launch the GPU kernel that would swap dimension-1 and dimension-2 in a 985 // 3D tensor. It looks at the shape of the incoming data, and decides the best 986 // strategy to launch. 987 template <typename T, bool conjugate = false> 988 void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, 989 const Dimension<3>& input_dims, T* output) { 990 // If both dimensions are not trivial, use tiles for the actual swapping. 991 // If one dimension is trivial, use SmallDim kernel for swapping. 992 // Otherwise, the trivial swapping relying on the ldg cache is more efficient. 993 static const int kMinDimensionToUseTiles = 16; 994 static const int kMinDimensionToUseRectTiles = 96; 995 996 bool large_matrix = input_dims[1] >= kMinDimensionToUseTiles && 997 input_dims[2] >= kMinDimensionToUseTiles; 998 bool narrow_matrix = input_dims[1] >= kMinDimensionToUseRectTiles || 999 input_dims[2] >= kMinDimensionToUseRectTiles; 1000 if (large_matrix) { 1001 // We get best performance when kTileSize is the number of threads in a warp 1002 // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256 1003 // threads. 1004 constexpr int kTileSize = 32; 1005 constexpr int kNumThreads = 256; 1006 1007 Dimension<3> input_dims_in_tiles = { 1008 input_dims[0], 1009 MathUtil::CeilOfRatio<int>(input_dims[1], kTileSize), 1010 MathUtil::CeilOfRatio<int>(input_dims[2], kTileSize), 1011 }; 1012 1013 int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * 1014 input_dims_in_tiles[2]; 1015 TF_CHECK_OK(GpuLaunchKernel( 1016 SwapDimension1And2InTensor3UsingTiles<T, kNumThreads, kTileSize, 1017 kTileSize, conjugate>, 1018 total_tiles_count, kNumThreads, 0, d.stream(), input, input_dims, 1019 output)); 1020 1021 } else if (narrow_matrix) { 1022 SwapDimension1And2InTensor3WithNarrowMatrices<T, conjugate>( 1023 d, input, input_dims, output, kMinDimensionToUseTiles); 1024 } else { 1025 int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; 1026 GpuLaunchConfig config = GetGpuLaunchConfig(total_element_count, d); 1027 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>, 1028 config.block_count, config.thread_per_block, 0, 1029 d.stream(), config.virtual_thread_count, input, 1030 input_dims, output)); 1031 } 1032 } 1033 1034 // A GPU helper functor that does general dimension 1 and 2 switch for 3D 1035 // tensor. 1036 template <typename T, bool conjugate> 1037 struct SwapDimension1And2InTensor3<GPUDevice, T, conjugate> { 1038 typedef GPUDevice Device; 1039 void operator()(const Device& d, const T* in, 1040 const gtl::ArraySlice<int64>& combined_dims, T* out) { 1041 Dimension<3> input_dims = {static_cast<int>(combined_dims[0]), 1042 static_cast<int>(combined_dims[1]), 1043 static_cast<int>(combined_dims[2])}; 1044 RunSwapDimension1And2InTensor3<T, conjugate>(d, in, input_dims, out); 1045 } 1046 }; 1047 1048 // A GPU helper functor that does general dimension 0 and 2 switch for 3D 1049 // tensor. 1050 template <typename T, bool conjugate> 1051 struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> { 1052 typedef GPUDevice Device; 1053 void operator()(const Device& d, const T* in, 1054 const gtl::ArraySlice<int64>& combined_dims, T* out) { 1055 Dimension<3> input_dims = {static_cast<int>(combined_dims[0]), 1056 static_cast<int>(combined_dims[1]), 1057 static_cast<int>(combined_dims[2])}; 1058 size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; 1059 GpuLaunchConfig config = GetGpuLaunchConfig(total_size, d); 1060 1061 auto out_ptr = reinterpret_cast<uintptr_t>(out); 1062 bool aligned = out_ptr % 16 == 0; 1063 1064 bool use_vector = false; 1065 bool use_custom_config = false; 1066 if ((input_dims[0] <= 128 && input_dims[2] <= 128) || 1067 input_dims[0] * input_dims[1] <= 128 || 1068 input_dims[1] * input_dims[2] <= 8) { 1069 use_vector = true; 1070 use_custom_config = true; 1071 } else if (input_dims[1] * input_dims[2] <= 16384) { 1072 use_vector = true; 1073 } 1074 1075 if (sizeof(T) == 2 && aligned && use_vector) { 1076 int block_count; 1077 if (use_custom_config) { 1078 block_count = (total_size + config.thread_per_block - 1) / 1079 config.thread_per_block; 1080 } else { 1081 block_count = config.block_count; 1082 } 1083 1084 TF_CHECK_OK( 1085 GpuLaunchKernel(ShuffleInTensor3SimpleVector<T, 2, 1, 0, conjugate>, 1086 block_count, config.thread_per_block / kUnroll, 0, 1087 d.stream(), total_size, in, input_dims, out)); 1088 } else { 1089 TF_CHECK_OK(GpuLaunchKernel(ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>, 1090 config.block_count, config.thread_per_block, 1091 0, d.stream(), config.virtual_thread_count, 1092 in, input_dims, out)); 1093 } 1094 } 1095 }; 1096 1097 // A GPU helper functor that converts NHWC TensorFlow data format to 1098 // NCHW format that is accepted by Cudnn. 1099 template <typename T, int NDIMS> 1100 struct NHWCToNCHW<GPUDevice, T, NDIMS> { 1101 typedef GPUDevice Device; 1102 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 1103 typename TTypes<T, NDIMS>::Tensor out) { 1104 Dimension<3> combined_dims; 1105 combined_dims[0] = in.dimension(0); // N (batch) 1106 combined_dims[1] = in.dimension(1); // spatial dimensions (HW) 1107 for (int i = 2; i < NDIMS - 1; ++i) { 1108 combined_dims[1] *= in.dimension(i); 1109 } 1110 combined_dims[2] = in.dimension(NDIMS - 1); // C (channels) 1111 RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); 1112 } 1113 }; 1114 1115 // A GPU helper functor that converts NCHW Cudnn data format to NHWC TensorFlow 1116 // Format. 1117 template <typename T, int NDIMS> 1118 struct NCHWToNHWC<GPUDevice, T, NDIMS> { 1119 typedef GPUDevice Device; 1120 void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, 1121 typename TTypes<T, NDIMS>::Tensor out) { 1122 Dimension<3> combined_dims; 1123 combined_dims[0] = in.dimension(0); // N (batch) 1124 combined_dims[1] = in.dimension(1); // C (channel) 1125 combined_dims[2] = in.dimension(2); // spatial dimensions (HW) 1126 for (int i = 3; i < NDIMS; ++i) { 1127 combined_dims[2] *= in.dimension(i); 1128 } 1129 RunSwapDimension1And2InTensor3(d, in.data(), combined_dims, out.data()); 1130 } 1131 }; 1132 1133 } // namespace functor 1134 } // namespace tensorflow 1135 1136 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 1137 1138 #endif // TENSORFLOW_CORE_KERNELS_CONV_2D_GPU_H_ 1139