1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19
20option cc_enable_arenas = true;
21
22// Primitive types are the individual values that can be held in rectangular
23// multidimensional arrays. A description of the rectangular multidimensional
24// array dimensions / primitive type is given by Shape, below.
25enum PrimitiveType {
26  // Invalid primitive type to serve as default.
27  PRIMITIVE_TYPE_INVALID = 0;
28
29  // Predicates are two-state booleans.
30  PRED = 1;
31
32  // Signed integral values of fixed width.
33  S8 = 2;
34  S16 = 3;
35  S32 = 4;
36  S64 = 5;
37
38  // Unsigned integral values of fixed width.
39  U8 = 6;
40  U16 = 7;
41  U32 = 8;
42  U64 = 9;
43
44  // Floating-point values of fixed width.
45  //
46  // Note: if f16s are not natively supported on the device, they will be
47  // converted to f16 from f32 at arbirary points in the computation.
48  F16 = 10;
49  F32 = 11;
50
51  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
52  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
53  // and 7 bits for the mantissa.
54  BF16 = 16;
55
56  F64 = 12;
57
58  // Complex values of fixed width.
59  C64 = 15;   // Paired F32 (real, imag), as in std::complex<float>.
60  C128 = 18;  // Paired F64 (real, imag), as in std::complex<double>.
61
62  // A tuple is a polymorphic sequence; e.g. a shape that holds different
63  // sub-shapes. They are used for things like returning multiple values from a
64  // computation; e.g. a computation that returns weights and biases may have a
65  // signature that results in a tuple like (f32[784x2000], f32[2000])
66  //
67  // If a shape proto has the tuple element type, it may not have any entries
68  // in the dimensions field.
69  TUPLE = 13;
70
71  // An opaque type used for passing context-specific data to a custom
72  // operation. Shapes of this primitive type will have empty dimensions and
73  // tuple_shapes fields.
74  OPAQUE = 14;
75
76  // A token type threaded between side-effecting operations. Shapes of this
77  // primitive type will have empty dimensions and tuple_shapes fields.
78  TOKEN = 17;
79
80  // Next = 19
81}
82
83// Describes the padding configuration for Pad operation. The padding amount on
84// both edges as well as between the elements are specified for each dimension.
85message PaddingConfig {
86  // Describes the padding configuration for a dimension.
87  message PaddingConfigDimension {
88    // Padding amount on the low-end (next to the index 0). May be negative.
89    int64 edge_padding_low = 1;
90
91    // Padding amount on the high-end (next to the highest index). May be
92    // negative.
93    int64 edge_padding_high = 2;
94
95    // Padding amount between the elements. May not be negative.
96    int64 interior_padding = 3;
97  }
98
99  // The padding configuration for all dimensions.
100  repeated PaddingConfigDimension dimensions = 1;
101}
102
103// A format specifies the method used by a layout to store an array in memory.
104enum Format {
105  // TODO(b/120869032): Rename this to FORMAT_NONE or something else which
106  // better corresponds to its meaning.
107  INVALID_FORMAT = 0;
108  // The default layout, with exactly one storage location per element.
109  DENSE = 1;
110  // A sparsely encoded layout, providing only the index/value pairs of non-zero
111  // elements.
112  SPARSE = 2;
113}
114
115// Describes a tile used in tiling-based layout. Refer to
116// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
117// details about tiling-based layout.
118message TileProto {
119  // Number of elements in each dimension of the tile. It's ordered from the
120  // most major dimension of the tile to the most minor dimension of the tile.
121  // The dimensions correspond to a suffix of the dimensions of the shape being
122  // tiled.
123  repeated int64 dimensions = 1;
124}
125
126// A layout describes how the array is placed in (1D) memory space.  This
127// includes the minor-to-major ordering of dimensions within a shape.
128//
129// Clients must specify the layouts of input Literals to the
130// computation. Layouts specified in interior operations which take Shapes (for
131// example, Convert) are ignored.
132//
133// See the XLA documentation for more information on shapes and layouts.
134//
135// LINT.IfChange
136message LayoutProto {
137  // The method used to store the data in memory. The format determines which of
138  // the other fields are used by the layout.
139  Format format = 4;
140
141  // Sequence of dimension numbers, from minor (fastest varying index) to major
142  // (slowest varying index). This field is required.
143  repeated int64 minor_to_major = 1;
144
145  reserved 2;
146  reserved "padded_dimensions";
147
148  reserved 3;
149  reserved "padding_value";
150
151  // The maximum number of elements that can be stored for SPARSE formats.  This
152  // can be used to determine the maximum size in bytes of arrays stored in
153  // memory.  This field must be unset unless the format is SPARSE.
154  int64 max_sparse_elements = 5;
155
156  // A sequence of tiles, starting from the tile that's applied first to the
157  // Shape.
158  //
159  // TODO(b/119839262): implement tiling in each backend or add Unimplemented
160  // error.
161  repeated TileProto tiles = 6;
162
163  // Bit size of each element. If the size is bigger than what the element
164  // type requires, the value is stored in the least significant
165  // bits and the additional most significant bits are filled with 0's.
166  //
167  // TODO(b/119839262): implement in each backend or add Unimplemented error.
168  int64 element_size_in_bits = 7;
169
170  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
171  // LayoutUtil::Hash appropriately to account for the new field.
172}
173// LINT.ThenChange( \
174//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc,      \
175//     https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc)
176
177// A shape describes the number of dimensions in the array, the size of each
178// dimension, and the primitive component type.
179//
180// Tuples are a special case in that they have rank zero and have tuple_shapes
181// defined.
182//
183// See the XLA documentation for more information on shapes and layouts.
184//
185// LINT.IfChange
186message ShapeProto {
187  reserved 1;
188  reserved "rank";
189
190  // The element type for this shape.
191  PrimitiveType element_type = 2;
192
193  // The size (number of elements) for each dimension, or an upper bound on the
194  // size if the dimension is dynamic.  In XLA, dimensions are numbered from 0
195  // to N-1 for an N-dimensional array. The first element of 'dimensions' is the
196  // size of dimension 0, the second element is the size of dimension 1, and so
197  // forth.  Empty list indicates a scalar.
198  //
199  // If the respective element in 'is_dimension_dynamic' is true then the value
200  // in this field represents an upper bound on the size of the dimension.
201  repeated int64 dimensions = 3;
202
203  // For tuples only, the shapes of constituent shapes in the tuple sequence.
204  repeated ShapeProto tuple_shapes = 4;
205
206  // The layout used to back this shape.
207  LayoutProto layout = 5;
208
209  // For arrays, this indicates whether or not each dimension is
210  // dynamically-sized. The number of elements in this repeated field should be
211  // zero (indicating that no dimensions are dynamic) or equal to the number of
212  // elements in the 'dimensions' field.
213  repeated bool is_dynamic_dimension = 6;
214
215  // Important: if any field is added, be sure to modify ShapeUtil::Equal(),
216  // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for
217  // the new field.
218}
219// LINT.ThenChange( \
220//     https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc)
221
222// Shape of the parameters and output of a computation (like a traditional
223// function signature).
224message ProgramShapeProto {
225  repeated ShapeProto parameters = 1;
226  ShapeProto result = 2;
227  repeated string parameter_names = 3;
228}
229
230// Statistics of a computation.
231message ComputationStats {
232  // The number of floating point operations in the computation.
233  double flop_count = 1;
234
235  // The number of transcendental operations (e.g., exp) in the computation.
236  double transcendental_count = 2;
237}
238
239// Symbolization metadata for HLO Instructions.
240//
241// This metadata is used for debugging XLA code generation, as well as
242// performance profiling of XLA-generated executables.
243message OpMetadata {
244  // The framework op name that generated this XLA op.
245  //
246  // Frameworks that build on top of XLA should mirror the names of their ops
247  // back to users by specifying the op_type. In this way, even if the
248  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
249  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
250  // multiple ops, then each op should have the op_type be "SoftMax".)
251  string op_type = 1;
252  // The user-specified name of the op.
253  //
254  // This name is often unique within a computation. Note: some frameworks
255  // add auto-generated names if the user does not provide one.
256  string op_name = 2;
257  // Indicate a file and line that this op is associated to in a user's program.
258  //
259  // e.g. it could be the file and line of user code that generated the op.
260  string source_file = 3;
261  int32 source_line = 4;
262}
263
264// Profile data from the execution of a computation.
265message ExecutionProfile {
266  // Whether the executable was read from the compilation cache.
267  bool compilation_cache_hit = 1;
268
269  // The time in milliseconds spent to compile the computation. This only set if
270  // the executable was not read from the compilation cache
271  // (compilation_cache_hit == false).
272  int64 compile_time_ms = 2;
273
274  // The number of cycles spent for the computation. This does not include the
275  // time taken for the data transfers between the host and the device. This is
276  // a target-dependent field and only used for debugging purposes.
277  int64 compute_cycle_count = 3;
278
279  // The time in nanoseconds spent for the computation, without data transfer.
280  int64 compute_time_ns = 4;
281
282  // The time in nanoseconds spent for the entire computation, including the
283  // result data transfer time. Current implementation does not spend any cycles
284  // for the input data transfer since the memory is initialized with the proper
285  // values before the execution.
286  int64 compute_and_transfer_time_ns = 5;
287
288  // The size of the binary code in the executable.
289  int64 executable_size_in_bytes = 6;
290}
291
292// Handle given to a user that represents an execution that the user launched
293// asynchronously on the device.
294message ExecutionHandle {
295  int64 handle = 1;
296}
297
298// Handle given to a user that represents a globally accessible allocation.
299// Contrast this against a ComputationDataHandle, which is not globally
300// accessible, since it only exists within a specific computation.
301message GlobalDataHandle {
302  int64 handle = 1;
303}
304
305// Handle given to a user that represents a replicated virtual device. Each
306// replicated device represents N physical devices for execution where N is the
307// number of replicas.
308message DeviceHandle {
309  int64 handle = 1;
310
311  // The number of model-parallel virtual devices that communicate via XLA
312  // Send/Recv instructions.
313  int64 device_count = 2;
314}
315
316// Handle given to a user to represent a channel between two computations
317// via a Send and Recv instruction pair. Channels are unbuffered, so Send
318// Send instructions will be blocked until the data is transferred.
319message ChannelHandle {
320  int64 handle = 1;
321  enum ChannelType {
322    // Invalid primitive type to serve as default.
323    CHANNEL_TYPE_INVALID = 0;
324
325    // A channel for sending data between devices.
326    DEVICE_TO_DEVICE = 1;
327
328    // A channel for sending data from the device to the host. Can only be used
329    // with a Send operation.
330    DEVICE_TO_HOST = 2;
331
332    // A channel for sending data from the host to the device. Can only be used
333    // with a Recv operation.
334    HOST_TO_DEVICE = 3;
335  }
336  ChannelType type = 2;
337}
338
339// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
340// represents the device ids assigned to a set of replicated computations.
341// See xla::DeviceAssignment class comment for more details.
342message DeviceAssignmentProto {
343  int32 replica_count = 1;
344  int32 computation_count = 2;
345
346  // Each logical computation runs on replica_count physical devices.
347  // ComputationDevice represents the device ids assinged to the replicas.
348  message ComputationDevice {
349    repeated int32 replica_device_ids = 1;
350  }
351  repeated ComputationDevice computation_devices = 3;
352}
353
354// Literals are used when the server and client need to exchange materialized
355// data / results. Literals are also used to describe constants used in
356// computations.
357//
358// Transfers to/from the client are encoded in literal form, and the structure
359// of the repeated fields is implied by the shape.
360message LiteralProto {
361  ShapeProto shape = 1;
362  repeated bool preds = 2;
363  bytes s8s = 15;
364  bytes u8s = 3;
365  repeated int32 s32s = 4;
366  repeated int64 s64s = 5;
367  repeated uint32 u32s = 6;
368  repeated uint64 u64s = 7;
369  repeated float f32s = 8;
370  repeated double f64s = 9;
371  repeated float c64s = 12;    // Stored as interleaved real, imag floats.
372  repeated double c128s = 18;  // Stored as interleaved real, imag doubles.
373  repeated LiteralProto tuple_literals = 10;
374  // The F16s, BF16s, U16s and S16s are encoded in little endian byte order
375  bytes f16s = 11;
376  bytes bf16s = 13;
377  bytes u16s = 16;
378  bytes s16s = 17;
379  repeated int64 sparse_indices = 14;
380  // Next = 19
381}
382
383message WindowDimension {
384  // The size of the window in this dimension. For a rectangle, this would be
385  // the width or height.
386  int64 size = 1;
387
388  // The stride at which the window moves across the base area in this
389  // dimension. In other words, this is the spacing between different
390  // positions of the window in this dimension.
391  int64 stride = 2;
392
393  // If positive, means the amount of padding to add to the base area at the low
394  // end of this dimension; if negative, its negative means the number of
395  // elements removed from the low end of this dimension. For example, in the
396  // horizontal dimension of a rectangle, this would be the number of padding
397  // values to pad on the left, given that indices increase when going right.
398  // The actual padding value depends upon the context. Convolution pads with
399  // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's
400  // init value.
401  int64 padding_low = 3;
402
403  // As padding_low, but on the high end of this dimension. For example, in the
404  // horizontal dimension of a rectangle, this would be the number of values to
405  // pad on the right, given that indices increase when going right.
406  int64 padding_high = 4;
407
408  // Dilation factor of the sliding window in this dimension. A dilation factor
409  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
410  // implicitly placed between each kernel element. This value may not be less
411  // than 1. See documentation for convolution.
412  int64 window_dilation = 5;
413
414  // Dilation factor of the base area in this dimension. A dilation factor of 1
415  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
416  // placed between each base area element. This value may not be less than 1.
417  // See documentation for convolution.
418  int64 base_dilation = 6;
419
420  // Window reversal means that this dimension was logically reversed before the
421  // operation.
422  bool window_reversal = 7;
423}
424
425// Describes the windowing in an operation such as convolution.
426//
427// The window is moved across a base area and for each position of the
428// window a computation is performed. The field below describes the
429// window and the movement of the window across a base area.
430message Window {
431  repeated WindowDimension dimensions = 1;
432}
433
434// Describes the dimension numbers for a gather operation.
435//
436// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
437// more details.
438message GatherDimensionNumbers {
439  // "Window indices" is a term for a set of indices that index into the
440  // interior of a dynamic-slice from the input tensor, the starting indices for
441  // which were computed from output_gather_dims (see the operation semantic for
442  // how this is defined) and the start_indices tensor.
443  //
444  // The window indices for a specific output index Out is computed as:
445  //
446  //  i = 0
447  //  for (k : [0, input_tensor_shape.rank))
448  //    window_indices[k] =
449  //      if k in collapsed_slice_dims
450  //      then 0
451  //      else Out[offset_dims[i++]]
452  repeated int64 offset_dims = 1;
453  repeated int64 collapsed_slice_dims = 2;
454
455  // This is interpreted as a map from i to start_index_map[i]. It
456  // transforms the gather index looked up from the start_indices tensor into
457  // the starting index in the input space.
458  repeated int64 start_index_map = 3;
459
460  // The dimension in the start_indices input that contains the starting
461  // indices.
462  int64 index_vector_dim = 4;
463}
464
465// Describes the dimension numbers for a scatter operation.
466//
467// All the fields are similar to the corresponding fields in
468// GatherDimensionNumbers. Differences are noted below.
469message ScatterDimensionNumbers {
470  // The set of dimensions in the updates shape that are window dimensions.
471  repeated int64 update_window_dims = 1;
472  // The set of window dimensions that must be inserted into the updates shape.
473  repeated int64 inserted_window_dims = 2;
474
475  repeated int64 scatter_dims_to_operand_dims = 3;
476  int64 index_vector_dim = 4;
477}
478
479message ConvolutionDimensionNumbers {
480  // The number of the dimension that represents batch in the input.
481  int64 input_batch_dimension = 7;
482
483  // The number of the dimension that represents features in the input.
484  int64 input_feature_dimension = 8;
485
486  // The dimension numbers for the spatial dimensions that the window
487  // moves through in the input.
488  repeated int64 input_spatial_dimensions = 11;
489
490  // The number of the dimension that represents input features in the
491  // convolutional kernel (rhs).
492  int64 kernel_input_feature_dimension = 3;
493
494  // The number of the dimension that represents output features in
495  // the convolutional kernel (rhs).
496  int64 kernel_output_feature_dimension = 4;
497
498  // The dimension numbers for the spatial dimensions that the window
499  // moves through in the kernel (rhs). window.strides(0) is the
500  // stride in the kernel_spatial_dimensions(0) dimension.
501  repeated int64 kernel_spatial_dimensions = 6;
502
503  // The number of the dimension that represents batch in the output.
504  int64 output_batch_dimension = 9;
505
506  // The number of the dimension that represents features in the output.
507  int64 output_feature_dimension = 10;
508
509  // The dimension numbers for the spatial dimensions that the window
510  // moves through in the output.
511  repeated int64 output_spatial_dimensions = 12;
512
513  // Next = 13
514}
515
516enum FftType {
517  FFT = 0;    // Forward FFT; complex in, complex out.
518  IFFT = 1;   // Inverse FFT; complex in, complex out.
519  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
520  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
521              //                   fft_length real out
522}
523
524message DotDimensionNumbers {
525  // The dimension numbers that represent the 'lhs' contracting dimensions.
526  repeated int64 lhs_contracting_dimensions = 1;
527  // The dimension numbers that represent the 'rhs' contracting dimensions.
528  repeated int64 rhs_contracting_dimensions = 2;
529  // The dimension numbers that represent the 'lhs' batch dimensions.
530  repeated int64 lhs_batch_dimensions = 3;
531  // The dimension numbers that represent the 'rhs' batch dimensions.
532  repeated int64 rhs_batch_dimensions = 4;
533}
534
535enum RandomDistribution {
536  RNG_INVALID = 0;
537
538  // Creates a uniform-distribution-generated random number on the semi-open
539  // interval [parameter[0], parameter[1]).
540  RNG_UNIFORM = 1;
541
542  // Creates a normal-distribution-generated random number with mean
543  // parameter[0] and standard deviation parameter[1].
544  RNG_NORMAL = 2;
545
546  // Next: 4
547}
548
549message TriangularSolveOptions {
550  // If true, solves ax = b. If false, solves xa = b.
551  bool left_side = 1;
552
553  // If true, 'a' is lower triangular. If false, 'a' is upper triangular.
554  bool lower = 2;
555
556  // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed.
557  bool unit_diagonal = 3;
558
559  // Should we transpose or use the adjoint of 'a'?
560  enum Transpose {
561    TRANSPOSE_INVALID = 0;
562    NO_TRANSPOSE = 1;  // Don't transpose 'a'.
563    TRANSPOSE = 2;     // Transpose 'a'.
564    ADJOINT = 3;       // Complex conjugate and transpose 'a'.
565  };
566  Transpose transpose_a = 4;
567}
568
569message CholeskyOptions {
570  // If true, uses the lower triangle of `a`. If false, uses the upper triangle
571  // of `a`.
572  bool lower = 1;
573}
574
575message OpSharding {
576  enum Type {
577    // This sharding is replicated across all devices (implies maximal,
578    // all other fields are unused).
579    REPLICATED = 0;
580    // This sharding is maximal - one device runs the entire operation.
581    MAXIMAL = 1;
582    // This sharding is a tuple - only the tuple_shardings field is valid.
583    TUPLE = 2;
584    // None of the above; tile_shape and tile_assignment are both used.
585    OTHER = 3;
586  }
587  Type type = 1;
588  // The shape of the sharded tile.
589  ShapeProto tile_shape = 2;
590  // The shape of the tile assignment tensor - this must be the same rank as
591  // tile_shape and the product of its dimensions must equal
592  // tile_assignment_devices.size().
593  repeated int64 tile_assignment_dimensions = 3;
594  // Flattened list of device IDs. The order of flattening is the same as used
595  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
596  repeated int64 tile_assignment_devices = 4;
597  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
598  // in pre-order. The tuple shape could be nested; here we store just a
599  // flattened list of all leaves in the tuple shape. Note that the tuple shape
600  // is not stored here; shardings do not store the shapes to which they are
601  // applied, this is inferred from the instruction this sharding gets attached
602  // to.
603  repeated OpSharding tuple_shardings = 5;
604}
605
606// Describes the replica groups in a cross replica op (e.g., all-reduce and
607// all-to-all).
608message ReplicaGroup {
609  // The ids of the replicas that belongs to the same group. The ordering of the
610  // ids matters in some op (e.g., all-to-all).
611  repeated int64 replica_ids = 1;
612}
613
614// Describes the source target pair in the collective permute op.
615message SourceTarget {
616  int64 source = 1;
617  int64 target = 2;
618}
619
620// Used to indicate the precision configuration. It has backend specific
621// meaning.
622message PrecisionConfig {
623  enum Precision {
624    DEFAULT = 0;
625    HIGH = 1;
626    HIGHEST = 2;
627
628    // Next: 3
629  }
630  repeated Precision operand_precision = 1;
631
632  // Next: 2
633}
634
635// Describes whether all data-parallelism replicas will receive the same
636// parameter data at each buffer.
637message ParameterReplication {
638  // A list of boolean values for the flattened leaf buffers. Each value
639  // indicates whether the corresponding leaf buffer is replicated.
640  //
641  // If this field is empty, it means no buffer is replicated. Otherwise, the
642  // number of elements in this field must match the number of leaf buffers in
643  // the HLO instruction's shape.
644  repeated bool replicated_at_leaf_buffers = 1;
645}
646
647// A backend-config for kWhile loops that stores the loop's trip count, if it is
648// known.
649//
650// This is useful for backends that can implement a `for i in 0..N` loop more
651// efficiently than a `while` loop.  For example, on GPUs, we can implement a
652// `for i in 0..N` loop by enqueueing the kernels for the loop body N times,
653// whereas implementing a `while` loop requires a host-device sync on each
654// iteration.
655message WhileLoopBackendConfig {
656  message KnownTripCount {
657    int64 n = 1;
658  }
659  // This indirection lets us distinguish between known-trip-count == 0 and
660  // unknown-trip-count.
661  KnownTripCount known_trip_count = 1;
662}
663