1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20import "tensorflow/compiler/xla/service/hlo.proto";
21import "tensorflow/compiler/xla/xla_data.proto";
22
23// Debugging options for XLA. These options may change at any time - there are
24// no guarantees about backward or forward compatibility for these fields.
25message DebugOptions {
26  // Show addresses of HLO ops in graph dump.
27  bool xla_hlo_graph_addresses = 2;
28
29  // Instrument the computation to collect per-HLO cycle counts.
30  bool xla_hlo_profile = 9;
31
32  // List of HLO passes to disable/enable. These names must exactly match the
33  // pass names as specified by the HloPassInterface::name() method.
34  //
35  // At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must
36  // be empty.
37  repeated string xla_disable_hlo_passes = 30;
38  repeated string xla_enable_hlo_passes_only = 124;
39
40  // Disables all HLO passes.  Notes that some passes are necessary for
41  // correctness and the invariants that must be satisfied by "fully optimized"
42  // HLO are different for different devices and may change over time.  The only
43  // "guarantee", such as it is, is that if you compile XLA and dump the
44  // optimized HLO for some graph, you should be able to run it again on the
45  // same device with the same build of XLA.
46  bool xla_disable_all_hlo_passes = 104;
47
48  // Numerical optimization level for the XLA compiler backend; the specific
49  // interpretation of this value is left to the backends.
50  int32 xla_backend_optimization_level = 31;
51
52  // Embed the compiler IR as a string in the executable.
53  bool xla_embed_ir_in_executable = 33;
54
55  // Eliminate implicit broadcasts when lowering user computations to HLO
56  // instructions; use explicit broadcast instead.
57  bool xla_eliminate_hlo_implicit_broadcast = 35;
58
59  // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen
60  // mode.
61  bool xla_cpu_multi_thread_eigen = 60;
62
63  // Path to directory with cuda/ptx tools and libraries.
64  string xla_gpu_cuda_data_dir = 61;
65
66  // Enable flush-to-zero semantics in the GPU backend.
67  bool xla_gpu_ftz = 62;
68
69  // Disable multi-streaming in the GPU backend.
70  bool xla_gpu_disable_multi_streaming = 63;
71
72  // Debugging feature: if enabled, the GPU backend will assign HLO operators to
73  // randomly chosen streams. This is intended to trigger concurrency bugs.
74  bool xla_gpu_use_random_streams = 134;
75
76  // If true, in LLVM-based backends, emit !alias.scope metadata in
77  // generated IR.
78  bool xla_llvm_enable_alias_scope_metadata = 70;
79
80  // If true, in LLVM-based backends, emit !noalias metadata in the
81  // generated IR.
82  bool xla_llvm_enable_noalias_metadata = 71;
83
84  // If true, in LLVM-based backends, emit !invariant.load metadata in
85  // the generated IR.
86  bool xla_llvm_enable_invariant_load_metadata = 72;
87
88  // If true, a set of expensive LLVM optimization passes will not be run.
89  bool xla_llvm_disable_expensive_passes = 73;
90
91  reserved 80;  // Was hlo_reduce_precision_options
92
93  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
94  // computation will run n! times with all permunations of layouts for the
95  // output shape in rank n. For example, with a 3D shape, all permutations of
96  // the set {0, 1, 2} are tried.
97  bool xla_test_all_output_layouts = 90;
98
99  // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the
100  // computation will run for all permunations of layouts of all input
101  // arguments. For example, with 2 input arguments in 2D and 4D shapes, the
102  // computation will run 2! * 4! times.
103  bool xla_test_all_input_layouts = 91;
104
105  // Assign colors based on sharding information when generating the Graphviz
106  // HLO graph.
107  bool xla_hlo_graph_sharding_color = 92;
108
109  reserved 93;  // Was xla_hlo_tfgraph_device_scopes
110
111  // If true, the GPU backend is free to use cudnn for HLO batch normalization
112  // ops.
113  bool xla_gpu_use_cudnn_batchnorm = 94;
114
115  // Generate calls to MKL-DNN in the CPU backend.
116  bool xla_cpu_use_mkl_dnn = 97;
117
118  // Maximum kernel unroll factor for the GPU backend.
119  int32 xla_gpu_max_kernel_unroll_factor = 98;
120
121  // When true, "unsafe" mathematical optimizations are enabled. These
122  // transformations include but are not limited to:
123  //
124  //  - Reducing the precision of operations (e.g. using an approximate sin
125  //    function, or transforming x/y into x * (1/y)).
126  //  - Assuming that operations never produce or consume NaN or +/- Inf (this
127  //    behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}).
128  //  - Assuming that +0 and -0 are indistinguishable.
129  bool xla_cpu_enable_fast_math = 99;
130
131  // When xla_cpu_enable_fast_math is true then this controls whether we allow
132  // operations to produce NaNs.  Ignored when xla_cpu_enable_fast_math is
133  // false.
134  bool xla_cpu_fast_math_honor_nans = 120;
135
136  // When xla_cpu_enable_fast_math is true then this controls whether we allow
137  // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is
138  // false.
139  bool xla_cpu_fast_math_honor_infs = 121;
140
141  // When xla_cpu_enable_fast_math is true then this controls whether we forbid
142  // to use the reciprocal of an argument instead of division. Ignored when
143  // xla_cpu_enable_fast_math is false.
144  bool xla_cpu_fast_math_honor_division = 126;
145
146  // When xla_cpu_enable_fast_math is true then this controls whether we forbid
147  // to approximate calculations for functions. Ignored when
148  // xla_cpu_enable_fast_math is false.
149  bool xla_cpu_fast_math_honor_functions = 129;
150
151  // When false we lower the Minimum and Maximum hlos in the CPU backend such
152  // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN.  In other words, if flag
153  // this is false we always propagate NaNs through Min and Max.
154  //
155  // Note, this does not correspond to the exact same behavior as the gpu flag
156  // below!
157  bool xla_cpu_enable_fast_min_max = 140;
158
159  // When true we lower the Minimum and Maximum hlos in the GPU backend such
160  // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN.  In other words, if flag
161  // this is true we don't propagate NaNs through Min and Max.
162  //
163  // Note, this does not correspond to the exact same behavior as the cpu flag
164  // above!
165  bool xla_gpu_enable_fast_min_max = 100;
166
167  // Allows xla to increase the output precision of floating point operations.
168  bool xla_allow_excess_precision = 122;
169
170  // Crashes the program when any kind of verification fails, instead of just
171  // logging the failures. One example is cross checking of convolution results
172  // among different algorithms.
173  bool xla_gpu_crash_on_verification_failures = 101;
174
175  // Disable GEMM and Convolution auto-tuning.
176  int32 xla_gpu_autotune_level = 123;
177
178  // Force the host platform to pretend that there are these many host
179  // "devices".  All these devices are backed by the same threadpool.  Defaults
180  // to 1.
181  //
182  // Setting this to anything other than 1 can increase overhead from context
183  // switching but we let the user override this behavior to help run tests on
184  // the host that run models in parallel across multiple devices.
185  int32 xla_force_host_platform_device_count = 102;
186
187  // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3).
188  bool xla_gpu_disable_gpuasm_optimizations = 103;
189
190  // Enable fast math with eigen in the HLO evaluator.
191  bool xla_hlo_evaluator_use_fast_path = 106;
192
193  // Temporary option to allow support for both the R1 and the scalar index
194  // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing.
195  bool xla_allow_scalar_index_dynamic_ops = 107;
196
197  enum StepMarkerLocation {
198    // Generate a step marker at the program entry. This handles the case where
199    // each step is done by one or multiple program execution(s). Only the first
200    // program will be tagged for generating a step marker at the program entry.
201    // This is the default.
202    STEP_MARK_AT_ENTRY = 0;
203    // Generate a step marker at each iteration of the top level while loop,
204    // which is assumed to be a training loop.
205    STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1;
206    // Generate a step marker at each iteration of the second level while loops,
207    // which is assumed to be a training or eval loop.
208    STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3;
209    // No step marker generated.
210    STEP_MARK_NONE = 2;
211  }
212  // Option to emit a target-specific marker to indicate the start of a training
213  // step. The location of the marker (if any) is determined by the option
214  // value.
215  StepMarkerLocation xla_step_marker_location = 108;
216
217  //
218  // BEGIN flags controlling dumping HLO modules for debugging.
219  //
220  // When dumping is enabled, HLO modules dumped at the very beginning and end
221  // of compilation, and optionally also during the pass pipeline.
222  //
223  // In general, if you set one of these flags, we will try to infer reasonable
224  // defaults for the others.  For example:
225  //
226  //  * Setting --xla_dump_to=/tmp/foo without specifying a format
227  //    with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text.
228  //
229  //  * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will
230  //    dump to stdout.
231  //
232
233  // Directory to dump into.
234  string xla_dump_to = 109;
235
236  // If specified, will only dump modules which match this regexp.
237  string xla_dump_hlo_module_re = 110;
238
239  // If this flag is specified, will also HLO before and after passes that match
240  // this regular expression.  Set to .* to dump before/after all passes.
241  string xla_dump_hlo_pass_re = 111;
242
243  // Specifies the format that HLO is dumped in.  Multiple of these may be
244  // specified.
245  bool xla_dump_hlo_as_text = 112;
246  bool xla_dump_hlo_as_proto = 113;
247  bool xla_dump_hlo_as_dot = 114;
248  bool xla_dump_hlo_as_url = 115;
249
250  // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML)
251  bool xla_dump_hlo_as_html = 116;
252
253  // Dump the visualization of the fusion progress.
254  bool xla_dump_fusion_visualization = 149;
255
256  // If true, every time an HLO module is run, we will dump an HloSnapshot
257  // (essentially, a serialized module plus its inputs) to the --xla_dump_to
258  // directory.
259  bool xla_dump_hlo_snapshots = 118;
260
261  // Include a timestamp in the dumped filenames.
262  bool xla_dump_include_timestamp = 131;
263
264  // Max number of hlo module dumps in a directory. Set to < 0 for unbounded.
265  int32 xla_dump_max_hlo_modules = 132;
266
267  // Dump HloModuleMetadata as a text proto for each HLO module.
268  bool xla_dump_module_metadata = 144;
269
270  //
271  // END flags controlling dumping HLO modules.
272  //
273
274  // Overrides for XLA GPU's convolution layout heuristic.
275  bool xla_gpu_force_conv_nchw = 125;
276  bool xla_gpu_force_conv_nhwc = 146;
277
278  // Paths to files with ptx code.
279  repeated string xla_gpu_ptx_file = 127;
280
281  // Denylist for cuDNN convolutions.
282  string xla_gpu_algorithm_denylist_path = 128;
283
284  // Guarantee run-to-run determinism from reductions on XLA:GPU.
285  bool xla_gpu_deterministic_reductions = 130;
286
287  // Debug options that trigger execution errors when NaN or Inf are detected.
288  bool xla_tpu_detect_nan = 135;
289  bool xla_tpu_detect_inf = 136;
290
291  // True if TraceMe annotations are enabled for XLA:CPU.
292  bool xla_cpu_enable_xprof_traceme = 137;
293
294  // It is usually preferable to not fallback to the driver; it can consume more
295  // memory, or have bugs.
296  bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138;
297
298  // Extra parameters to pass the GPU assembler.
299  string xla_gpu_asm_extra_flags = 141;
300
301  // Per-heap size constraint. New heaps will be created if per-heap max size is
302  // reached.
303  int32 xla_multiheap_size_constraint_per_heap = 142;
304
305  // Enable detailed logging into vlog.
306  bool xla_detailed_logging = 143;
307
308  // Overrides normal multi-threaded compilation settting to use this many
309  // threads. Setting to 0 (the default value) means no enforcement.
310  int32 xla_gpu_force_compilation_parallelism = 147;
311
312  // Guarantees run-to-run determinism. At present, the HLO ops Scatter and
313  // SelectAndScatter do not have deterministic XLA:GPU implementations.
314  // Compilation errors out if these ops are encountered.
315  bool xla_gpu_deterministic_ops = 148;
316
317  // Next id: 150
318
319  // Extra options to pass to the compilation backend (e.g. LLVM); specific
320  // interpretation of these values is left to the backend.
321  map<string, string> xla_backend_extra_options = 500;
322
323  reserved 5, 117, 133,
324      139;  // were xla_hlo_dump_as_graphdef, xla_dump_to,
325            // xla_gpu_use_horizontal_fusion, and
326            // xla_gpu_unsafe_fallback_to_driver_on_ptxas_error
327}
328
329// These settings control how XLA compiles and/or runs code.  Not all settings
330// will have an effect on every platform.
331//
332// When adding new fields, keep in mind that boolean fields default to false.
333message ExecutionOptions {
334  // This optional field's layout is used as a hint when storing the output of
335  // this computation.  Subsequent transfers of this output array to the client
336  // may be faster when using this layout.
337  //
338  // We use a Shape here to accommodate computations that return a tuple.
339  ShapeProto shape_with_output_layout = 2;
340
341  // Used to seed random-number generators used in this computation.  If this is
342  // 0, we generate a seed ourselves.
343  //
344  // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation.
345  uint64 seed = 3;
346
347  DebugOptions debug_options = 4;
348
349  // This optional field specifies a particular set of devices to run the
350  // computation on. The computation will be partitioned across these devices.
351  // If not provided, the default device will be chosen.
352  repeated DeviceHandle device_handles = 5;
353
354  // Number of replicas of the computation to run. If zero, uses the default
355  // number of replicas for the XLA service.
356  int32 num_replicas = 6;
357
358  // This optional field specifies the device assignment if known at compile
359  // time.
360  DeviceAssignmentProto device_assignment = 7;
361
362  // Alias input and output buffers for parameters that are passed-through XLA
363  // modules without being changed.
364  bool alias_passthrough_params = 8;
365
366  // Number of partitions of the computation to run (model parallelism).
367  // If zero, uses the default number of partitions for the XLA service.
368  int32 num_partitions = 9;
369
370  // Used to identify a set of programs that should be launch together.
371  int32 launch_id = 10;
372
373  // Indicates whether to use SPMD (true) or MPMD (false) partitioning when
374  // num_partitions > 1 and XLA is requested to partition the input program.
375  bool use_spmd_partitioning = 11;
376
377  // If set, deduplicate hlo into function calls to reduce binary size. Only
378  // works on TPU.
379  bool deduplicate_hlo = 12;
380
381  // If set, broadcast replicated parameters to all replicas, using collectives.
382  // Only applicable to TPU.
383  bool broadcast_replicated_parameters_via_collectives = 13;
384}
385
386message GetDeviceHandlesRequest {
387  int64 device_count = 1;
388}
389
390message GetDeviceHandlesResponse {
391  repeated DeviceHandle device_handles = 1;
392}
393
394message TransferToClientRequest {
395  GlobalDataHandle data = 1;
396
397  // This optional field directs the service to return the literal in this
398  // layout. A shape is used to hold the layout to accommodate tuples.
399  ShapeProto shape_with_layout = 2;
400}
401
402message TransferToClientResponse {
403  LiteralProto literal = 1;
404}
405
406message TransferToServerRequest {
407  LiteralProto literal = 1;
408  DeviceHandle device_handle = 2;
409}
410
411message TransferToServerResponse {
412  GlobalDataHandle data = 1;
413}
414
415message TransferToInfeedRequest {
416  LiteralProto literal = 1;
417  int64 replica_id = 2;
418  DeviceHandle device_handle = 3;
419}
420
421message TransferToInfeedResponse {}
422
423message TransferFromOutfeedRequest {
424  // This optional field directs the service to return the literal in this
425  // layout. A shape is used to hold the layout to accommodate tuples.
426  ShapeProto shape_with_layout = 1;
427
428  int64 replica_id = 2;
429  DeviceHandle device_handle = 3;
430}
431
432message TransferFromOutfeedResponse {
433  LiteralProto literal = 1;
434}
435
436message ResetDeviceRequest {
437  DeviceHandle device_handle = 1;
438}
439
440message ResetDeviceResponse {}
441
442message ComputationGraphStatsRequest {
443  HloModuleProto computation = 1;
444  DebugOptions debug_options = 2;
445}
446
447message ComputationStatsResponse {
448  ComputationStats stats = 1;
449}
450
451message CreateChannelHandleRequest {
452  ChannelHandle.ChannelType channel_type = 1;
453}
454
455message CreateChannelHandleResponse {
456  ChannelHandle channel = 1;
457}
458
459message UnregisterRequest {
460  repeated GlobalDataHandle data = 1;
461}
462
463message UnregisterResponse {}
464
465message CompileRequest {
466  // The graph to be compiled.
467  HloModuleProto computation = 1;
468
469  // Options that affect how XLA compiles code to service this request.
470  ExecutionOptions execution_options = 2;
471
472  // The layouts of the input arguments. If not set, the default layout will be
473  // used. Although the real arguments are not needed in compilation, the
474  // layouts of the arguments can affect the compilation.
475  repeated ShapeProto input_shape_with_layout = 3;
476}
477
478message CompileResponse {
479  // The handle to the executable.
480  ExecutionHandle handle = 1;
481}
482
483message ExecuteRequest {
484  ExecutionHandle handle = 1;
485
486  // The shape and layout of the arguments must be the same as the those of the
487  // executable's parameters.
488  repeated GlobalDataHandle arguments = 2;
489}
490
491// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace
492// the uses with calls to Compile and Execute.
493message ExecuteGraphRequest {
494  HloModuleProto computation = 1;
495  repeated GlobalDataHandle arguments = 2;
496
497  // Options that affect how XLA compiles and runs code to service this request.
498  ExecutionOptions execution_options = 3;
499}
500
501message ExecuteGraphParallelRequest {
502  repeated ExecuteGraphRequest requests = 1;
503}
504
505message ExecuteResponse {
506  GlobalDataHandle output = 1;
507  ExecutionProfile profile = 2;
508}
509
510message ExecuteParallelResponse {
511  repeated ExecuteResponse responses = 1;
512}
513
514message WaitForExecutionRequest {
515  ExecutionHandle execution = 1;
516}
517
518message WaitForExecutionResponse {
519  GlobalDataHandle output = 1;
520  ExecutionProfile profile = 2;
521}
522
523message ComputeConstantGraphRequest {
524  HloModuleProto computation = 1;
525  LayoutProto output_layout = 2;
526}
527
528message ComputeConstantResponse {
529  // A LiteralProto is returned directly for this request.
530  LiteralProto literal = 1;
531}
532
533message DeconstructTupleRequest {
534  GlobalDataHandle tuple_handle = 2;
535}
536
537message DeconstructTupleResponse {
538  repeated GlobalDataHandle element_handles = 1;
539}
540
541message LoadDataRequest {
542  // Describes the path of the ColumnIO tablet to load.
543  string columnio_tablet_path = 1;
544
545  // Describes the field to load within the ColumnIO tablet.
546  string columnio_field = 2;
547
548  // Individual element shape, excluding rows.
549  ShapeProto element_shape = 3;
550
551  // Warning: ColumnIO does not support random-access, so use offset with
552  // caution in performance-critical scenarios.
553  int64 offset = 4;
554
555  // Maximum number of elements (with shape element_shape) to load.
556  int64 limit = 5;
557
558  // If more than one item is requested (via limit > 1), then this request
559  // attribute zips together the produced vectors.
560  bool zip = 6;
561}
562
563message LoadDataResponse {
564  GlobalDataHandle data = 1;
565  ShapeProto data_shape = 2;
566  int64 available_rows = 3;
567  int64 rows_loaded = 4;
568  int64 nanoseconds = 5;
569}
570
571message GetShapeRequest {
572  GlobalDataHandle data = 1;
573}
574
575message GetShapeResponse {
576  ShapeProto shape = 1;
577}
578
579message UnpackRequest {
580  GlobalDataHandle data = 1;
581}
582
583message UnpackResponse {
584  repeated GlobalDataHandle tied_data = 1;
585}
586