1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
18 
19 #include "tensorflow/core/lib/gtl/flatset.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/util/env_var.h"
22 
23 namespace tensorflow {
24 namespace grappler {
25 
26 // Represents the four lists of ops: the allow list, infer list, deny list, and
27 // clear list. These lists determine which ops are converted to fp16/bf16
28 // (referred to as 'f16' for short) and which ops stay as fp32.
29 class AutoMixedPrecisionLists {
30  public:
~AutoMixedPrecisionLists()31   virtual ~AutoMixedPrecisionLists() {}
32 
33   // Returns the set of ops that are considered numerically-safe (for execution
34   // in f16), performance-critical, and can run in f16. These ops are always
35   // converted to f16.
36   virtual gtl::FlatSet<string> AllowList() = 0;
37   // Returns the set of ops that can run in f16 and are considered numerically-
38   // safe (for execution in f16), but which may be made unsafe by an upstream
39   // denylist op.
40   virtual gtl::FlatSet<string> InferList() = 0;
41   // Returns the set of ops that are considered numerically-dangerous (i.e.,
42   // unsafe for execution in f16) and whose effects may also be observed in
43   // downstream nodes (e.g. for f16, in Exp -> Add, the Add is unsafe due to
44   // the Exp).
45   virtual gtl::FlatSet<string> DenyList() = 0;
46   // Returns the set of ops that do not have numerically-significant effects
47   // (i.e., they are always considered safe for execution in f16 precision), and
48   // can run in f16.
49   virtual gtl::FlatSet<string> ClearList() = 0;
50 
51  protected:
52   // Adds or removes ops from list if certain environmental variables are set.
UpdateList(const string & list_name,gtl::FlatSet<string> * list)53   static void UpdateList(const string& list_name, gtl::FlatSet<string>* list) {
54     CHECK(list_name == "ALLOWLIST" || list_name == "INFERLIST" ||  // Crash OK.
55           list_name == "DENYLIST" || list_name == "CLEARLIST" ||
56           // TODO(reedwm): for bkwds compat; remove when no longer necessary:
57           list_name == "WHITELIST" || list_name == "GRAYLIST" ||
58           list_name == "BLACKLIST");
59     string add_env_var =
60         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD";
61     string remove_env_var =
62         "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_REMOVE";
63     string to_add, to_remove;
64     TF_CHECK_OK(ReadStringFromEnvVar(add_env_var, "", &to_add));
65     TF_CHECK_OK(ReadStringFromEnvVar(remove_env_var, "", &to_remove));
66     for (const auto& x : str_util::Split(to_add, ",")) {
67       list->insert(x);
68     }
69     for (const auto& x : str_util::Split(to_remove, ",")) {
70       list->erase(x);
71     }
72   }
73 
74   // Subclasses should include these on the ClearList.
AddTensorListOps(gtl::FlatSet<string> * list)75   static void AddTensorListOps(gtl::FlatSet<string>* list) {
76     // Note: if a data structure op (such as TensorListPopBack) is added here,
77     // IsTensorListReaderOp or IsTensorListWriterOp may need to be modified
78     // LINT.IfChange
79     constexpr const char* tensor_list_ops[] = {
80         "TensorListConcat",     "TensorListConcatLists",
81         "TensorListConcatV2",   "TensorListGather",
82         "TensorListGetItem",    "TensorListPopBack",
83         "TensorListPushBack",   "TensorListPushBackBatch",
84         "TensorListFromTensor", "TensorListScatter",
85         "TensorListScatterV2",  "TensorListScatterIntoExistingList",
86         "TensorListSetItem",    "TensorListSplit",
87         "TensorListStack"};
88     // LINT.ThenChange(//tensorflow/core/grappler/optimizers/auto_mixed_precision.cc)
89     for (auto op : tensor_list_ops) {
90       list->insert(op);
91     }
92   }
93 };
94 
95 class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists {
96  private:
IsPseudoFastMath()97   static bool IsPseudoFastMath() {
98     string optimization_level;
99     TF_CHECK_OK(
100         ReadStringFromEnvVar("TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_LEVEL", "",
101                              &optimization_level));
102     optimization_level = str_util::Uppercase(optimization_level);
103     return optimization_level == "TENSOR_CORES_ONLY";
104   }
105 
106  public:
AutoMixedPrecisionListsCuda(int cuda_version,int cudnn_version)107   AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version)
108       : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {}
109 
AllowList()110   gtl::FlatSet<string> AllowList() override {
111     auto list = gtl::FlatSet<string>{
112         "BlockLSTM",
113         "BlockLSTMV2",
114         "BlockLSTMGrad",
115         "BlockLSTMGradV2",
116         "Conv2D",
117         "Conv2DBackpropFilter",
118         "Conv2DBackpropInput",
119         "CudnnRNN",
120         "CudnnRNNBackprop",
121         "CudnnRNNBackpropV2",
122         "CudnnRNNBackpropV3",
123         "CudnnRNNV2",
124         "CudnnRNNV3",
125         "Einsum",
126         "GRUBlockCell",
127         "GRUBlockCellGrad",
128         "LSTMBlockCell",
129         "LSTMBlockCellGrad",
130         "MatMul",
131     };
132     if (cuda_version_ >= 9010) {
133       // Fp16 BatchMatMul is slow before CUDA 9.1.
134       list.insert("BatchMatMul");
135       list.insert("BatchMatMulV2");
136     }
137     if (cudnn_version_ >= 7602) {
138       // Fp16 3D conv is slow before CUDNN 7.6.2.
139       list.insert("Conv3D");
140       list.insert("Conv3DBackpropFilter");
141       list.insert("Conv3DBackpropFilterV2");
142       list.insert("Conv3DBackpropInput");
143       list.insert("Conv3DBackpropInputV2");
144     }
145     if (cudnn_version_ >= 8000) {
146       list.insert("DepthwiseConv2dNative");
147       list.insert("DepthwiseConv2dNativeBackpropFilter");
148       list.insert("DepthwiseConv2dNativeBackpropInput");
149     }
150     UpdateList("ALLOWLIST", &list);
151     // For backwards compatibility, keeping the original env variable here.
152     // TODO(reedwm): This should be removed if we don't have active users.
153     UpdateList("WHITELIST", &list);
154 
155     return list;
156   }
157 
InferList()158   gtl::FlatSet<string> InferList() override {
159     if (IsPseudoFastMath()) {
160       return gtl::FlatSet<string>{};
161     }
162 
163     auto list = gtl::FlatSet<string>{
164         "Add",
165         "AddN",
166         "AddV2",
167         "AvgPool",
168         "AvgPool3D",
169         "AvgPool3DGrad",
170         "AvgPoolGrad",
171         "BiasAdd",
172         "BiasAddGrad",
173         "BiasAddV1",
174         "Elu",
175         "EluGrad",
176         "Erf",
177         "Erfc",
178         "FloorDiv",
179         "FusedBatchNormV2",
180         "FusedBatchNormGradV2",
181         "FusedBatchNormV3",
182         "FusedBatchNormGradV3",
183         "_FusedBatchNormEx",
184         "Inv",
185         "LeakyRelu",
186         "LeakyReluGrad",
187         "Log",
188         "Log1p",
189         "LogSoftmax",
190         "Mul",
191         "Prod",
192         "RealDiv",
193         "Reciprocal",
194         "Selu",
195         "SeluGrad",
196         "Sigmoid",
197         "SigmoidGrad",
198         "Softmax",
199         "Softplus",
200         "SoftplusGrad",
201         "Softsign",
202         "SoftsignGrad",
203         "Sqrt",
204         "Sub",
205         "Tanh",
206         "TanhGrad",
207     };
208     UpdateList("INFERLIST", &list);
209     // For backwards compatibility, keeping the original env variable here.
210     // TODO(reedwm): This should be removed if we don't have active users.
211     UpdateList("GRAYLIST", &list);
212     return list;
213   }
214 
DenyList()215   gtl::FlatSet<string> DenyList() override {
216     if (IsPseudoFastMath()) {
217       return gtl::FlatSet<string>{};
218     }
219 
220     auto list = gtl::FlatSet<string>{
221         "Exp",
222         "Expm1",
223         "L2Loss",
224         "Mean",
225         "Pow",
226         "SaveV2",
227         "SoftmaxCrossEntropyWithLogits",
228         "SparseSoftmaxCrossEntropyWithLogits",
229         "Sum",
230     };
231     UpdateList("DENYLIST", &list);
232     // For backwards compatibility, keeping the original env variable here.
233     // TODO(reedwm): This should be removed if we don't have active users.
234     UpdateList("BLACKLIST", &list);
235     return list;
236   }
237 
ClearList()238   gtl::FlatSet<string> ClearList() override {
239     if (IsPseudoFastMath()) {
240       return gtl::FlatSet<string>{};
241     }
242 
243     auto list = gtl::FlatSet<string>{
244         "Abs",
245         "ArgMax",
246         "ArgMin",
247         "BatchToSpace",
248         "BatchToSpaceND",
249         "BroadcastTo",
250         "Ceil",
251         "CheckNumerics",
252         "ClipByValue",
253         "Concat",
254         "ConcatV2",
255         "DepthToSpace",
256         "DynamicPartition",
257         "DynamicStitch",
258         "Enter",
259         "EnsureShape",
260         "Equal",
261         "Exit",
262         "ExpandDims",
263         "Fill",
264         "Floor",
265         "Gather",
266         "GatherNd",
267         "GatherV2",
268         "Greater",
269         "GreaterEqual",
270         "Identity",
271         "IdentityN",
272         "IsFinite",
273         "IsInf",
274         "IsNan",
275         "Less",
276         "LessEqual",
277         "Max",
278         "MaxPool",
279         "MaxPool3D",
280         "MaxPool3DGrad",
281         "MaxPool3DGradGrad",
282         "MaxPoolGrad",
283         "MaxPoolGradGrad",
284         "MaxPoolGradGradV2",
285         "MaxPoolGradV2",
286         "MaxPoolV2",
287         "Maximum",
288         "Merge",
289         "Min",
290         "Minimum",
291         "MirrorPad",
292         "MirrorPadGrad",
293         "Neg",
294         "NextIteration",
295         "NotEqual",
296         "OneHot",
297         "OnesLike",
298         "Pack",
299         "Pad",
300         "PadV2",
301         "PreventGradient",
302         "Rank",
303         "Relu",
304         "Relu6",
305         "Relu6Grad",
306         "ReluGrad",
307         "Reshape",
308         "ResizeNearestNeighbor",
309         "ResizeNearestNeighborGrad",
310         "Reverse",
311         "ReverseSequence",
312         "ReverseV2",
313         "Round",
314         "Select",
315         "SelectV2",
316         "Shape",
317         "ShapeN",
318         "Sign",
319         "Size",
320         "Slice",
321         "Snapshot",
322         "SpaceToBatch",
323         "SpaceToBatchND",
324         "SpaceToDepth",
325         "Split",
326         "SplitV",
327         "Squeeze",
328         "StopGradient",
329         "StridedSlice",
330         "StridedSliceGrad",
331         "Switch",
332         "Tile",
333         "TopK",
334         "TopKV2",
335         "Transpose",
336         "Where",
337         "ZerosLike",
338     };
339     AddTensorListOps(&list);
340     UpdateList("CLEARLIST", &list);
341     return list;
342   }
343 
344  private:
345   int cuda_version_;
346   int cudnn_version_;
347 };
348 
349 class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists {
350  public:
AutoMixedPrecisionListsMkl()351   AutoMixedPrecisionListsMkl() {}
352 
353   // Only ops which are supported by MKL in bfloat16 should be added to the
354   // allow list, infer list, or clear list.
AllowList()355   gtl::FlatSet<string> AllowList() override {
356     auto list = gtl::FlatSet<string>{"Conv2D",
357                                      "Conv2DBackpropFilter",
358                                      "Conv2DBackpropInput",
359                                      "Conv3D",
360                                      "Conv3DBackpropFilterV2",
361                                      "Conv3DBackpropInputV2",
362                                      "DepthwiseConv2dNative",
363                                      "DepthwiseConv2dNativeBackpropFilter",
364                                      "DepthwiseConv2dNativeBackpropInput",
365                                      "MatMul",
366                                      "BatchMatMul",
367                                      "BatchMatMulV2"};
368 
369     UpdateList("ALLOWLIST", &list);
370     // For backwards compatibility, keeping the original env variable here.
371     // TODO(reedwm): This should be removed if we don't have active users.
372     UpdateList("WHITELIST", &list);
373     return list;
374   }
375 
InferList()376   gtl::FlatSet<string> InferList() override {
377     auto list = gtl::FlatSet<string>{
378         "Add",
379         "AddN",
380         "AddV2",
381         "AvgPool",
382         "AvgPool3D",
383         "AvgPool3DGrad",
384         "AvgPoolGrad",
385         "BiasAdd",
386         "BiasAddGrad",
387         "BiasAddV1",
388         "FusedBatchNormV2",
389         "FusedBatchNormGradV2",
390         "FusedBatchNormV3",
391         "FusedBatchNormGradV3",
392         "LeakyRelu",
393         "LeakyReluGrad",
394         "Mul",
395         "Sub",
396     };
397     UpdateList("INFERLIST", &list);
398     // For backwards compatibility, keeping the original env variable here.
399     // TODO(reedwm): This should be removed if we don't have active users.
400     UpdateList("GRAYLIST", &list);
401     return list;
402   }
403 
DenyList()404   gtl::FlatSet<string> DenyList() override {
405     auto list = gtl::FlatSet<string>{
406         "Exp",
407         "Expm1",
408         "L2Loss",
409         "Mean",
410         "Pow",
411         "SaveV2",
412         "Softmax",
413         "SoftmaxCrossEntropyWithLogits",
414         "SparseSoftmaxCrossEntropyWithLogits",
415         "Sum",
416     };
417     UpdateList("DENYLIST", &list);
418     // For backwards compatibility, keeping the original env variable here.
419     // TODO(reedwm): This should be removed if we don't have active users.
420     UpdateList("BLACKLIST", &list);
421     return list;
422   }
423 
ClearList()424   gtl::FlatSet<string> ClearList() override {
425     auto list = gtl::FlatSet<string>{
426         "Concat",          "ConcatV2",  "Enter",         "EnsureShape",
427         "Equal",           "Exit",      "ExpandDims",    "Identity",
428         "MaxPool",         "MaxPool3D", "MaxPool3DGrad", "MaxPoolGrad",
429         "MaxPoolV2",       "Maximum",   "Merge",         "NextIteration",
430         "PreventGradient", "Relu",      "Relu6",         "Relu6Grad",
431         "ReluGrad",        "Reshape",   "Select",        "SelectV2",
432         "Shape",           "ShapeN",    "Slice",         "Split",
433         "SplitV",          "Squeeze",   "StopGradient",  "Switch",
434         "Transpose",       "ZerosLike",
435     };
436     AddTensorListOps(&list);
437     UpdateList("CLEARLIST", &list);
438     return list;
439   }
440 };
441 
442 }  // end namespace grappler
443 }  // end namespace tensorflow
444 
445 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_AUTO_MIXED_PRECISION_LISTS_H_
446