1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
18 
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20 
21 #include <algorithm>
22 
23 #include "absl/base/casts.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/gpu_cuda_alias.h"
30 
31 // Usage of GetGpuLaunchConfig, GetGpu2DLaunchConfig, and
32 // GetGpu3DLaunchConfig:
33 //
34 // There are two versions of GetGpuLaunchConfig and GetGpu2DLaunchConfig, one
35 // version uses heuristics without any knowledge of the device kernel, the other
36 // version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
37 // launch parameters that maximize occupancy. Currently, only the maximum
38 // occupancy version of GetGpu3DLaunchConfig is available.
39 //
40 // For large number of work elements, the convention is that each kernel would
41 // iterate through its assigned range. The return value of GetGpuLaunchConfig
42 // is struct GpuLaunchConfig, which contains all the information needed for the
43 // kernel launch, including: virtual number of threads, the number of threads
44 // per block and number of threads per block used inside <<< >>> of a kernel
45 // launch. GetGpu2DLaunchConfig and GetGpu3DLaunchConfig does the same thing
46 // as GpuLaunchConfig. The only difference is the dimension. The macros
47 // GPU_1D_KERNEL_LOOP and GPU_AXIS_KERNEL_LOOP might be used to do inner loop.
48 //
49 /* Sample code:
50 
51 __global__ void MyKernel1D(GpuLaunchConfig config, other_args...) {
52   GPU_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
53     do_your_job_here;
54   }
55 }
56 
57 __global__ void MyKernel2D(Gpu2DLaunchConfig config, other_args...) {
58   GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
59     GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
60       do_your_job_here;
61     }
62   }
63 }
64 
65 __global__ void MyKernel3D(Gpu3DLaunchConfig config, other_args...) {
66   GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
67     GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
68       GPU_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
69         do_your_job_here;
70       }
71     }
72   }
73 }
74 
75 void MyDriverFunc(const Eigen::GpuDevice &d) {
76   // use heuristics
77   GpuLaunchConfig cfg1 = GetGpuLaunchConfig(10240, d);
78   MyKernel1D <<<config.block_count,
79                 config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
80   Gpu2DLaunchConfig cfg2 = GetGpu2DLaunchConfig(10240, 10240, d);
81   MyKernel2D <<<config.block_count,
82                 config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
83   Gpu3DLaunchConfig cfg3 = GetGpu3DLaunchConfig(4096, 4096, 100, d);
84   MyKernel3D <<<config.block_count,
85                 config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
86 
87   // maximize occupancy
88   GpuLaunchConfig cfg4 = GetGpuLaunchConfig(10240, d, MyKernel1D, 0, 0 );
89   MyKernel1D <<<config.block_count,
90                 config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
91   Gpu2DLaunchConfig cfg5 = GetGpu2DLaunchConfig(10240, 10240, d,
92                                                   MyKernel1D, 0, 0);
93   MyKernel2D <<<config.block_count,
94                 config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
95   Gpu3DLaunchConfig cfg6 = GetGpu3DLaunchConfig(4096, 4096, 100, d,
96                                                   MyKernel1D, 0, 0);
97   MyKernel3D <<<config.block_count,
98                 config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
99 }
100 
101 // See the test for this for more example:
102 //
103 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/gpu_kernel_helper_test.cu.cc
104 
105 */
106 
107 namespace tensorflow {
108 
DivUp(int a,int b)109 inline int DivUp(int a, int b) { return (a + b - 1) / b; }
110 
111 struct GpuLaunchConfig {
112   // Logical number of thread that works on the elements. If each logical
113   // thread works on exactly a single element, this is the same as the working
114   // element count.
115   int virtual_thread_count = -1;
116   // Number of threads per block.
117   int thread_per_block = -1;
118   // Number of blocks for GPU kernel launch.
119   int block_count = -1;
120 };
121 CREATE_CUDA_TYPE_ALIAS(GpuLaunchConfig, CudaLaunchConfig);
122 
123 // Calculate the GPU launch config we should use for a kernel launch.
124 // This is assuming the kernel is quite simple and will largely be
125 // memory-limited.
126 // REQUIRES: work_element_count > 0.
GetGpuLaunchConfig(int work_element_count,const Eigen::GpuDevice & d)127 inline GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
128                                           const Eigen::GpuDevice& d) {
129   CHECK_GT(work_element_count, 0);
130   GpuLaunchConfig config;
131   const int virtual_thread_count = work_element_count;
132   const int physical_thread_count = std::min(
133       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
134       virtual_thread_count);
135   const int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
136   const int block_count =
137       std::min(DivUp(physical_thread_count, thread_per_block),
138                d.getNumGpuMultiProcessors());
139 
140   config.virtual_thread_count = virtual_thread_count;
141   config.thread_per_block = thread_per_block;
142   config.block_count = block_count;
143   return config;
144 }
145 #ifndef TENSORFLOW_USE_ROCM
GetCudaLaunchConfig(int work_element_count,const Eigen::GpuDevice & d)146 inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
147                                             const Eigen::GpuDevice& d) {
148   return GetGpuLaunchConfig(work_element_count, d);
149 }
150 #endif
151 
152 // Calculate the GPU launch config we should use for a kernel launch. This
153 // variant takes the resource limits of func into account to maximize occupancy.
154 // REQUIRES: work_element_count > 0.
155 template <typename DeviceFunc>
GetGpuLaunchConfig(int work_element_count,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)156 GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
157                                    const Eigen::GpuDevice& d, DeviceFunc func,
158                                    size_t dynamic_shared_memory_size,
159                                    int block_size_limit) {
160   CHECK_GT(work_element_count, 0);
161   GpuLaunchConfig config;
162   int block_count = 0;
163   int thread_per_block = 0;
164 
165 #if GOOGLE_CUDA
166   cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
167       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
168       block_size_limit);
169   CHECK_EQ(err, cudaSuccess);
170 #elif TENSORFLOW_USE_ROCM
171   hipError_t err = hipOccupancyMaxPotentialBlockSize(
172       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
173       block_size_limit);
174   CHECK_EQ(err, hipSuccess);
175 #endif
176 
177   block_count =
178       std::min(block_count, DivUp(work_element_count, thread_per_block));
179 
180   config.virtual_thread_count = work_element_count;
181   config.thread_per_block = thread_per_block;
182   config.block_count = block_count;
183   return config;
184 }
185 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfig, GetCudaLaunchConfig);
186 
187 // Calculate the GPU launch config we should use for a kernel launch. This
188 // variant takes the resource limits of func into account to maximize occupancy.
189 // The returned launch config has thread_per_block set to fixed_block_size.
190 // REQUIRES: work_element_count > 0.
191 template <typename DeviceFunc>
GetGpuLaunchConfigFixedBlockSize(int work_element_count,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int fixed_block_size)192 GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize(
193     int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func,
194     size_t dynamic_shared_memory_size, int fixed_block_size) {
195   CHECK_GT(work_element_count, 0);
196   GpuLaunchConfig config;
197   int block_count = 0;
198 
199 #if GOOGLE_CUDA
200   cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
201       &block_count, func, fixed_block_size, dynamic_shared_memory_size);
202   CHECK_EQ(err, cudaSuccess);
203 #elif TENSORFLOW_USE_ROCM
204   hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
205       &block_count, func, fixed_block_size, dynamic_shared_memory_size);
206   CHECK_EQ(err, hipSuccess);
207 #endif
208   block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
209                          DivUp(work_element_count, fixed_block_size));
210 
211   config.virtual_thread_count = work_element_count;
212   config.thread_per_block = fixed_block_size;
213   config.block_count = block_count;
214   return config;
215 }
216 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfigFixedBlockSize,
217                                 GetCudaLaunchConfigFixedBlockSize);
218 
219 struct Gpu2DLaunchConfig {
220   dim3 virtual_thread_count = dim3(0, 0, 0);
221   dim3 thread_per_block = dim3(0, 0, 0);
222   dim3 block_count = dim3(0, 0, 0);
223 };
224 CREATE_CUDA_TYPE_ALIAS(Gpu2DLaunchConfig, Cuda2DLaunchConfig);
225 
GetGpu2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d)226 inline Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim,
227                                               const Eigen::GpuDevice& d) {
228   Gpu2DLaunchConfig config;
229 
230   if (xdim <= 0 || ydim <= 0) {
231     return config;
232   }
233 
234   const int kThreadsPerBlock = 256;
235   int block_cols = std::min(xdim, kThreadsPerBlock);
236   // ok to round down here and just do more loops in the kernel
237   int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
238 
239   const int physical_thread_count =
240       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
241 
242   const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
243 
244   config.virtual_thread_count = dim3(xdim, ydim, 1);
245   config.thread_per_block = dim3(block_cols, block_rows, 1);
246 
247   int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
248 
249   config.block_count = dim3(
250       grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
251   return config;
252 }
253 #ifndef TENSORFLOW_USE_ROCM
GetCuda2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d)254 inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
255                                                 const Eigen::GpuDevice& d) {
256   return GetGpu2DLaunchConfig(xdim, ydim, d);
257 }
258 #endif
259 
260 // Calculate the GPU 2D and 3D launch config we should use for a kernel launch.
261 // This variant takes the resource limits of func into account to maximize
262 // occupancy.
263 using Gpu3DLaunchConfig = Gpu2DLaunchConfig;
264 CREATE_CUDA_TYPE_ALIAS(Gpu3DLaunchConfig, Cuda3DLaunchConfig);
265 
266 template <typename DeviceFunc>
GetGpu3DLaunchConfig(int xdim,int ydim,int zdim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)267 Gpu3DLaunchConfig GetGpu3DLaunchConfig(int xdim, int ydim, int zdim,
268                                        const Eigen::GpuDevice& d,
269                                        DeviceFunc func,
270                                        size_t dynamic_shared_memory_size,
271                                        int block_size_limit) {
272   Gpu3DLaunchConfig config;
273 
274   if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
275     return config;
276   }
277 
278   int dev;
279 #if GOOGLE_CUDA
280   cudaGetDevice(&dev);
281   cudaDeviceProp deviceProp;
282   cudaGetDeviceProperties(&deviceProp, dev);
283 #elif TENSORFLOW_USE_ROCM
284   hipGetDevice(&dev);
285   hipDeviceProp_t deviceProp;
286   hipGetDeviceProperties(&deviceProp, dev);
287 #endif
288   int xthreadlimit = deviceProp.maxThreadsDim[0];
289   int ythreadlimit = deviceProp.maxThreadsDim[1];
290   int zthreadlimit = deviceProp.maxThreadsDim[2];
291   int xgridlimit = deviceProp.maxGridSize[0];
292   int ygridlimit = deviceProp.maxGridSize[1];
293   int zgridlimit = deviceProp.maxGridSize[2];
294 
295   int block_count = 0;
296   int thread_per_block = 0;
297 
298 #if GOOGLE_CUDA
299   cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
300       &block_count, &thread_per_block, func, dynamic_shared_memory_size,
301       block_size_limit);
302   CHECK_EQ(err, cudaSuccess);
303 #elif TENSORFLOW_USE_ROCM
304   // ROCM TODO re-enable this after hipOccupancyMaxPotentialBlockSize is
305   // implemented
306   // hipError_t err = hipOccupancyMaxPotentialBlockSize(
307   //    &block_count, &thread_per_block, func, dynamic_shared_memory_size,
308   //    block_size_limit);
309   // CHECK_EQ(err, hipSuccess);
310 
311   const int physical_thread_count =
312       d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
313   thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
314   block_count = std::min(DivUp(physical_thread_count, thread_per_block),
315                          d.getNumGpuMultiProcessors());
316 #endif
317 
318   int threadsx = std::min({xdim, thread_per_block, xthreadlimit});
319   int threadsy =
320       std::min({ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit});
321   int threadsz =
322       std::min({zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
323                 zthreadlimit});
324 
325   int blocksx = std::min({block_count, DivUp(xdim, threadsx), xgridlimit});
326   int blocksy = std::min(
327       {DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit});
328   int blocksz = std::min({DivUp(block_count, (blocksx * blocksy)),
329                           DivUp(zdim, threadsz), zgridlimit});
330 
331   config.virtual_thread_count = dim3(xdim, ydim, zdim);
332   config.thread_per_block = dim3(threadsx, threadsy, threadsz);
333   config.block_count = dim3(blocksx, blocksy, blocksz);
334   return config;
335 }
336 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu3DLaunchConfig, GetCuda3DLaunchConfig);
337 
338 template <typename DeviceFunc>
GetGpu2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)339 Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim,
340                                        const Eigen::GpuDevice& d,
341                                        DeviceFunc func,
342                                        size_t dynamic_shared_memory_size,
343                                        int block_size_limit) {
344   return GetGpu3DLaunchConfig(xdim, ydim, 1, d, func,
345                               dynamic_shared_memory_size, block_size_limit);
346 }
347 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu2DLaunchConfig, GetCuda2DLaunchConfig);
348 
349 #if GOOGLE_CUDA
350 template <typename DeviceFunc>
GetCuda2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)351 Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
352                                          const Eigen::GpuDevice& d,
353                                          DeviceFunc func,
354                                          size_t dynamic_shared_memory_size,
355                                          int block_size_limit) {
356   return GetGpu2DLaunchConfig(xdim, ydim, d, func, dynamic_shared_memory_size,
357                               block_size_limit);
358 }
359 #endif  // GOOGLE_CUDA
360 
361 namespace detail {
362 template <typename... Ts, size_t... Is>
GetArrayOfElementPointersImpl(std::tuple<Ts...> * tuple,absl::index_sequence<Is...>)363 std::array<void*, sizeof...(Ts)> GetArrayOfElementPointersImpl(
364     std::tuple<Ts...>* tuple, absl::index_sequence<Is...>) {
365   return {{&std::get<Is>(*tuple)...}};
366 }
367 // Returns an array of void pointers to the elements of the given tuple.
368 template <typename... Ts>
GetArrayOfElementPointers(std::tuple<Ts...> * tuple)369 std::array<void*, sizeof...(Ts)> GetArrayOfElementPointers(
370     std::tuple<Ts...>* tuple) {
371   return GetArrayOfElementPointersImpl(tuple,
372                                        absl::index_sequence_for<Ts...>{});
373 }
374 
375 template <bool...>
376 struct BoolPack;
377 template <bool... Bs>
378 using NoneTrue = std::is_same<BoolPack<Bs..., false>, BoolPack<false, Bs...>>;
379 // Returns whether none of the types in Ts is a reference.
380 template <typename... Ts>
NoneIsReference()381 constexpr bool NoneIsReference() {
382   return NoneTrue<(std::is_reference<Ts>::value)...>::value;
383 }
384 }  // namespace detail
385 }  // namespace tensorflow
386 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
387 #endif  // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
388