1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/use_cudnn.h"
17 
18 #include "tensorflow/core/lib/core/stringpiece.h"
19 #include "tensorflow/core/lib/strings/str_util.h"
20 #include "tensorflow/core/platform/types.h"
21 #include "tensorflow/core/util/env_var.h"
22 
23 namespace tensorflow {
24 
25 #define ADD_BOOL_CUDNN_FLAG(func_name, flag_name, default_value)           \
26   bool func_name() {                                                       \
27     bool value = default_value;                                            \
28     Status status = ReadBoolFromEnvVar(#flag_name, default_value, &value); \
29     if (!status.ok()) {                                                    \
30       LOG(ERROR) << status;                                                \
31     }                                                                      \
32     return value;                                                          \
33   }
34 
35 ADD_BOOL_CUDNN_FLAG(CanUseCudnn, TF_USE_CUDNN, true);
36 ADD_BOOL_CUDNN_FLAG(CudnnUseAutotune, TF_CUDNN_USE_AUTOTUNE, true);
37 // Whether to auto-tuning Cudnn RNN forward and backward pass to pick
38 // statistically the best cudnnRNNAlgo_t and cudnnMathType_t.
39 // The flag is disabled when TF_DEBUG_CUDNN_RNN is turned on.
40 ADD_BOOL_CUDNN_FLAG(CudnnRnnUseAutotune, TF_CUDNN_RNN_USE_AUTOTUNE, true);
41 ADD_BOOL_CUDNN_FLAG(CudnnDisableConv1x1Optimization,
42                     TF_CUDNN_DISABLE_CONV_1X1_OPTIMIZATION, false);
43 
44 // Whether to run Cudnn RNN forward and backward in debug mode, where users can
45 // force a specified cudnnRNNAlgo_t and cudnnMathType_t, when used together with
46 // the following two env vars:
47 // TF_DEBUG_CUDNN_RNN_USE_TENSOR_OPS
48 // TF_DEBUG_CUDNN_RNN_ALGO
49 // By default it is disabled and only intended for testing and profiling.
50 ADD_BOOL_CUDNN_FLAG(DebugCudnnRnn, TF_DEBUG_CUDNN_RNN, false);
51 // If using TENSOR_OP_MATH in Cudnn RNN for both forward and backward pass. Only
52 // effective when TF_DEBUG_CUDNN_RNN is true.
53 // Note none of the persistent RNN algorithm support TENSOR_OP_MATH before
54 // Cudnn 7.1. See Nvidia Cudnn manual for more details.
55 ADD_BOOL_CUDNN_FLAG(DebugCudnnRnnUseTensorOps,
56                     TF_DEBUG_CUDNN_RNN_USE_TENSOR_OPS, false);
57 #undef ADD_BOOL_CUDNN_FLAG
58 
59 #define ADD_INT64_CUDNN_FLAG(func_name, flag_name, default_value)           \
60   int64 func_name() {                                                       \
61     int64 value = default_value;                                            \
62     Status status = ReadInt64FromEnvVar(#flag_name, default_value, &value); \
63     if (!status.ok()) {                                                     \
64       LOG(ERROR) << status;                                                 \
65     }                                                                       \
66     return value;                                                           \
67   }
68 // Cudnn RNN algorithm to use for both forward and backward pass. Only effective
69 // when TF_DEBUG_CUDNN_RNN is true. See Nvidia Cudnn manual for allowed
70 // cudnnRNNAlgo_t.
71 ADD_INT64_CUDNN_FLAG(DebugCudnnRnnAlgo, TF_DEBUG_CUDNN_RNN_ALGO, -1);
72 #undef ADD_INT64_CUDNN_FLAG
73 
CudnnConvComputeMode()74 FP16ConvMode CudnnConvComputeMode() {
75   string value;
76   Status status = ReadStringFromEnvVar("TF_FP16_CONV_MODE", "accurate", &value);
77   if (!status.ok()) {
78     LOG(ERROR) << status;
79   }
80   string lowercase_value = str_util::Lowercase(value);
81   if (lowercase_value == "accurate") {
82     return FP16ConvMode::kAccurate;
83   } else if (lowercase_value == "fast") {
84     return FP16ConvMode::kFast;
85   } else {
86     LOG(ERROR) << "FP16ConvMode only supports two modes, ACCURATE and FAST. "
87                   "Got unknown mode: "
88                << value;
89   }
90   return FP16ConvMode::kAccurate;
91 }
92 
93 }  // namespace tensorflow
94