1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/debug_options_flags.h"
17 
18 #include <vector>
19 
20 #include "absl/base/call_once.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/node_hash_map.h"
23 #include "absl/strings/str_format.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/compiler/xla/debug_options_parsers.h"
26 #include "tensorflow/compiler/xla/parse_flags_from_env.h"
27 
28 namespace xla {
29 
DefaultDebugOptionsIgnoringFlags()30 DebugOptions DefaultDebugOptionsIgnoringFlags() {
31   DebugOptions opts;
32   opts.set_xla_llvm_enable_alias_scope_metadata(true);
33   opts.set_xla_llvm_enable_noalias_metadata(true);
34   opts.set_xla_llvm_enable_invariant_load_metadata(true);
35   opts.set_xla_llvm_disable_expensive_passes(false);
36   opts.set_xla_backend_optimization_level(3);
37   opts.set_xla_gpu_autotune_level(4);
38   opts.set_xla_cpu_multi_thread_eigen(true);
39   opts.set_xla_gpu_cuda_data_dir("./cuda_sdk_lib");
40   opts.set_xla_gpu_asm_extra_flags("");
41   opts.set_xla_eliminate_hlo_implicit_broadcast(true);
42   opts.set_xla_dump_hlo_as_html(false);
43   opts.set_xla_dump_fusion_visualization(false);
44   opts.set_xla_dump_include_timestamp(true);
45   opts.set_xla_dump_max_hlo_modules(-1);
46   opts.set_xla_dump_module_metadata(false);
47 #ifdef INTEL_MKL
48   opts.set_xla_cpu_use_mkl_dnn(true);
49 #endif  // INTEL_MKL
50   opts.set_xla_gpu_max_kernel_unroll_factor(4);
51   // Set cudnn batchnorm off by default; it does not provide a performance win
52   // on average.
53   opts.set_xla_gpu_use_cudnn_batchnorm(false);
54 
55   // Run all GPU work on one stream by default.  Using multiple streams
56   // increases memory usage and we lack strong motivating benchmarks for tuning
57   // the heuristics needed to decide when to run on multiple streams.  See
58   // b/77879207.
59   opts.set_xla_gpu_disable_multi_streaming(true);
60 
61   // Disable forms of fast math that have caused users problems in the past.
62   opts.set_xla_cpu_enable_fast_math(true);
63   opts.set_xla_cpu_fast_math_honor_nans(true);
64   opts.set_xla_cpu_fast_math_honor_infs(true);
65   opts.set_xla_cpu_fast_math_honor_functions(true);
66   opts.set_xla_cpu_fast_math_honor_division(true);
67 
68   // By default, copy TF's Eigen style min_max behavior with nans.
69   opts.set_xla_cpu_enable_fast_min_max(true);
70 
71   opts.set_xla_gpu_enable_fast_min_max(true);
72 
73   opts.set_xla_allow_excess_precision(true);
74   opts.set_xla_force_host_platform_device_count(1);
75   opts.set_xla_gpu_deterministic_reductions(false);
76   opts.set_xla_cpu_enable_xprof_traceme(false);
77   opts.set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(false);
78   opts.set_xla_multiheap_size_constraint_per_heap(-1);
79   opts.set_xla_detailed_logging(true);
80   return opts;
81 }
82 
83 static absl::once_flag flags_init;
84 static DebugOptions* flag_values;
85 static std::vector<tensorflow::Flag>* flag_objects;
86 
87 // Maps pass -> initial fuel values (parsed when AllocateFlags was run).
88 static absl::flat_hash_map<string, int64>* initial_fuel;
89 
90 // Maps pass -> whether fuel was ever consumed for that pass.
91 static absl::node_hash_map<string, std::atomic<bool>>* fuel_ever_consumed;
92 
93 // Maps pass -> remaining fuel.
94 //
95 // All threads start off using this global fuel pool, but ResetThreadLocalFuel()
96 // switches them to a thread-local fuel pool.
97 static absl::node_hash_map<string, std::atomic<int64>>* global_fuel;
98 
99 // If we're using thread-local fuel, this stores it.
100 static thread_local std::unique_ptr<
101     absl::node_hash_map<string, std::atomic<int64>>>
102     thread_fuel;  // NOLINT (global variable with nontrivial destructor)
103 
104 // Logs a warning if a pass's fuel was never consumed, on the theory that this
105 // may be a typo in the flag value.  Called atexit.
WarnIfFuelWasNeverConsumed()106 static void WarnIfFuelWasNeverConsumed() {
107   CHECK(fuel_ever_consumed != nullptr);
108   for (const auto& kv : *fuel_ever_consumed) {
109     absl::string_view pass = kv.first;
110     bool was_consumed = kv.second;
111     if (!was_consumed) {
112       LOG(ERROR) << absl::StreamFormat(
113           "Compiler fuel for \"%s\" was never consumed. This may be a typo in "
114           "the --xla_fuel flag you passed.",
115           pass);
116     }
117   }
118 }
119 
120 // Allocates flag_values and flag_objects; this function must not be called more
121 // than once - its call done via call_once.
AllocateFlags()122 static void AllocateFlags() {
123   flag_values = new DebugOptions(DefaultDebugOptionsIgnoringFlags());
124 
125   // Returns a lambda that calls "member_setter" on "flag_values" with the
126   // argument passed in to the lambda.
127   auto bool_setter_for = [](void (DebugOptions::*member_setter)(bool)) {
128     return [member_setter](bool value) {
129       (flag_values->*member_setter)(value);
130       return true;
131     };
132   };
133 
134   // Returns a lambda that calls "member_setter" on "flag_values" with the
135   // argument passed in to the lambda.
136   auto int32_setter_for = [](void (DebugOptions::*member_setter)(int32)) {
137     return [member_setter](int32 value) {
138       (flag_values->*member_setter)(value);
139       return true;
140     };
141   };
142 
143   auto string_setter_for =
144       [](void (DebugOptions::*member_setter)(const string& value)) {
145         return [member_setter](const string& value) {
146           (flag_values->*member_setter)(value);
147           return true;
148         };
149       };
150 
151   // Custom "sub-parser" lambda for xla_disable_hlo_passes.
152   auto setter_for_xla_disable_hlo_passes = [](string comma_separated_values) {
153     for (const auto& passname :
154          std::vector<string>(absl::StrSplit(comma_separated_values, ','))) {
155       flag_values->add_xla_disable_hlo_passes(passname);
156     }
157     return true;
158   };
159 
160   // Custom "sub-parser" lambda for xla_enable_hlo_passes_only.
161   auto setter_for_xla_enable_hlo_passes_only =
162       [](string comma_separated_values) {
163         for (const auto& passname :
164              std::vector<string>(absl::StrSplit(comma_separated_values, ','))) {
165           flag_values->add_xla_enable_hlo_passes_only(passname);
166         }
167         return true;
168       };
169 
170   // Custom "sub-parser" lambda for xla_gpu_ptx_file.
171   auto setter_for_xla_gpu_ptx_file = [](string value) {
172     flag_values->add_xla_gpu_ptx_file(value);
173     return true;
174   };
175 
176   // Custom "sub-parser" lambda for xla_backend_extra_options.
177   auto setter_for_xla_backend_extra_options =
178       [](string comma_separated_values) {
179         auto* extra_options_map =
180             flag_values->mutable_xla_backend_extra_options();
181         parse_xla_backend_extra_options(extra_options_map,
182                                         comma_separated_values);
183         return true;
184       };
185 
186   // Custom "sub-parser" for xla_fuel.  Note that ConsumeFuel does not do any
187   // locking on the fuel global variables.  This means that it's
188   // illegal/undefined behavior to modify this flag value while the compiler is
189   // running.
190   initial_fuel = new absl::flat_hash_map<string, int64>();
191   fuel_ever_consumed = new absl::node_hash_map<string, std::atomic<bool>>();
192   global_fuel = new absl::node_hash_map<string, std::atomic<int64>>();
193   auto setter_for_xla_fuel = [](string xla_fuel_value) {
194     initial_fuel->clear();
195     global_fuel->clear();
196     fuel_ever_consumed->clear();
197 
198     for (const auto& kv : absl::StrSplit(xla_fuel_value, ',')) {
199       std::vector<string> pass_and_fuel = absl::StrSplit(kv, '=');
200       if (pass_and_fuel.size() != 2) {
201         LOG(ERROR) << absl::StreamFormat(
202             "Illegal value for --xla_fuel. Saw %s, but expected token %s to "
203             "have format X=INTEGER.",
204             xla_fuel_value, kv);
205         return false;
206       }
207       const auto& pass = pass_and_fuel[0];
208       const auto& fuel_str = pass_and_fuel[1];
209       int64 fuel;
210       if (!absl::SimpleAtoi(fuel_str, &fuel)) {
211         LOG(ERROR) << absl::StreamFormat(
212             "Illegal value for --xla_fuel. Saw %s, but expected token %s to be "
213             "an integer.",
214             xla_fuel_value, fuel_str);
215         return false;
216       }
217       initial_fuel->emplace(pass, fuel);
218       global_fuel->emplace(pass, fuel);
219       fuel_ever_consumed->emplace(pass, false);
220     }
221 
222     // If --xla_fuel was specified, register an atexit handler which logs a
223     // warning if a pass was specified but never consumed any fuel, on the
224     // theory that this is may be a typo.
225     if (!initial_fuel->empty()) {
226       static absl::once_flag register_atexit_once;
227       absl::call_once(
228           register_atexit_once,
229           +[] { std::atexit(WarnIfFuelWasNeverConsumed); });
230     }
231     return true;
232   };
233 
234   flag_objects = new std::vector<tensorflow::Flag>();
235   // Don't use an initializer list for initializing the vector; this would
236   // create a temporary copy, and exceeds the stack space when compiling with
237   // certain configurations.
238   flag_objects->push_back(tensorflow::Flag(
239       "xla_cpu_enable_fast_math",
240       bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_math),
241       flag_values->xla_cpu_enable_fast_math(),
242       "Enable unsafe fast-math optimizations in the CPU compiler; this may "
243       "produce faster code at the expense of some accuracy."));
244   flag_objects->push_back(tensorflow::Flag(
245       "xla_cpu_fast_math_honor_nans",
246       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_nans),
247       flag_values->xla_cpu_fast_math_honor_nans(),
248       "When xla_cpu_enable_fast_math is true then this controls whether we "
249       "allow operations to produce NaNs.  Ignored when "
250       "xla_cpu_enable_fast_math is false."));
251   flag_objects->push_back(tensorflow::Flag(
252       "xla_cpu_fast_math_honor_infs",
253       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_infs),
254       flag_values->xla_cpu_fast_math_honor_infs(),
255       "When xla_cpu_enable_fast_math is true then this controls whether we "
256       "allow operations to produce infinites.  Ignored when "
257       "xla_cpu_enable_fast_math is false."));
258   flag_objects->push_back(tensorflow::Flag(
259       "xla_cpu_fast_math_honor_division",
260       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_division),
261       flag_values->xla_cpu_fast_math_honor_division(),
262       "When xla_cpu_enable_fast_math is true then this controls whether we "
263       "forbid to use multiplication by the reciprocal instead of division. "
264       "Ignored when xla_cpu_enable_fast_math is false."));
265   flag_objects->push_back(tensorflow::Flag(
266       "xla_cpu_fast_math_honor_functions",
267       bool_setter_for(&DebugOptions::set_xla_cpu_fast_math_honor_functions),
268       flag_values->xla_cpu_fast_math_honor_functions(),
269       "When xla_cpu_enable_fast_math is true then this controls whether we "
270       "forbid to approximate calculations for functions. Ignored when "
271       "xla_cpu_enable_fast_math is false."));
272   flag_objects->push_back(tensorflow::Flag(
273       "xla_cpu_enable_fast_min_max",
274       bool_setter_for(&DebugOptions::set_xla_cpu_enable_fast_min_max),
275       flag_values->xla_cpu_enable_fast_min_max(),
276       "Enable fast floating point min/max lowering that always propagates "
277       "NaNs."));
278   flag_objects->push_back(tensorflow::Flag(
279       "xla_gpu_enable_fast_min_max",
280       bool_setter_for(&DebugOptions::set_xla_gpu_enable_fast_min_max),
281       flag_values->xla_gpu_enable_fast_min_max(),
282       "Enable fast floating point min/max lowering that does not propagate "
283       "NaNs."));
284   flag_objects->push_back(tensorflow::Flag(
285       "xla_llvm_enable_alias_scope_metadata",
286       bool_setter_for(&DebugOptions::set_xla_llvm_enable_alias_scope_metadata),
287       flag_values->xla_llvm_enable_alias_scope_metadata(),
288       "In LLVM-based backends, enable the emission of !alias.scope metadata in "
289       "the generated IR."));
290   flag_objects->push_back(tensorflow::Flag(
291       "xla_llvm_enable_noalias_metadata",
292       bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata),
293       flag_values->xla_llvm_enable_noalias_metadata(),
294       "In LLVM-based backends, enable the emission of !noalias metadata in the "
295       "generated IR."));
296   flag_objects->push_back(tensorflow::Flag(
297       "xla_llvm_enable_invariant_load_metadata",
298       bool_setter_for(
299           &DebugOptions::set_xla_llvm_enable_invariant_load_metadata),
300       flag_values->xla_llvm_enable_invariant_load_metadata(),
301       "In LLVM-based backends, enable the emission of !invariant.load metadata "
302       "in the generated IR."));
303   flag_objects->push_back(tensorflow::Flag(
304       "xla_llvm_disable_expensive_passes",
305       bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes),
306       flag_values->xla_llvm_disable_expensive_passes(),
307       "In LLVM-based backends, disable a custom set of expensive optimization "
308       "passes."));
309   flag_objects->push_back(tensorflow::Flag(
310       "xla_backend_optimization_level",
311       int32_setter_for(&DebugOptions::set_xla_backend_optimization_level),
312       flag_values->xla_backend_optimization_level(),
313       "Numerical optimization level for the XLA compiler backend."));
314   flag_objects->push_back(tensorflow::Flag(
315       "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "",
316       "Comma-separated list of hlo passes to be disabled. These names must "
317       "exactly match the passes' names; no whitespace around commas."));
318   flag_objects->push_back(tensorflow::Flag(
319       "xla_enable_hlo_passes_only", setter_for_xla_enable_hlo_passes_only, "",
320       "Comma-separated list of hlo passes to be enabled. These names must "
321       "exactly match the passes' names; no whitespace around commas. The "
322       "unspecified passes are all disabled."));
323   flag_objects->push_back(tensorflow::Flag(
324       "xla_disable_all_hlo_passes",
325       bool_setter_for(&DebugOptions::set_xla_disable_all_hlo_passes), false,
326       "Disables all HLO passes.  Notes that some passes are necessary for "
327       "correctness and the invariants that must be satisfied by 'fully "
328       "optimized' HLO are different for different devices and may change "
329       "over time.  The only 'guarantee', such as it is, is that if you compile "
330       "XLA and dump the optimized HLO for some graph, you should be able to "
331       "run it again on the same device with the same build of XLA."));
332   flag_objects->push_back(tensorflow::Flag(
333       "xla_embed_ir_in_executable",
334       bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable),
335       flag_values->xla_embed_ir_in_executable(),
336       "Embed the compiler IR as a string in the executable."));
337   flag_objects->push_back(tensorflow::Flag(
338       "xla_eliminate_hlo_implicit_broadcast",
339       bool_setter_for(&DebugOptions::set_xla_eliminate_hlo_implicit_broadcast),
340       flag_values->xla_eliminate_hlo_implicit_broadcast(),
341       "Eliminate implicit broadcasts when lowering user computations to HLO "
342       "instructions; use explicit broadcast instead."));
343   flag_objects->push_back(tensorflow::Flag(
344       "xla_cpu_multi_thread_eigen",
345       bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen),
346       flag_values->xla_cpu_multi_thread_eigen(),
347       "When generating calls to Eigen in the CPU backend, use multi-threaded "
348       "Eigen mode."));
349   flag_objects->push_back(tensorflow::Flag(
350       "xla_gpu_cuda_data_dir", flag_values->mutable_xla_gpu_cuda_data_dir(),
351       "If non-empty, specifies a local directory containing ptxas and nvvm "
352       "libdevice files; otherwise we use those from runfile directories."));
353   flag_objects->push_back(tensorflow::Flag(
354       "xla_gpu_ftz", bool_setter_for(&DebugOptions::set_xla_gpu_ftz),
355       flag_values->xla_gpu_ftz(),
356       "If true, flush-to-zero semantics are enabled in the code generated for "
357       "GPUs."));
358   flag_objects->push_back(tensorflow::Flag(
359       "xla_gpu_disable_multi_streaming",
360       bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming),
361       flag_values->xla_gpu_disable_multi_streaming(),
362       "If true, multi-streaming in the GPU backend is disabled."));
363   flag_objects->push_back(tensorflow::Flag(
364       "xla_gpu_max_kernel_unroll_factor",
365       int32_setter_for(&DebugOptions::set_xla_gpu_max_kernel_unroll_factor),
366       flag_values->xla_gpu_max_kernel_unroll_factor(),
367       "Specify the maximum kernel unroll factor for the GPU backend."));
368   flag_objects->push_back(tensorflow::Flag(
369       "xla_gpu_ptx_file", setter_for_xla_gpu_ptx_file, "",
370       "If non-empty, specifies a file containing ptx to use. The filename "
371       "prefix must have the same pattern as PTX dumped by XLA. This allows to "
372       "match one specific module. General workflow. Get the generated module "
373       "ptx from XLA. Modify it. Then pass it back via this option."));
374   flag_objects->push_back(tensorflow::Flag(
375       "xla_test_all_output_layouts",
376       bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
377       flag_values->xla_test_all_output_layouts(),
378       "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of "
379       "output layouts. For example, with a 3D shape, all permutations of the "
380       "set {0, 1, 2} are tried."));
381   flag_objects->push_back(tensorflow::Flag(
382       "xla_test_all_input_layouts",
383       bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts),
384       flag_values->xla_test_all_input_layouts(),
385       "Let ClientLibraryTestBase::ComputeAndCompare* test all permutations of "
386       "*input* layouts. For example, for 2 input arguments with 2D shape and "
387       "4D shape, the computation will run 2! * 4! times for every possible "
388       "layouts"));
389   flag_objects->push_back(tensorflow::Flag(
390       "xla_hlo_profile", bool_setter_for(&DebugOptions::set_xla_hlo_profile),
391       flag_values->xla_hlo_profile(),
392       "Instrument the computation to collect per-HLO cycle counts"));
393   flag_objects->push_back(tensorflow::Flag(
394       "xla_backend_extra_options", setter_for_xla_backend_extra_options, "",
395       "Extra options to pass to a backend; comma-separated list of 'key=val' "
396       "strings (=val may be omitted); no whitespace around commas."));
397   flag_objects->push_back(tensorflow::Flag(
398       "xla_gpu_use_cudnn_batchnorm",
399       bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm),
400       flag_values->xla_gpu_use_cudnn_batchnorm(),
401       "Allows the GPU backend to implement batchnorm HLOs using cudnn, rather "
402       "than expanding them to a soup of HLOs."));
403   flag_objects->push_back(
404       tensorflow::Flag("xla_cpu_use_mkl_dnn",
405                        bool_setter_for(&DebugOptions::set_xla_cpu_use_mkl_dnn),
406                        flag_values->xla_cpu_use_mkl_dnn(),
407                        "Generate calls to MKL-DNN in the CPU backend."));
408   flag_objects->push_back(tensorflow::Flag(
409       "xla_gpu_crash_on_verification_failures",
410       bool_setter_for(
411           &DebugOptions::set_xla_gpu_crash_on_verification_failures),
412       flag_values->xla_gpu_crash_on_verification_failures(),
413       "Crashes the program on extra verification failures, e.g. cuDNN cross "
414       "checking failures"));
415   flag_objects->push_back(tensorflow::Flag(
416       "xla_gpu_autotune_level",
417       int32_setter_for(&DebugOptions::set_xla_gpu_autotune_level),
418       flag_values->xla_gpu_autotune_level(),
419       "Set GEMM and Convolution auto-tuning level. 0 = off; 1 = on; 2 = "
420       "on+init; 3 = on+init+reinit; 4 = on+init+reinit+check."));
421   flag_objects->push_back(tensorflow::Flag(
422       "xla_force_host_platform_device_count",
423       int32_setter_for(&DebugOptions::set_xla_force_host_platform_device_count),
424       flag_values->xla_force_host_platform_device_count(),
425       "Force the host platform to pretend that there are these many host "
426       "\"devices\". All of these host devices are backed by the same "
427       "threadpool. Setting this to anything other than 1 can increase overhead "
428       "from context switching but we let the user override this behavior to "
429       "help run tests on the host that run models in parallel across multiple "
430       "devices."));
431   flag_objects->push_back(tensorflow::Flag(
432       "xla_gpu_disable_gpuasm_optimizations",
433       bool_setter_for(&DebugOptions::set_xla_gpu_disable_gpuasm_optimizations),
434       flag_values->xla_gpu_disable_gpuasm_optimizations(),
435       "In XLA:GPU run ptxas in -O0 (default is -O3)."));
436   flag_objects->push_back(tensorflow::Flag(
437       "xla_gpu_asm_extra_flags",
438       string_setter_for(&DebugOptions::set_xla_gpu_asm_extra_flags), "",
439       "Pass extra parameters to the GPU assembler tool (i.e., ptxas for CUDA). "
440       "If multiple parameters, separate them by comma."));
441   flag_objects->push_back(tensorflow::Flag(
442       "xla_fuel", setter_for_xla_fuel, /*default_value_for_display=*/"",
443       "Sets compiler fuel, useful for bisecting bugs in passes.  Format "
444       "--xla_fuel=PASS1=NUM1,PASS2=NUM2,..."));
445   flag_objects->push_back(tensorflow::Flag(
446       "xla_dump_to", string_setter_for(&DebugOptions::set_xla_dump_to),
447       flag_values->xla_dump_to(),
448       "Directory into which debugging data is written. If not specified but "
449       "another dumping flag is passed, data will be written to stdout. To "
450       "explicitly write to stdout, set this to \"-\". The values \"sponge\" "
451       "and \"test_undeclared_outputs_dir\" have a special meaning: They cause "
452       "us to dump into the directory specified by the environment variable "
453       "TEST_UNDECLARED_OUTPUTS_DIR."));
454   flag_objects->push_back(tensorflow::Flag(
455       "xla_dump_hlo_as_text",
456       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_text),
457       flag_values->xla_dump_hlo_as_text(),
458       "Dumps HLO modules as text before and after optimizations. Results are "
459       "written to the --xla_dump_to dir, or, if no dir is specified, to "
460       "stdout."));
461   flag_objects->push_back(tensorflow::Flag(
462       "xla_dump_hlo_as_proto",
463       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_proto),
464       flag_values->xla_dump_hlo_as_proto(),
465       "Dumps HLO modules as HloProtos to the directory specified by "
466       "--xla_dump_to."));
467   flag_objects->push_back(
468       tensorflow::Flag("xla_dump_hlo_as_dot",
469                        bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_dot),
470                        flag_values->xla_dump_hlo_as_dot(),
471                        "Dumps HLO modules rendered as dot files to the "
472                        "directory specified by --xla_dump_to."));
473   flag_objects->push_back(
474       tensorflow::Flag("xla_dump_hlo_as_html",
475                        bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_html),
476                        flag_values->xla_dump_hlo_as_html(),
477                        "Dumps HLO modules rendered as HTML files to the "
478                        "directory specified by --xla_dump_to."));
479   flag_objects->push_back(tensorflow::Flag(
480       "xla_dump_hlo_as_url",
481       bool_setter_for(&DebugOptions::set_xla_dump_hlo_as_url),
482       flag_values->xla_dump_hlo_as_url(),
483       "Tries to dump HLO modules rendered as URLs to stdout (and also to the "
484       "directory specified by --xla_dump_to). This is not implemented by "
485       "default; you need to add a plugin which calls "
486       "RegisterGraphToURLRenderer()."));
487   flag_objects->push_back(tensorflow::Flag(
488       "xla_dump_fusion_visualization",
489       bool_setter_for(&DebugOptions::set_xla_dump_fusion_visualization),
490       flag_values->xla_dump_fusion_visualization(),
491       "Tries to generate HLO fusion visualization as an HTML page to the "
492       "directory specified by --xla_dump_to). This is not implemented by "
493       "default; you need to add a plugin which calls "
494       "RegisterGraphToURLRenderer(). Generates a file per computation. "
495       "Currently only implemented for the GPU backend."));
496   flag_objects->push_back(tensorflow::Flag(
497       "xla_dump_hlo_snapshots",
498       bool_setter_for(&DebugOptions::set_xla_dump_hlo_snapshots),
499       flag_values->xla_dump_hlo_snapshots(),
500       "Every time an HLO module is run, dumps an HloSnapshot to the directory "
501       "specified by --xla_dump_to."));
502   flag_objects->push_back(tensorflow::Flag(
503       "xla_dump_hlo_module_re",
504       string_setter_for(&DebugOptions::set_xla_dump_hlo_module_re),
505       flag_values->xla_dump_hlo_module_re(),
506       "Limits dumping only to modules which match this regular expression. "
507       "Default is to dump all modules."));
508   flag_objects->push_back(tensorflow::Flag(
509       "xla_dump_hlo_pass_re",
510       string_setter_for(&DebugOptions::set_xla_dump_hlo_pass_re),
511       flag_values->xla_dump_hlo_pass_re(),
512       "If specified, dumps HLO before and after optimization passes which "
513       "match this regular expression, in addition to dumping at the very "
514       "beginning and end of compilation."));
515   flag_objects->push_back(tensorflow::Flag(
516       "xla_dump_include_timestamp",
517       bool_setter_for(&DebugOptions::set_xla_dump_include_timestamp),
518       flag_values->xla_dump_include_timestamp(),
519       "If specified, includes a timestamp in the dumped filenames."));
520   flag_objects->push_back(tensorflow::Flag(
521       "xla_dump_max_hlo_modules",
522       int32_setter_for(&DebugOptions::set_xla_dump_max_hlo_modules),
523       flag_values->xla_dump_max_hlo_modules(),
524       "Max number of hlo module dumps in a directory. Set to < 0 for "
525       "unbounded."));
526   flag_objects->push_back(tensorflow::Flag(
527       "xla_dump_module_metadata",
528       bool_setter_for(&DebugOptions::set_xla_dump_module_metadata),
529       flag_values->xla_dump_module_metadata(),
530       "Dumps HloModuleMetadata as text protos to the directory specified "
531       "by --xla_dump_to."));
532   flag_objects->push_back(tensorflow::Flag(
533       "xla_hlo_graph_addresses",
534       bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses),
535       flag_values->xla_hlo_graph_addresses(),
536       "When rendering graphs (--xla_dump_hlo_as_{dot,html,url}), displays "
537       "the address in memory of each HloInstruction object."));
538   flag_objects->push_back(tensorflow::Flag(
539       "xla_hlo_graph_sharding_color",
540       bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
541       flag_values->xla_hlo_graph_sharding_color(),
542       "Assign colors based on sharding assignments when generating the HLO "
543       "graphs."));
544   flag_objects->push_back(tensorflow::Flag(
545       "xla_allow_excess_precision",
546       bool_setter_for(&DebugOptions::set_xla_allow_excess_precision),
547       flag_values->xla_allow_excess_precision(),
548       "Allow xla to increase the output precision of an instruction."));
549   flag_objects->push_back(tensorflow::Flag(
550       "xla_gpu_force_conv_nchw",
551       bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nchw),
552       flag_values->xla_gpu_force_conv_nchw(),
553       "For cuDNN convolutions, always use NCHW layouts."));
554   flag_objects->push_back(tensorflow::Flag(
555       "xla_gpu_force_conv_nhwc",
556       bool_setter_for(&DebugOptions::set_xla_gpu_force_conv_nhwc),
557       flag_values->xla_gpu_force_conv_nhwc(),
558       "For cuDNN convolutions, always use NHWC layouts."));
559   flag_objects->push_back(tensorflow::Flag(
560       "xla_gpu_algorithm_denylist_path",
561       string_setter_for(&DebugOptions::set_xla_gpu_algorithm_denylist_path),
562       flag_values->xla_gpu_algorithm_denylist_path(),
563       "An AlgorithmDenylist text proto file as a denylist of convolutions to "
564       "avoid to use."));
565   flag_objects->push_back(tensorflow::Flag(
566       "xla_gpu_deterministic_reductions",
567       bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_reductions),
568       flag_values->xla_gpu_deterministic_reductions(),
569       "Always run deterministic reductions on GPU"));
570   flag_objects->push_back(tensorflow::Flag(
571       "xla_tpu_detect_nan",
572       bool_setter_for(&DebugOptions::set_xla_tpu_detect_nan),
573       flag_values->xla_tpu_detect_nan(),
574       "Trigger error on execution on TPU if a NAN value is detected"));
575   flag_objects->push_back(tensorflow::Flag(
576       "xla_tpu_detect_inf",
577       bool_setter_for(&DebugOptions::set_xla_tpu_detect_inf),
578       flag_values->xla_tpu_detect_inf(),
579       "Trigger error on execution on TPU if a INF value is detected"));
580   flag_objects->push_back(tensorflow::Flag(
581       "xla_cpu_enable_xprof_traceme",
582       bool_setter_for(&DebugOptions::set_xla_cpu_enable_xprof_traceme),
583       flag_values->xla_cpu_enable_xprof_traceme(),
584       "If true, XLA CPU generates code to call "
585       "TraceMe::Activity{Start|End} around HLO operations."));
586   flag_objects->push_back(tensorflow::Flag(
587       "xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found",
588       bool_setter_for(
589           &DebugOptions::
590               set_xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found),
591       flag_values->xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found(),
592       "If true, XLA GPU falls back to the driver if ptxas is not found. Note "
593       "that falling back to the driver can have drawbacks like using more "
594       "memory and/or other bugs during compilation, so we recommend setting "
595       "this flag to false."));
596   flag_objects->push_back(tensorflow::Flag(
597       "xla_multiheap_size_constraint_per_heap",
598       int32_setter_for(
599           &DebugOptions::set_xla_multiheap_size_constraint_per_heap),
600       flag_values->xla_multiheap_size_constraint_per_heap(),
601       "Generates multiple heaps (i.e., temp buffers) with a size "
602       "constraint on each heap to avoid Out-of-Memory due to memory "
603       "fragmentation. The constraint is soft, so it works with tensors "
604       "larger than the given constraint size. -1 corresponds to no "
605       "constraints."));
606   flag_objects->push_back(tensorflow::Flag(
607       "xla_gpu_force_compilation_parallelism",
608       int32_setter_for(
609           &DebugOptions::set_xla_gpu_force_compilation_parallelism),
610       flag_values->xla_gpu_force_compilation_parallelism(),
611       "Overrides normal multi-threaded compilation settting to use this many "
612       "threads. Setting to 0 (the default value) means no enforcement."));
613   flag_objects->push_back(tensorflow::Flag(
614       "xla_gpu_deterministic_ops",
615       bool_setter_for(&DebugOptions::set_xla_gpu_deterministic_ops),
616       flag_values->xla_gpu_deterministic_ops(),
617       "Guarantees run-to-run determinism on GPU."));
618 
619   ParseFlagsFromEnvAndDieIfUnknown("XLA_FLAGS", *flag_objects);
620 }
621 
AppendDebugOptionsFlags(std::vector<tensorflow::Flag> * flag_list)622 void AppendDebugOptionsFlags(std::vector<tensorflow::Flag>* flag_list) {
623   absl::call_once(flags_init, &AllocateFlags);
624   flag_list->insert(flag_list->end(), flag_objects->begin(),
625                     flag_objects->end());
626 }
627 
GetDebugOptionsFromFlags()628 xla::DebugOptions GetDebugOptionsFromFlags() {
629   absl::call_once(flags_init, &AllocateFlags);
630   return *flag_values;
631 }
632 
ResetThreadLocalFuel()633 void ResetThreadLocalFuel() {
634   absl::call_once(flags_init, &AllocateFlags);
635 
636   thread_fuel.reset(new absl::node_hash_map<string, std::atomic<int64>>());
637   CHECK(initial_fuel != nullptr);
638   for (const auto& kv : *initial_fuel) {
639     thread_fuel->emplace(kv.first, kv.second);
640   }
641 }
642 
ConsumeFuel(absl::string_view pass,bool * just_ran_out)643 bool ConsumeFuel(absl::string_view pass, bool* just_ran_out) {
644   absl::call_once(flags_init, &AllocateFlags);
645   if (just_ran_out != nullptr) {
646     *just_ran_out = false;
647   }
648   auto* fuel_pool = thread_fuel ? thread_fuel.get() : global_fuel;
649   if (fuel_pool->empty()) {
650     return true;
651   }
652   auto it = fuel_pool->find(pass);
653   if (it == fuel_pool->end()) {
654     return true;
655   }
656   std::atomic<int64>& remaining_fuel = it->second;
657   std::atomic<bool>& fuel_has_been_consumed = fuel_ever_consumed->at(pass);
658   fuel_has_been_consumed = true;
659 
660   int64 remaining = remaining_fuel.fetch_sub(1);
661   if (just_ran_out != nullptr) {
662     *just_ran_out = remaining == 0;
663   }
664   return remaining > 0;
665 }
666 
667 }  // namespace xla
668