1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
17 #define TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
18
19 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
20
21 #include <algorithm>
22
23 #include "absl/base/casts.h"
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/framework/op_kernel.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/gpu_cuda_alias.h"
30
31 // Usage of GetGpuLaunchConfig, GetGpu2DLaunchConfig, and
32 // GetGpu3DLaunchConfig:
33 //
34 // There are two versions of GetGpuLaunchConfig and GetGpu2DLaunchConfig, one
35 // version uses heuristics without any knowledge of the device kernel, the other
36 // version uses cudaOccupancyMaxPotentialBlockSize to determine the theoretical
37 // launch parameters that maximize occupancy. Currently, only the maximum
38 // occupancy version of GetGpu3DLaunchConfig is available.
39 //
40 // For large number of work elements, the convention is that each kernel would
41 // iterate through its assigned range. The return value of GetGpuLaunchConfig
42 // is struct GpuLaunchConfig, which contains all the information needed for the
43 // kernel launch, including: virtual number of threads, the number of threads
44 // per block and number of threads per block used inside <<< >>> of a kernel
45 // launch. GetGpu2DLaunchConfig and GetGpu3DLaunchConfig does the same thing
46 // as GpuLaunchConfig. The only difference is the dimension. The macros
47 // GPU_1D_KERNEL_LOOP and GPU_AXIS_KERNEL_LOOP might be used to do inner loop.
48 //
49 /* Sample code:
50
51 __global__ void MyKernel1D(GpuLaunchConfig config, other_args...) {
52 GPU_1D_KERNEL_LOOP(x, config.virtual_thread_count) {
53 do_your_job_here;
54 }
55 }
56
57 __global__ void MyKernel2D(Gpu2DLaunchConfig config, other_args...) {
58 GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
59 GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
60 do_your_job_here;
61 }
62 }
63 }
64
65 __global__ void MyKernel3D(Gpu3DLaunchConfig config, other_args...) {
66 GPU_AXIS_KERNEL_LOOP(x, config.virtual_thread_count, x) {
67 GPU_AXIS_KERNEL_LOOP(y, config.virtual_thread_count, y) {
68 GPU_AXIS_KERNEL_LOOP(z, config.virtual_thread_count, z) {
69 do_your_job_here;
70 }
71 }
72 }
73 }
74
75 void MyDriverFunc(const Eigen::GpuDevice &d) {
76 // use heuristics
77 GpuLaunchConfig cfg1 = GetGpuLaunchConfig(10240, d);
78 MyKernel1D <<<config.block_count,
79 config.thread_per_block, 0, d.stream()>>> (cfg1, other_args...);
80 Gpu2DLaunchConfig cfg2 = GetGpu2DLaunchConfig(10240, 10240, d);
81 MyKernel2D <<<config.block_count,
82 config.thread_per_block, 0, d.stream()>>> (cfg2, other_args...);
83 Gpu3DLaunchConfig cfg3 = GetGpu3DLaunchConfig(4096, 4096, 100, d);
84 MyKernel3D <<<config.block_count,
85 config.thread_per_block, 0, d.stream()>>> (cfg3, other_args...);
86
87 // maximize occupancy
88 GpuLaunchConfig cfg4 = GetGpuLaunchConfig(10240, d, MyKernel1D, 0, 0 );
89 MyKernel1D <<<config.block_count,
90 config.thread_per_block, 0, d.stream()>>> (cfg4, other_args...);
91 Gpu2DLaunchConfig cfg5 = GetGpu2DLaunchConfig(10240, 10240, d,
92 MyKernel1D, 0, 0);
93 MyKernel2D <<<config.block_count,
94 config.thread_per_block, 0, d.stream()>>> (cfg5, other_args...);
95 Gpu3DLaunchConfig cfg6 = GetGpu3DLaunchConfig(4096, 4096, 100, d,
96 MyKernel1D, 0, 0);
97 MyKernel3D <<<config.block_count,
98 config.thread_per_block, 0, d.stream()>>> (cfg6, other_args...);
99 }
100
101 // See the test for this for more example:
102 //
103 https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/util/gpu_kernel_helper_test.cu.cc
104
105 */
106
107 namespace tensorflow {
108
DivUp(int a,int b)109 inline int DivUp(int a, int b) { return (a + b - 1) / b; }
110
111 struct GpuLaunchConfig {
112 // Logical number of thread that works on the elements. If each logical
113 // thread works on exactly a single element, this is the same as the working
114 // element count.
115 int virtual_thread_count = -1;
116 // Number of threads per block.
117 int thread_per_block = -1;
118 // Number of blocks for GPU kernel launch.
119 int block_count = -1;
120 };
121 CREATE_CUDA_TYPE_ALIAS(GpuLaunchConfig, CudaLaunchConfig);
122
123 // Calculate the GPU launch config we should use for a kernel launch.
124 // This is assuming the kernel is quite simple and will largely be
125 // memory-limited.
126 // REQUIRES: work_element_count > 0.
GetGpuLaunchConfig(int work_element_count,const Eigen::GpuDevice & d)127 inline GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
128 const Eigen::GpuDevice& d) {
129 CHECK_GT(work_element_count, 0);
130 GpuLaunchConfig config;
131 const int virtual_thread_count = work_element_count;
132 const int physical_thread_count = std::min(
133 d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor(),
134 virtual_thread_count);
135 const int thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
136 const int block_count =
137 std::min(DivUp(physical_thread_count, thread_per_block),
138 d.getNumGpuMultiProcessors());
139
140 config.virtual_thread_count = virtual_thread_count;
141 config.thread_per_block = thread_per_block;
142 config.block_count = block_count;
143 return config;
144 }
145 #ifndef TENSORFLOW_USE_ROCM
GetCudaLaunchConfig(int work_element_count,const Eigen::GpuDevice & d)146 inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count,
147 const Eigen::GpuDevice& d) {
148 return GetGpuLaunchConfig(work_element_count, d);
149 }
150 #endif
151
152 // Calculate the GPU launch config we should use for a kernel launch. This
153 // variant takes the resource limits of func into account to maximize occupancy.
154 // REQUIRES: work_element_count > 0.
155 template <typename DeviceFunc>
GetGpuLaunchConfig(int work_element_count,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)156 GpuLaunchConfig GetGpuLaunchConfig(int work_element_count,
157 const Eigen::GpuDevice& d, DeviceFunc func,
158 size_t dynamic_shared_memory_size,
159 int block_size_limit) {
160 CHECK_GT(work_element_count, 0);
161 GpuLaunchConfig config;
162 int block_count = 0;
163 int thread_per_block = 0;
164
165 #if GOOGLE_CUDA
166 cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
167 &block_count, &thread_per_block, func, dynamic_shared_memory_size,
168 block_size_limit);
169 CHECK_EQ(err, cudaSuccess);
170 #elif TENSORFLOW_USE_ROCM
171 hipError_t err = hipOccupancyMaxPotentialBlockSize(
172 &block_count, &thread_per_block, func, dynamic_shared_memory_size,
173 block_size_limit);
174 CHECK_EQ(err, hipSuccess);
175 #endif
176
177 block_count =
178 std::min(block_count, DivUp(work_element_count, thread_per_block));
179
180 config.virtual_thread_count = work_element_count;
181 config.thread_per_block = thread_per_block;
182 config.block_count = block_count;
183 return config;
184 }
185 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfig, GetCudaLaunchConfig);
186
187 // Calculate the GPU launch config we should use for a kernel launch. This
188 // variant takes the resource limits of func into account to maximize occupancy.
189 // The returned launch config has thread_per_block set to fixed_block_size.
190 // REQUIRES: work_element_count > 0.
191 template <typename DeviceFunc>
GetGpuLaunchConfigFixedBlockSize(int work_element_count,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int fixed_block_size)192 GpuLaunchConfig GetGpuLaunchConfigFixedBlockSize(
193 int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func,
194 size_t dynamic_shared_memory_size, int fixed_block_size) {
195 CHECK_GT(work_element_count, 0);
196 GpuLaunchConfig config;
197 int block_count = 0;
198
199 #if GOOGLE_CUDA
200 cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
201 &block_count, func, fixed_block_size, dynamic_shared_memory_size);
202 CHECK_EQ(err, cudaSuccess);
203 #elif TENSORFLOW_USE_ROCM
204 hipError_t err = hipOccupancyMaxActiveBlocksPerMultiprocessor(
205 &block_count, func, fixed_block_size, dynamic_shared_memory_size);
206 CHECK_EQ(err, hipSuccess);
207 #endif
208 block_count = std::min(block_count * d.getNumGpuMultiProcessors(),
209 DivUp(work_element_count, fixed_block_size));
210
211 config.virtual_thread_count = work_element_count;
212 config.thread_per_block = fixed_block_size;
213 config.block_count = block_count;
214 return config;
215 }
216 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpuLaunchConfigFixedBlockSize,
217 GetCudaLaunchConfigFixedBlockSize);
218
219 struct Gpu2DLaunchConfig {
220 dim3 virtual_thread_count = dim3(0, 0, 0);
221 dim3 thread_per_block = dim3(0, 0, 0);
222 dim3 block_count = dim3(0, 0, 0);
223 };
224 CREATE_CUDA_TYPE_ALIAS(Gpu2DLaunchConfig, Cuda2DLaunchConfig);
225
GetGpu2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d)226 inline Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim,
227 const Eigen::GpuDevice& d) {
228 Gpu2DLaunchConfig config;
229
230 if (xdim <= 0 || ydim <= 0) {
231 return config;
232 }
233
234 const int kThreadsPerBlock = 256;
235 int block_cols = std::min(xdim, kThreadsPerBlock);
236 // ok to round down here and just do more loops in the kernel
237 int block_rows = std::max(kThreadsPerBlock / block_cols, 1);
238
239 const int physical_thread_count =
240 d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
241
242 const int max_blocks = std::max(physical_thread_count / kThreadsPerBlock, 1);
243
244 config.virtual_thread_count = dim3(xdim, ydim, 1);
245 config.thread_per_block = dim3(block_cols, block_rows, 1);
246
247 int grid_x = std::min(DivUp(xdim, block_cols), max_blocks);
248
249 config.block_count = dim3(
250 grid_x, std::min(max_blocks / grid_x, std::max(ydim / block_rows, 1)), 1);
251 return config;
252 }
253 #ifndef TENSORFLOW_USE_ROCM
GetCuda2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d)254 inline Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
255 const Eigen::GpuDevice& d) {
256 return GetGpu2DLaunchConfig(xdim, ydim, d);
257 }
258 #endif
259
260 // Calculate the GPU 2D and 3D launch config we should use for a kernel launch.
261 // This variant takes the resource limits of func into account to maximize
262 // occupancy.
263 using Gpu3DLaunchConfig = Gpu2DLaunchConfig;
264 CREATE_CUDA_TYPE_ALIAS(Gpu3DLaunchConfig, Cuda3DLaunchConfig);
265
266 template <typename DeviceFunc>
GetGpu3DLaunchConfig(int xdim,int ydim,int zdim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)267 Gpu3DLaunchConfig GetGpu3DLaunchConfig(int xdim, int ydim, int zdim,
268 const Eigen::GpuDevice& d,
269 DeviceFunc func,
270 size_t dynamic_shared_memory_size,
271 int block_size_limit) {
272 Gpu3DLaunchConfig config;
273
274 if (xdim <= 0 || ydim <= 0 || zdim <= 0) {
275 return config;
276 }
277
278 int dev;
279 #if GOOGLE_CUDA
280 cudaGetDevice(&dev);
281 cudaDeviceProp deviceProp;
282 cudaGetDeviceProperties(&deviceProp, dev);
283 #elif TENSORFLOW_USE_ROCM
284 hipGetDevice(&dev);
285 hipDeviceProp_t deviceProp;
286 hipGetDeviceProperties(&deviceProp, dev);
287 #endif
288 int xthreadlimit = deviceProp.maxThreadsDim[0];
289 int ythreadlimit = deviceProp.maxThreadsDim[1];
290 int zthreadlimit = deviceProp.maxThreadsDim[2];
291 int xgridlimit = deviceProp.maxGridSize[0];
292 int ygridlimit = deviceProp.maxGridSize[1];
293 int zgridlimit = deviceProp.maxGridSize[2];
294
295 int block_count = 0;
296 int thread_per_block = 0;
297
298 #if GOOGLE_CUDA
299 cudaError_t err = cudaOccupancyMaxPotentialBlockSize(
300 &block_count, &thread_per_block, func, dynamic_shared_memory_size,
301 block_size_limit);
302 CHECK_EQ(err, cudaSuccess);
303 #elif TENSORFLOW_USE_ROCM
304 // ROCM TODO re-enable this after hipOccupancyMaxPotentialBlockSize is
305 // implemented
306 // hipError_t err = hipOccupancyMaxPotentialBlockSize(
307 // &block_count, &thread_per_block, func, dynamic_shared_memory_size,
308 // block_size_limit);
309 // CHECK_EQ(err, hipSuccess);
310
311 const int physical_thread_count =
312 d.getNumGpuMultiProcessors() * d.maxGpuThreadsPerMultiProcessor();
313 thread_per_block = std::min(1024, d.maxGpuThreadsPerBlock());
314 block_count = std::min(DivUp(physical_thread_count, thread_per_block),
315 d.getNumGpuMultiProcessors());
316 #endif
317
318 int threadsx = std::min({xdim, thread_per_block, xthreadlimit});
319 int threadsy =
320 std::min({ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit});
321 int threadsz =
322 std::min({zdim, std::max(thread_per_block / (threadsx * threadsy), 1),
323 zthreadlimit});
324
325 int blocksx = std::min({block_count, DivUp(xdim, threadsx), xgridlimit});
326 int blocksy = std::min(
327 {DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit});
328 int blocksz = std::min({DivUp(block_count, (blocksx * blocksy)),
329 DivUp(zdim, threadsz), zgridlimit});
330
331 config.virtual_thread_count = dim3(xdim, ydim, zdim);
332 config.thread_per_block = dim3(threadsx, threadsy, threadsz);
333 config.block_count = dim3(blocksx, blocksy, blocksz);
334 return config;
335 }
336 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu3DLaunchConfig, GetCuda3DLaunchConfig);
337
338 template <typename DeviceFunc>
GetGpu2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)339 Gpu2DLaunchConfig GetGpu2DLaunchConfig(int xdim, int ydim,
340 const Eigen::GpuDevice& d,
341 DeviceFunc func,
342 size_t dynamic_shared_memory_size,
343 int block_size_limit) {
344 return GetGpu3DLaunchConfig(xdim, ydim, 1, d, func,
345 dynamic_shared_memory_size, block_size_limit);
346 }
347 CREATE_CUDA_HOST_FUNCTION_ALIAS(GetGpu2DLaunchConfig, GetCuda2DLaunchConfig);
348
349 #if GOOGLE_CUDA
350 template <typename DeviceFunc>
GetCuda2DLaunchConfig(int xdim,int ydim,const Eigen::GpuDevice & d,DeviceFunc func,size_t dynamic_shared_memory_size,int block_size_limit)351 Cuda2DLaunchConfig GetCuda2DLaunchConfig(int xdim, int ydim,
352 const Eigen::GpuDevice& d,
353 DeviceFunc func,
354 size_t dynamic_shared_memory_size,
355 int block_size_limit) {
356 return GetGpu2DLaunchConfig(xdim, ydim, d, func, dynamic_shared_memory_size,
357 block_size_limit);
358 }
359 #endif // GOOGLE_CUDA
360
361 namespace detail {
362 template <typename... Ts, size_t... Is>
GetArrayOfElementPointersImpl(std::tuple<Ts...> * tuple,absl::index_sequence<Is...>)363 std::array<void*, sizeof...(Ts)> GetArrayOfElementPointersImpl(
364 std::tuple<Ts...>* tuple, absl::index_sequence<Is...>) {
365 return {{&std::get<Is>(*tuple)...}};
366 }
367 // Returns an array of void pointers to the elements of the given tuple.
368 template <typename... Ts>
GetArrayOfElementPointers(std::tuple<Ts...> * tuple)369 std::array<void*, sizeof...(Ts)> GetArrayOfElementPointers(
370 std::tuple<Ts...>* tuple) {
371 return GetArrayOfElementPointersImpl(tuple,
372 absl::index_sequence_for<Ts...>{});
373 }
374
375 template <bool...>
376 struct BoolPack;
377 template <bool... Bs>
378 using NoneTrue = std::is_same<BoolPack<Bs..., false>, BoolPack<false, Bs...>>;
379 // Returns whether none of the types in Ts is a reference.
380 template <typename... Ts>
NoneIsReference()381 constexpr bool NoneIsReference() {
382 return NoneTrue<(std::is_reference<Ts>::value)...>::value;
383 }
384 } // namespace detail
385 } // namespace tensorflow
386 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
387 #endif // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_
388