1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/flags.h"
17 
18 #include <mutex>  // NOLINT
19 
20 #include "absl/base/call_once.h"
21 #include "absl/strings/numbers.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/strip.h"
24 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
25 #include "tensorflow/core/platform/macros.h"
26 #include "tensorflow/core/util/command_line_flags.h"
27 
28 namespace tensorflow {
29 namespace {
30 
31 BuildXlaOpsPassFlags* build_ops_flags;
32 MarkForCompilationPassFlags* mark_for_compilation_flags;
33 XlaDeviceFlags* device_flags;
34 XlaOpsCommonFlags* ops_flags;
35 IntroduceFloatingPointJitterPassFlags* jitter_flags;
36 MlirCommonFlags* mlir_flags;
37 
38 std::vector<Flag>* flag_list;
39 absl::once_flag flags_init;
40 
SetterForXlaAutoJitFlag(const string & value)41 bool SetterForXlaAutoJitFlag(const string& value) {
42   int32 opt_level;
43   // We need to use the mark_for_compilation_flags directly here instead of
44   // going via GetMarkForCompilationPassFlags() to avoid infinite recursion. The
45   // latter will try to setup and parse flags, which would bring us back to this
46   // setter.
47   if (absl::SimpleAtoi(value, &opt_level)) {
48     mark_for_compilation_flags->xla_auto_jit_flag
49         .optimization_level_single_gpu = opt_level;
50     mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
51         opt_level;
52     return true;
53   }
54 
55   if (value == "fusible") {
56     mark_for_compilation_flags->xla_auto_jit_flag
57         .optimization_level_single_gpu = 1;
58     mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general =
59         1;
60     mark_for_compilation_flags->tf_xla_ops_to_cluster = "FUSIBLE";
61     return true;
62   }
63 
64   absl::string_view value_sv(value);
65   if (!absl::ConsumePrefix(&value_sv, "single-gpu(") ||
66       !absl::ConsumeSuffix(&value_sv, ")") ||
67       !absl::SimpleAtoi(value_sv, &opt_level)) {
68     return false;
69   }
70 
71   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
72       opt_level;
73   return true;
74 }
75 
AppendMarkForCompilationPassFlagsInternal(std::vector<Flag> * flag_list)76 void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
77   std::vector<Flag> new_flags = {
78       Flag("tf_xla_auto_jit", SetterForXlaAutoJitFlag, "0",
79            "Control compilation of operators into XLA computations on CPU and "
80            "GPU devices.  0 = use ConfigProto setting; -1 = off; 1 = on for "
81            "things very likely to be improved; 2 = on for everything; "
82            "(experimental) fusible = only for Tensorflow operations that XLA "
83            "knows how to fuse.  "
84            "If set to single-gpu(<N>) then this resolves to <N> for single-GPU "
85            "graphs (graphs that have at least one node placed on a GPU and no "
86            "more than one GPU is in use through the entire graph) and 0 "
87            "otherwise.  Experimental."),
88       Flag("tf_xla_min_cluster_size",
89            &mark_for_compilation_flags->tf_xla_min_cluster_size,
90            "Minimum number of operators in an XLA compilation. Ignored for "
91            "operators placed on an XLA device or operators explicitly marked "
92            "for compilation."),
93       Flag("tf_xla_max_cluster_size",
94            &mark_for_compilation_flags->tf_xla_max_cluster_size,
95            "Maximum number of operators in an XLA compilation."),
96       Flag(
97           "tf_xla_ops_to_cluster",
98           &mark_for_compilation_flags->tf_xla_ops_to_cluster,
99           "(experimental) "
100           "Limit the operations clustered by XLA to these operations. "
101           "If multiple, separate them with commas. Shortcuts: "
102           " PW: All point-wise operations."
103           " RED: All reduction operations."
104           " MISC: Mixed operations."
105           " PWRED: TF operations that get converted to PW+RED operation in XLA."
106           " REDUCEWINDOW: TF operations like MaxPool/AvgPool that get "
107           "converted to ReduceWindow in XLA."
108           " REDUCEWINDOWPW: Operation that get converted to ReduceWindow + PW "
109           "(LRN, LRNGrad)."
110           " BN: TF FusedBatchNorm* operations."
111           " FUSIBLE: All TF operations that XLA can fuse (All the above). "
112           "You can also put any TF operation name, e.g. 'FUSIBLE,MatMul'."),
113       Flag("tf_xla_clustering_debug",
114            &mark_for_compilation_flags->tf_xla_clustering_debug,
115            "Dump graphs during XLA compilation."),
116       Flag("tf_xla_cpu_global_jit",
117            &mark_for_compilation_flags->tf_xla_cpu_global_jit,
118            "Enables global JIT compilation for CPU via SessionOptions."),
119       Flag("tf_xla_clustering_fuel",
120            &mark_for_compilation_flags->tf_xla_clustering_fuel,
121            "Places an artificial limit on the number of ops marked as "
122            "eligible for clustering."),
123       Flag("tf_xla_disable_deadness_safety_checks_for_debugging",
124            &mark_for_compilation_flags
125                 ->tf_xla_disable_deadness_safety_checks_for_debugging,
126            "Disable deadness related safety checks when clustering (this is "
127            "unsound)."),
128       Flag("tf_xla_disable_resource_variable_safety_checks_for_debugging",
129            &mark_for_compilation_flags
130                 ->tf_xla_disable_resource_variable_safety_checks_for_debugging,
131            "Disable resource variables related safety checks when clustering "
132            "(this is unsound).")};
133   flag_list->insert(flag_list->end(), new_flags.begin(), new_flags.end());
134 }
135 
AllocateAndParseFlags()136 void AllocateAndParseFlags() {
137   build_ops_flags = new BuildXlaOpsPassFlags;
138   build_ops_flags->tf_xla_enable_lazy_compilation = true;
139   build_ops_flags->tf_xla_print_cluster_outputs = false;
140   build_ops_flags->tf_xla_check_cluster_input_numerics = false;
141   build_ops_flags->tf_xla_check_cluster_output_numerics = false;
142   build_ops_flags->tf_xla_disable_constant_folding = false;
143 
144   mark_for_compilation_flags = new MarkForCompilationPassFlags;
145   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_single_gpu =
146       0;
147   mark_for_compilation_flags->xla_auto_jit_flag.optimization_level_general = 0;
148   mark_for_compilation_flags->tf_xla_min_cluster_size = 4;
149   mark_for_compilation_flags->tf_xla_max_cluster_size =
150       std::numeric_limits<int32>::max();
151   mark_for_compilation_flags->tf_xla_clustering_debug = false;
152   mark_for_compilation_flags->tf_xla_cpu_global_jit = false;
153   mark_for_compilation_flags->tf_xla_clustering_fuel =
154       std::numeric_limits<int64>::max();
155   mark_for_compilation_flags
156       ->tf_xla_disable_deadness_safety_checks_for_debugging = false;
157   mark_for_compilation_flags
158       ->tf_xla_disable_resource_variable_safety_checks_for_debugging = false;
159 
160   device_flags = new XlaDeviceFlags;
161   device_flags->tf_xla_compile_on_demand = false;
162   device_flags->tf_xla_enable_xla_devices = false;
163 
164   ops_flags = new XlaOpsCommonFlags;
165   ops_flags->tf_xla_always_defer_compilation = false;
166 
167   jitter_flags = new IntroduceFloatingPointJitterPassFlags;
168   jitter_flags->jitter_amount = 1e-5;
169 
170   // The `enable_mlir_bridge` flag allows the user to explicitly request that
171   // their program is (or isn't) compiled using the MLIR-based TF-to-XLA bridge.
172   //
173   // The `enable_mlir_bridge_is_explicit` variable tracks whether or not the
174   // user has made an explicit request. That is, if this variable is set to
175   // true, the program honors the user's request as per `enable_mlir_bridge`; if
176   // it's set to false, the default behavior is used (which may run either
177   // bridge, on a per-graph basis).
178   bool enable_mlir_bridge = false;
179   bool enable_mlir_bridge_is_explicit = false;
180   bool mlir_bridge_safe_mode = false;
181 
182   auto setter_for_jitter_tensor_names = [](string sequence) {
183     jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
184     return true;
185   };
186 
187   flag_list = new std::vector<Flag>(
188       {Flag("tf_xla_enable_lazy_compilation",
189             &build_ops_flags->tf_xla_enable_lazy_compilation, ""),
190        Flag("tf_xla_print_cluster_outputs",
191             &build_ops_flags->tf_xla_print_cluster_outputs,
192             "If true then insert Print nodes to print out values produced by "
193             "XLA clusters."),
194        Flag("tf_xla_check_cluster_input_numerics",
195             &build_ops_flags->tf_xla_check_cluster_input_numerics,
196             "If true then insert CheckNumerics nodes to check all cluster "
197             "inputs."),
198        Flag("tf_xla_check_cluster_output_numerics",
199             &build_ops_flags->tf_xla_check_cluster_output_numerics,
200             "If true then insert CheckNumerics nodes to check all cluster "
201             "outputs."),
202        Flag("tf_xla_disable_constant_folding",
203             &build_ops_flags->tf_xla_disable_constant_folding,
204             "If true then disables constant folding on TF graph before XLA "
205             "compilation."),
206 
207        Flag("tf_xla_compile_on_demand", &device_flags->tf_xla_compile_on_demand,
208             "Switch a device into 'on-demand' mode, where instead of "
209             "autoclustering ops are compiled one by one just-in-time."),
210 
211        Flag("tf_xla_enable_xla_devices",
212             &device_flags->tf_xla_enable_xla_devices,
213             "Generate XLA_* devices, where placing a computation on such a "
214             "device"
215             "forces compilation by XLA. Deprecated."),
216 
217        Flag("tf_xla_always_defer_compilation",
218             &ops_flags->tf_xla_always_defer_compilation, ""),
219 
220        Flag("tf_introduce_floating_point_jitter_to_tensors",
221             setter_for_jitter_tensor_names, "",
222             "The Tensors to add the jitter to.  The tensors are named in the "
223             "TensorId format of <node name>:<output idx>."),
224        Flag("tf_introduce_floating_point_jitter_amount",
225             &jitter_flags->jitter_amount,
226             "The amount of jitter to introduce.  This amount is added to each "
227             "element in the tensors named in `tensor_names."),
228 
229        Flag("tf_mlir_enable_mlir_bridge", &enable_mlir_bridge,
230             "Enables experimental MLIR-Based TensorFlow Compiler Bridge.",
231             &enable_mlir_bridge_is_explicit),
232        Flag(
233            "tf_mlir_bridge_safe_mode", &mlir_bridge_safe_mode,
234            "When tf_mlir_enable_mlir_bridge is true, this field can enable "
235            "the MLIR bridge's safe mode. When the MLIR bridge is in safe mode, "
236            "it only runs for graphs that use features MLIR bridge currently "
237            "supports.")});
238 
239   AppendMarkForCompilationPassFlagsInternal(flag_list);
240   xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
241 
242   mlir_flags = new MlirCommonFlags;
243   if (!enable_mlir_bridge_is_explicit) {
244     mlir_flags->tf_mlir_enable_mlir_bridge =
245         ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_UNSPECIFIED;
246   } else if (enable_mlir_bridge) {
247     mlir_flags->tf_mlir_enable_mlir_bridge =
248         (mlir_bridge_safe_mode)
249             ? ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_SAFE_MODE_ENABLED
250             : ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
251   } else {
252     mlir_flags->tf_mlir_enable_mlir_bridge =
253         ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_DISABLED;
254   }
255 }
256 
257 }  // namespace
258 
SetXlaAutoJitFlagFromFlagString(const string & value)259 bool SetXlaAutoJitFlagFromFlagString(const string& value) {
260   absl::call_once(flags_init, &AllocateAndParseFlags);
261   return SetterForXlaAutoJitFlag(value);
262 }
263 
GetBuildXlaOpsPassFlags()264 BuildXlaOpsPassFlags* GetBuildXlaOpsPassFlags() {
265   absl::call_once(flags_init, &AllocateAndParseFlags);
266   return build_ops_flags;
267 }
268 
GetMarkForCompilationPassFlags()269 MarkForCompilationPassFlags* GetMarkForCompilationPassFlags() {
270   absl::call_once(flags_init, &AllocateAndParseFlags);
271   return mark_for_compilation_flags;
272 }
273 
GetXlaDeviceFlags()274 XlaDeviceFlags* GetXlaDeviceFlags() {
275   absl::call_once(flags_init, &AllocateAndParseFlags);
276   return device_flags;
277 }
278 
GetXlaOpsCommonFlags()279 const XlaOpsCommonFlags& GetXlaOpsCommonFlags() {
280   absl::call_once(flags_init, &AllocateAndParseFlags);
281   return *ops_flags;
282 }
283 
284 const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags()285 GetIntroduceFloatingPointJitterPassFlags() {
286   absl::call_once(flags_init, &AllocateAndParseFlags);
287   return *jitter_flags;
288 }
289 
GetMlirCommonFlags()290 MlirCommonFlags* GetMlirCommonFlags() {
291   absl::call_once(flags_init, &AllocateAndParseFlags);
292   return mlir_flags;
293 }
294 
AppendMarkForCompilationPassFlags(std::vector<Flag> * flag_list)295 void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
296   absl::call_once(flags_init, &AllocateAndParseFlags);
297   AppendMarkForCompilationPassFlagsInternal(flag_list);
298 }
299 
300 static std::atomic<bool> xla_compilation_disabled(false);
301 
DisableXlaCompilation()302 void DisableXlaCompilation() { xla_compilation_disabled = true; }
303 
FailOnXlaCompilation()304 bool FailOnXlaCompilation() { return xla_compilation_disabled; }
305 
306 }  // namespace tensorflow
307