1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 18 19 #include "tensorflow/core/lib/gtl/flatset.h" 20 #include "tensorflow/core/lib/strings/str_util.h" 21 #include "tensorflow/core/util/env_var.h" 22 23 namespace tensorflow { 24 namespace grappler { 25 26 // Represents the four lists of ops: the allow list, infer list, deny list, and 27 // clear list. These lists determine which ops are converted to fp16/bf16 28 // (referred to as 'f16' for short) and which ops stay as fp32. 29 class AutoMixedPrecisionLists { 30 public: ~AutoMixedPrecisionLists()31 virtual ~AutoMixedPrecisionLists() {} 32 33 // Returns the set of ops that are considered numerically-safe (for execution 34 // in f16), performance-critical, and can run in f16. These ops are always 35 // converted to f16. 36 virtual gtl::FlatSet<string> AllowList() = 0; 37 // Returns the set of ops that can run in f16 and are considered numerically- 38 // safe (for execution in f16), but which may be made unsafe by an upstream 39 // denylist op. 40 virtual gtl::FlatSet<string> InferList() = 0; 41 // Returns the set of ops that are considered numerically-dangerous (i.e., 42 // unsafe for execution in f16) and whose effects may also be observed in 43 // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to 44 // the Exp). 45 virtual gtl::FlatSet<string> DenyList() = 0; 46 // Returns the set of ops that do not have numerically-significant effects 47 // (i.e., they are always considered safe for execution in f16 precision), and 48 // can run in f16. 49 virtual gtl::FlatSet<string> ClearList() = 0; 50 51 protected: 52 // Adds or removes ops from list if certain environmental variables are set. UpdateList(const string & list_name,gtl::FlatSet<string> * list)53 static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) { 54 CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" || // Crash OK. 55 list_name == "DENYLIST" || list_name == "CLEARLIST" || 56 // TODO(reedwm): for bkwds compat; remove when no longer necessary: 57 list_name == "WHITELIST" || list_name == "GRAYLIST" || 58 list_name == "BLACKLIST"); 59 string add_env_var = 60 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD"; 61 string remove_env_var = 62 "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE"; 63 string to_add, to_remove; 64 TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add)); 65 TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove)); 66 for (const auto& x : str_util::Split(to_add, ",")) { 67 list->insert(x); 68 } 69 for (const auto& x : str_util::Split(to_remove, ",")) { 70 list->erase(x); 71 } 72 } 73 74 // Subclasses should include these on the ClearList. AddTensorListOps(gtl::FlatSet<string> * list)75 static void AddTensorListOps(gtl::FlatSet<string>* list) { 76 // Note: if a data structure op (such as TensorListPopBack) is added here, 77 // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified 78 // LINT.IfChange 79 constexpr const char* tensor_list_ops[] = { 80 "TensorListConcat", "TensorListConcatLists", 81 "TensorListConcatV2", "TensorListGather", 82 "TensorListGetItem", "TensorListPopBack", 83 "TensorListPushBack", "TensorListPushBackBatch", 84 "TensorListFromTensor", "TensorListScatter", 85 "TensorListScatterV2", "TensorListScatterIntoExistingList", 86 "TensorListSetItem", "TensorListSplit", 87 "TensorListStack"}; 88 // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc) 89 for (auto op : tensor_list_ops) { 90 list->insert(op); 91 } 92 } 93 }; 94 95 class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { 96 private: IsPseudoFastMath()97 static bool IsPseudoFastMath() { 98 string optimization_level; 99 TF_CHECK_OK( 100 ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "", 101 &optimization_level)); 102 optimization_level = str_util::Uppercase(optimization_level); 103 return optimization_level == "TENSOR_CORES_ONLY"; 104 } 105 106 public: AutoMixedPrecisionListsCuda(int cuda_version,int cudnn_version)107 AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) 108 : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} 109 AllowList()110 gtl::FlatSet<string> AllowList() override { 111 auto list = gtl::FlatSet<string>{ 112 "BlockLSTM", 113 "BlockLSTMV2", 114 "BlockLSTMGrad", 115 "BlockLSTMGradV2", 116 "Conv2D", 117 "Conv2DBackpropFilter", 118 "Conv2DBackpropInput", 119 "CudnnRNN", 120 "CudnnRNNBackprop", 121 "CudnnRNNBackpropV2", 122 "CudnnRNNBackpropV3", 123 "CudnnRNNV2", 124 "CudnnRNNV3", 125 "Einsum", 126 "GRUBlockCell", 127 "GRUBlockCellGrad", 128 "LSTMBlockCell", 129 "LSTMBlockCellGrad", 130 "MatMul", 131 }; 132 if (cuda_version_ >= 9010) { 133 // Fp16 BatchMatMul is slow before CUDA 9.1. 134 list.insert("BatchMatMul"); 135 list.insert("BatchMatMulV2"); 136 } 137 if (cudnn_version_ >= 7602) { 138 // Fp16 3D conv is slow before CUDNN 7.6.2. 139 list.insert("Conv3D"); 140 list.insert("Conv3DBackpropFilter"); 141 list.insert("Conv3DBackpropFilterV2"); 142 list.insert("Conv3DBackpropInput"); 143 list.insert("Conv3DBackpropInputV2"); 144 } 145 if (cudnn_version_ >= 8000) { 146 list.insert("DepthwiseConv2dNative"); 147 list.insert("DepthwiseConv2dNativeBackpropFilter"); 148 list.insert("DepthwiseConv2dNativeBackpropInput"); 149 } 150 UpdateList("ALLOWLIST", &list); 151 // For backwards compatibility, keeping the original env variable here. 152 // TODO(reedwm): This should be removed if we don't have active users. 153 UpdateList("WHITELIST", &list); 154 155 return list; 156 } 157 InferList()158 gtl::FlatSet<string> InferList() override { 159 if (IsPseudoFastMath()) { 160 return gtl::FlatSet<string>{}; 161 } 162 163 auto list = gtl::FlatSet<string>{ 164 "Add", 165 "AddN", 166 "AddV2", 167 "AvgPool", 168 "AvgPool3D", 169 "AvgPool3DGrad", 170 "AvgPoolGrad", 171 "BiasAdd", 172 "BiasAddGrad", 173 "BiasAddV1", 174 "Elu", 175 "EluGrad", 176 "Erf", 177 "Erfc", 178 "FloorDiv", 179 "FusedBatchNormV2", 180 "FusedBatchNormGradV2", 181 "FusedBatchNormV3", 182 "FusedBatchNormGradV3", 183 "_FusedBatchNormEx", 184 "Inv", 185 "LeakyRelu", 186 "LeakyReluGrad", 187 "Log", 188 "Log1p", 189 "LogSoftmax", 190 "Mul", 191 "Prod", 192 "RealDiv", 193 "Reciprocal", 194 "Selu", 195 "SeluGrad", 196 "Sigmoid", 197 "SigmoidGrad", 198 "Softmax", 199 "Softplus", 200 "SoftplusGrad", 201 "Softsign", 202 "SoftsignGrad", 203 "Sqrt", 204 "Sub", 205 "Tanh", 206 "TanhGrad", 207 }; 208 UpdateList("INFERLIST", &list); 209 // For backwards compatibility, keeping the original env variable here. 210 // TODO(reedwm): This should be removed if we don't have active users. 211 UpdateList("GRAYLIST", &list); 212 return list; 213 } 214 DenyList()215 gtl::FlatSet<string> DenyList() override { 216 if (IsPseudoFastMath()) { 217 return gtl::FlatSet<string>{}; 218 } 219 220 auto list = gtl::FlatSet<string>{ 221 "Exp", 222 "Expm1", 223 "L2Loss", 224 "Mean", 225 "Pow", 226 "SaveV2", 227 "SoftmaxCrossEntropyWithLogits", 228 "SparseSoftmaxCrossEntropyWithLogits", 229 "Sum", 230 }; 231 UpdateList("DENYLIST", &list); 232 // For backwards compatibility, keeping the original env variable here. 233 // TODO(reedwm): This should be removed if we don't have active users. 234 UpdateList("BLACKLIST", &list); 235 return list; 236 } 237 ClearList()238 gtl::FlatSet<string> ClearList() override { 239 if (IsPseudoFastMath()) { 240 return gtl::FlatSet<string>{}; 241 } 242 243 auto list = gtl::FlatSet<string>{ 244 "Abs", 245 "ArgMax", 246 "ArgMin", 247 "BatchToSpace", 248 "BatchToSpaceND", 249 "BroadcastTo", 250 "Ceil", 251 "CheckNumerics", 252 "ClipByValue", 253 "Concat", 254 "ConcatV2", 255 "DepthToSpace", 256 "DynamicPartition", 257 "DynamicStitch", 258 "Enter", 259 "EnsureShape", 260 "Equal", 261 "Exit", 262 "ExpandDims", 263 "Fill", 264 "Floor", 265 "Gather", 266 "GatherNd", 267 "GatherV2", 268 "Greater", 269 "GreaterEqual", 270 "Identity", 271 "IdentityN", 272 "IsFinite", 273 "IsInf", 274 "IsNan", 275 "Less", 276 "LessEqual", 277 "Max", 278 "MaxPool", 279 "MaxPool3D", 280 "MaxPool3DGrad", 281 "MaxPool3DGradGrad", 282 "MaxPoolGrad", 283 "MaxPoolGradGrad", 284 "MaxPoolGradGradV2", 285 "MaxPoolGradV2", 286 "MaxPoolV2", 287 "Maximum", 288 "Merge", 289 "Min", 290 "Minimum", 291 "MirrorPad", 292 "MirrorPadGrad", 293 "Neg", 294 "NextIteration", 295 "NotEqual", 296 "OneHot", 297 "OnesLike", 298 "Pack", 299 "Pad", 300 "PadV2", 301 "PreventGradient", 302 "Rank", 303 "Relu", 304 "Relu6", 305 "Relu6Grad", 306 "ReluGrad", 307 "Reshape", 308 "ResizeNearestNeighbor", 309 "ResizeNearestNeighborGrad", 310 "Reverse", 311 "ReverseSequence", 312 "ReverseV2", 313 "Round", 314 "Select", 315 "SelectV2", 316 "Shape", 317 "ShapeN", 318 "Sign", 319 "Size", 320 "Slice", 321 "Snapshot", 322 "SpaceToBatch", 323 "SpaceToBatchND", 324 "SpaceToDepth", 325 "Split", 326 "SplitV", 327 "Squeeze", 328 "StopGradient", 329 "StridedSlice", 330 "StridedSliceGrad", 331 "Switch", 332 "Tile", 333 "TopK", 334 "TopKV2", 335 "Transpose", 336 "Where", 337 "ZerosLike", 338 }; 339 AddTensorListOps(&list); 340 UpdateList("CLEARLIST", &list); 341 return list; 342 } 343 344 private: 345 int cuda_version_; 346 int cudnn_version_; 347 }; 348 349 class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { 350 public: AutoMixedPrecisionListsMkl()351 AutoMixedPrecisionListsMkl() {} 352 353 // Only ops which are supported by MKL in bfloat16 should be added to the 354 // allow list, infer list, or clear list. AllowList()355 gtl::FlatSet<string> AllowList() override { 356 auto list = gtl::FlatSet<string>{"Conv2D", 357 "Conv2DBackpropFilter", 358 "Conv2DBackpropInput", 359 "Conv3D", 360 "Conv3DBackpropFilterV2", 361 "Conv3DBackpropInputV2", 362 "DepthwiseConv2dNative", 363 "DepthwiseConv2dNativeBackpropFilter", 364 "DepthwiseConv2dNativeBackpropInput", 365 "MatMul", 366 "BatchMatMul", 367 "BatchMatMulV2"}; 368 369 UpdateList("ALLOWLIST", &list); 370 // For backwards compatibility, keeping the original env variable here. 371 // TODO(reedwm): This should be removed if we don't have active users. 372 UpdateList("WHITELIST", &list); 373 return list; 374 } 375 InferList()376 gtl::FlatSet<string> InferList() override { 377 auto list = gtl::FlatSet<string>{ 378 "Add", 379 "AddN", 380 "AddV2", 381 "AvgPool", 382 "AvgPool3D", 383 "AvgPool3DGrad", 384 "AvgPoolGrad", 385 "BiasAdd", 386 "BiasAddGrad", 387 "BiasAddV1", 388 "FusedBatchNormV2", 389 "FusedBatchNormGradV2", 390 "FusedBatchNormV3", 391 "FusedBatchNormGradV3", 392 "LeakyRelu", 393 "LeakyReluGrad", 394 "Mul", 395 "Sub", 396 }; 397 UpdateList("INFERLIST", &list); 398 // For backwards compatibility, keeping the original env variable here. 399 // TODO(reedwm): This should be removed if we don't have active users. 400 UpdateList("GRAYLIST", &list); 401 return list; 402 } 403 DenyList()404 gtl::FlatSet<string> DenyList() override { 405 auto list = gtl::FlatSet<string>{ 406 "Exp", 407 "Expm1", 408 "L2Loss", 409 "Mean", 410 "Pow", 411 "SaveV2", 412 "Softmax", 413 "SoftmaxCrossEntropyWithLogits", 414 "SparseSoftmaxCrossEntropyWithLogits", 415 "Sum", 416 }; 417 UpdateList("DENYLIST", &list); 418 // For backwards compatibility, keeping the original env variable here. 419 // TODO(reedwm): This should be removed if we don't have active users. 420 UpdateList("BLACKLIST", &list); 421 return list; 422 } 423 ClearList()424 gtl::FlatSet<string> ClearList() override { 425 auto list = gtl::FlatSet<string>{ 426 "Concat", "ConcatV2", "Enter", "EnsureShape", 427 "Equal", "Exit", "ExpandDims", "Identity", 428 "MaxPool", "MaxPool3D", "MaxPool3DGrad", "MaxPoolGrad", 429 "MaxPoolV2", "Maximum", "Merge", "NextIteration", 430 "PreventGradient", "Relu", "Relu6", "Relu6Grad", 431 "ReluGrad", "Reshape", "Select", "SelectV2", 432 "Shape", "ShapeN", "Slice", "Split", 433 "SplitV", "Squeeze", "StopGradient", "Switch", 434 "Transpose", "ZerosLike", 435 }; 436 AddTensorListOps(&list); 437 UpdateList("CLEARLIST", &list); 438 return list; 439 } 440 }; 441 442 } // end namespace grappler 443 } // end namespace tensorflow 444 445 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_ 446