1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16syntax = "proto3";
17
18package xla;
19option cc_enable_arenas = true;
20
21// Primitive types are the individual values that can be held in rectangular
22// multidimensional arrays. A description of the rectangular multidimensional
23// array dimensions / primitive type is given by Shape, below.
24enum PrimitiveType {
25  // Invalid primitive type to serve as default.
26  PRIMITIVE_TYPE_INVALID = 0;
27
28  // Predicates are two-state booleans.
29  PRED = 1;
30
31  // Signed integral values of fixed width.
32  S8 = 2;
33  S16 = 3;
34  S32 = 4;
35  S64 = 5;
36
37  // Unsigned integral values of fixed width.
38  U8 = 6;
39  U16 = 7;
40  U32 = 8;
41  U64 = 9;
42
43  // Floating-point values of fixed width.
44  //
45  // Note: if f16s are not natively supported on the device, they will be
46  // converted to f16 from f32 at arbirary points in the computation.
47  F16 = 10;
48  F32 = 11;
49
50  // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit
51  // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent
52  // and 7 bits for the mantissa.
53  BF16 = 16;
54
55  F64 = 12;
56
57  // Complex values of fixed width.
58  C64 = 15;  // Paired F32 (real, imag), as in std::complex<float>.
59
60  // A tuple is a polymorphic sequence; e.g. a shape that holds different
61  // sub-shapes. They are used for things like returning multiple values from a
62  // computation; e.g. a computation that returns weights and biases may have a
63  // signature that results in a tuple like (f32[784x2000], f32[2000])
64  //
65  // If a shape proto has the tuple element type, it may not have any entries
66  // in the dimensions field.
67  TUPLE = 13;
68
69  // An opaque type used for passing context specific data to a custom
70  // operation.
71  OPAQUE = 14;
72
73  // Next = 17
74}
75
76// Describes the value held inside padding elements.
77enum PaddingValue {
78  INVALID_PAD = 0;
79
80  // Zero padding must be 0-values that correspond to the shape's element type.
81  ZERO_PAD = 1;
82
83  // One padding must be 1-values that correspond to the shape's element type.
84  ONE_PAD = 2;
85
86  // "Lowest" padding must be the lowest values in the shape's element type,
87  // used as padding for operations like max-accumulation.
88  LOWEST_PAD = 3;
89
90  // "Highest" padding must be the largest values in the shape's element type,
91  // used as padding for operations like min-accumulation.
92  HIGHEST_PAD = 4;
93
94  // Unknown padding could be anything; e.g. floating NaNs!
95  UNKNOWN_PAD = 5;
96}
97
98// Describes the padding configuration for Pad operation. The padding amount on
99// both edges as well as between the elements are specified for each dimension.
100message PaddingConfig {
101  // Describes the padding configuration for a dimension.
102  message PaddingConfigDimension {
103    // Padding amount on the low-end (next to the index 0).
104    int64 edge_padding_low = 1;
105
106    // Padding amount on the high-end (next to the highest index).
107    int64 edge_padding_high = 2;
108
109    // Padding amount between the elements.
110    int64 interior_padding = 3;
111  }
112
113  // The padding configuration for all dimensions.
114  repeated PaddingConfigDimension dimensions = 1;
115}
116
117// A format specifies the method used by a layout to store an array in memory.
118enum Format {
119  INVALID_FORMAT = 0;
120  // The default layout, with exactly one storage location per element (ignoring
121  // padding).
122  DENSE = 1;
123  // A sparsely encoded layout, providing only the index/value pairs of non-zero
124  // elements.
125  SPARSE = 2;
126}
127
128// A layout describes how the array is placed in (1D) memory space.  This
129// includes the minor-to-major ordering of dimensions within a shape, as well as
130// any padding present in those dimensions.
131//
132// Clients must specify the layouts of input Literals to the
133// computation. Layouts specified in interior operations which take Shapes (for
134// example, Convert) are ignored.
135//
136// See the XLA documentation for more information on shapes and layouts.
137message Layout {
138  // The method used to store the data in memory. The format determines which of
139  // the other fields are used by the layout.
140  Format format = 4;
141
142  // Sequence of dimension numbers, from minor (fastest varying index) to major
143  // (slowest varying index). This field is required.
144  repeated int64 minor_to_major = 1;
145
146  // The width to which the layout of each dimension is padded up to. If
147  // present, the size of the padded_dimensions must equal the rank of the
148  // shape. The padding appears at the end of a dimension, not at the
149  // beginning. This kind of padding, unlike padding in e.g. convolution, is not
150  // part of the shape. This field must be unset unless the format is DENSE.
151  repeated int64 padded_dimensions = 2;
152
153  // Describes the values in the padding specified by padded_dimensions. This
154  // field must be unset unless the format is DENSE.
155  PaddingValue padding_value = 3;
156
157  // The maximum number of elements that can be stored for SPARSE formats.  This
158  // can be used to determine the maximum size in bytes of arrays stored in
159  // memory.  This field must be unset unless the format is SPARSE.
160  int64 max_sparse_elements = 5;
161
162  // Important: if any field is added, be sure to modify ShapeUtil::Equal()
163  // appropriately to account for the new field.
164}
165
166// A shape describes the number of dimensions in the array, the size of each
167// dimension, and the primitive component type.
168//
169// Tuples are a special case in that they have rank zero and have tuple_shapes
170// defined.
171//
172// See the XLA documentation for more information on shapes and layouts.
173message Shape {
174  reserved 1;
175  reserved "rank";
176
177  // The element type for this shape.
178  PrimitiveType element_type = 2;
179
180  // The size (number of elements) for each dimension.
181  // In XLA, dimensions are numbered from 0 to N-1 for an
182  // N-dimensional array. The first element of 'dimensions' is the size of
183  // dimension 0, the second element is the size of dimension 1, and so forth.
184  // Empty list indicates a scalar.
185  repeated int64 dimensions = 3;
186
187  // For tuples only, the shapes of constitutent shapes in the tuple sequence.
188  repeated Shape tuple_shapes = 4;
189
190  // The layout used to back this shape.
191  Layout layout = 5;
192
193  // Important: if any field is added, be sure to modify ShapeUtil::Equal() and
194  // ShapeUtil::Compatible() appropriately to account for the new field.
195}
196
197// Shape of the parameters and output of a computation (like a traditional
198// function signature).
199message ProgramShape {
200  repeated Shape parameters = 1;
201  Shape result = 2;
202  repeated string parameter_names = 3;
203}
204
205// Statistics of a computation.
206message ComputationStats {
207  // The number of floating point operations in the computation.
208  double flop_count = 1;
209
210  // The number of transcendental operations (e.g., exp) in the computation.
211  double transcendental_count = 2;
212}
213
214// Symbolization metadata for HLO Instructions.
215//
216// This metadata is used for debugging XLA code generation, as well as
217// performance profiling of XLA-generated executables.
218message OpMetadata {
219  // The framework op name that generated this XLA op.
220  //
221  // Frameworks that build on top of XLA should mirror the names of their ops
222  // back to users by specifying the op_type. In this way, even if the
223  // framework's "ops" are implemented as multiple XLA HLO Ops, they can be
224  // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as
225  // multiple ops, then each op should have the op_type be "SoftMax".)
226  string op_type = 1;
227  // The user-specified name of the op.
228  //
229  // This name is often unique within a computation. Note: some frameworks
230  // add auto-generated names if the user does not provide one.
231  string op_name = 2;
232  // Indicate a file and line that this op is associated to in a user's program.
233  //
234  // e.g. it could be the file and line of user code that generated the op.
235  string source_file = 3;
236  int32 source_line = 4;
237}
238
239// Profile data from the execution of a computation.
240message ExecutionProfile {
241  // Whether the executable was read from the compilation cache.
242  bool compilation_cache_hit = 1;
243
244  // The time in milliseconds spent to compile the computation. This only set if
245  // the executable was not read from the compilation cache
246  // (compilation_cache_hit == false).
247  int64 compile_time_ms = 2;
248
249  // The number of cycles spent for the computation. This does not include the
250  // time taken for the data transfers between the host and the device. This is
251  // a target-dependent field and only used for debugging purposes.
252  int64 compute_cycle_count = 3;
253
254  // The time in nanoseconds spent for the computation, without data transfer.
255  int64 compute_time_ns = 4;
256
257  // The time in nanoseconds spent for the entire computation, including the
258  // result data transfer time. Current implementation does not spend any cycles
259  // for the input data transfer since the memory is initialized with the proper
260  // values before the execution.
261  int64 compute_and_transfer_time_ns = 5;
262}
263
264// Handle given to a user that represents a computation that the user builds up
265// before execution.
266message ComputationHandle {
267  int64 handle = 1;
268}
269
270// Handle given to a user that represents an execution that the user launched
271// asynchronously on the device.
272message ExecutionHandle {
273  int64 handle = 1;
274}
275
276// Handle given to a user that represents a globally accessible allocation.
277// Contrast this against a ComputationDataHandle, which is not globally
278// accessible, since it only exists within a specific computation.
279message GlobalDataHandle {
280  int64 handle = 1;
281}
282
283// Handle given to a user that represents a data result in a computation.
284// This is used to pass to subsequent computations that depends upon the data as
285// an operand.
286message ComputationDataHandle {
287  int64 handle = 1;
288}
289
290// Handle given to a user that represents a replicated virtual device. Each
291// replicated device represents N physical devices for execution where N is the
292// number of replicas.
293message DeviceHandle {
294  int64 handle = 1;
295
296  // The number of model-parallel virtual devices that communicate via XLA
297  // Send/Recv instructions.
298  int64 device_count = 2;
299}
300
301// Handle given to a user to represent a channel between two computations
302// via a Send and Recv instruction pair. Channels are unbuffered, so Send
303// Send instructions will be blocked until the data is transferred.
304message ChannelHandle {
305  int64 handle = 1;
306}
307
308// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which
309// represents the device ids assigned to a set of replicated computations.
310// See xla::DeviceAssignment class comment for more details.
311message DeviceAssignmentProto {
312  int32 replica_count = 1;
313  int32 computation_count = 2;
314
315  // Each logical computation runs on replica_count physical devices.
316  // ComputationDevice represents the device ids assinged to the replicas.
317  message ComputationDevice {
318    repeated int32 replica_device_ids = 1;
319  }
320  repeated ComputationDevice computation_devices = 3;
321}
322
323// Literals are used when the server and client need to exchange materialized
324// data / results. Literals are also used to describe constants used in
325// computations.
326//
327// Transfers to/from the client are encoded in literal form, and the structure
328// of the repeated fields is implied by the shape.
329message LiteralProto {
330  Shape shape = 1;
331  repeated bool preds = 2;
332  bytes u8s = 3;
333  repeated int32 s32s = 4;
334  repeated int64 s64s = 5;
335  repeated uint32 u32s = 6;
336  repeated uint64 u64s = 7;
337  repeated float f32s = 8;
338  repeated double f64s = 9;
339  repeated float c64s = 12;  // Stored as interleaved real, imag floats.
340  repeated LiteralProto tuple_literals = 10;
341  // The F16s and BF16s are encoded in little endian byte order
342  bytes f16s = 11;
343  bytes bf16s = 13;
344  repeated int64 sparse_indices = 14;
345  // Next = 15
346}
347
348message WindowDimension {
349  // The size of the window in this dimension. For a rectangle, this would be
350  // the width or height.
351  int64 size = 1;
352
353  // The stride at which the window moves across the base area in this
354  // dimension. In other words, this is the spacing between different
355  // positions of the window in this dimension.
356  int64 stride = 2;
357
358  // If positive, means the amount of padding with zeroes to add to the base
359  // area at the low end of this dimension; if negative, its negative means the
360  // number of elements removed from the low end of this dimension. For example,
361  // in the horizontal dimension of a rectangle, this would be the number of
362  // zeroes to pad on the left, given that indices increase when going right.
363  int64 padding_low = 3;
364
365  // As padding_low, but on the high end of this dimension. For
366  // example, in the horizontal dimension of a rectangle, this would
367  // be the number of zeroes to pad on the right, given that indices
368  // increase when going right.
369  int64 padding_high = 4;
370
371  // Dilation factor of the sliding window in this dimension. A dilation factor
372  // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are
373  // implicitly placed between each kernel element. See documentation for
374  // convolution.
375  int64 window_dilation = 5;
376
377  // Dilation factor of the base area in this dimension. A dilation factor of 1
378  // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly
379  // placed between each base area element. See documentation for convolution.
380  int64 base_dilation = 6;
381
382  // Window reversal means that this dimension was logically reversed before the
383  // operation.
384  bool window_reversal = 7;
385}
386
387// Describes the windowing in an operation such as convolution.
388//
389// The window is moved across a base area and for each position of the
390// window a computation is performed. The field below describes the
391// window and the movement of the window across a base area.
392message Window {
393  repeated WindowDimension dimensions = 1;
394}
395
396// Describes the dimension numbers for a gather operation.
397//
398// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for
399// more details.
400message GatherDimensionNumbers {
401  // "Window indices" is a term for a set of indices that index into the
402  // interior of a dynamic-slice from the input tensor, the starting indices for
403  // which were computed from output_gather_dims (see the operation semantic for
404  // how this is defined) and the gather_indices tensor.
405  //
406  // The window indices for a specific output index Out is computed as:
407  //
408  //  i = 0
409  //  for (k : [0, input_tensor_shape.rank))
410  //    window_indices[k] =
411  //      if k in elided_window_dims
412  //      then 0
413  //      else Out[output_window_dims[i++]]
414  repeated int64 output_window_dims = 1;
415  repeated int64 elided_window_dims = 2;
416
417  // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It
418  // transforms the gather index looked up from the gather_indices tensor into
419  // the starting index in the input space.
420  repeated int64 gather_dims_to_operand_dims = 3;
421}
422
423// Operation requests that are all collected as a tagged union with a oneof
424// field in OpRequest.
425
426message ConstantRequest {
427  LiteralProto literal = 2;
428}
429
430message GetTupleElementRequest {
431  ComputationDataHandle operand = 2;
432  int64 index = 3;
433}
434
435message SliceRequest {
436  ComputationDataHandle operand = 2;
437  repeated int64 start_indices = 3;
438  repeated int64 limit_indices = 4;
439  repeated int64 strides = 5;
440}
441
442message DynamicSliceRequest {
443  // Operand from which to slice at dynamic 'start_indices'.
444  ComputationDataHandle operand = 2;
445  // Dynamically computed 'start_indices' for slice operation.
446  ComputationDataHandle start_indices = 3;
447  // Slice sizes for each dimension (note that indices calculations are computed
448  // modulo dimension sizes to avoid out-of-bound array accesses).
449  repeated int64 slice_sizes = 4;
450}
451
452message DynamicUpdateSliceRequest {
453  // Operand on which slice 'update' is to be applied.
454  ComputationDataHandle operand = 2;
455  // The slice update to apply to 'operand'.
456  ComputationDataHandle update = 3;
457  // Dynamically computed start indices for the update slice operation.
458  ComputationDataHandle start_indices = 4;
459}
460
461message ConvolutionDimensionNumbers {
462  // The number of the dimension that represents batch in the input.
463  int64 input_batch_dimension = 7;
464
465  // The number of the dimension that represents features in the input.
466  int64 input_feature_dimension = 8;
467
468  // The dimension numbers for the spatial dimensions that the window
469  // moves through in the input.
470  repeated int64 input_spatial_dimensions = 11;
471
472  // The number of the dimension that represents input features in the
473  // convolutional kernel (rhs).
474  int64 kernel_input_feature_dimension = 3;
475
476  // The number of the dimension that represents output features in
477  // the convolutional kernel (rhs).
478  int64 kernel_output_feature_dimension = 4;
479
480  // The dimension numbers for the spatial dimensions that the window
481  // moves through in the kernel (rhs). window.strides(0) is the
482  // stride in the kernel_spatial_dimensions(0) dimension.
483  repeated int64 kernel_spatial_dimensions = 6;
484
485  // The number of the dimension that represents batch in the output.
486  int64 output_batch_dimension = 9;
487
488  // The number of the dimension that represents features in the output.
489  int64 output_feature_dimension = 10;
490
491  // The dimension numbers for the spatial dimensions that the window
492  // moves through in the output.
493  repeated int64 output_spatial_dimensions = 12;
494
495  // Next = 13
496};
497
498message ConvolveRequest {
499  ComputationDataHandle lhs = 2;
500  ComputationDataHandle rhs = 3;  // This is the filter/kernel.
501  Window window = 4;              // Describes the filter/kernel.
502  ConvolutionDimensionNumbers dimension_numbers = 5;
503}
504
505enum FftType {
506  FFT = 0;    // Forward FFT; complex in, complex out.
507  IFFT = 1;   // Inverse FFT; complex in, complex out.
508  RFFT = 2;   // Forward real FFT; real in, fft_length / 2 + 1 complex out
509  IRFFT = 3;  // Inverse real FFT; fft_length / 2 + 1 complex in,
510              //                   fft_length real out
511}
512
513message FftRequest {
514  FftType fft_type = 1;
515  repeated int64 fft_length = 2;  // Multivalent for higher-order FFT.
516  ComputationDataHandle operand = 3;
517}
518
519message InfeedRequest {
520  // The shape of the data returned by reading the device's infeed buffer.
521  Shape shape = 2;
522
523  // Additional infeed configuration for the backend.
524  bytes config = 3;
525}
526
527message OutfeedRequest {
528  // The shape of the data returned by reading the device's outfeed buffer.
529  Shape shape = 1;
530
531  // Operand to the Outfeed. Supports tuple.
532  ComputationDataHandle operand = 2;
533
534  // Backend-specific information for how to perform the outfeed.
535  bytes outfeed_config = 3;
536}
537
538message CallRequest {
539  ComputationHandle to_apply = 2;
540  repeated ComputationDataHandle operands = 3;
541}
542
543message CustomCallRequest {
544  string call_target_name = 2;
545  repeated ComputationDataHandle operands = 3;
546  Shape shape = 4;
547}
548
549message HostComputeRequest {
550  // Operand to the HostCompute. Supports tuple.
551  repeated ComputationDataHandle operands = 1;
552
553  // Name used to identify HostSend/Recv channels.
554  string channel_name = 2;
555
556  // Cost estimate in nanoseconds.
557  int64 cost_estimate_ns = 3;
558
559  // The shape of any data returned by host.
560  Shape shape = 4;
561}
562
563message DotDimensionNumbers {
564  // The dimension numbers that represent the 'lhs' contracting dimensions.
565  repeated int64 lhs_contracting_dimensions = 1;
566  // The dimension numbers that represent the 'rhs' contracting dimensions.
567  repeated int64 rhs_contracting_dimensions = 2;
568  // The dimension numbers that represent the 'lhs' batch dimensions.
569  repeated int64 lhs_batch_dimensions = 3;
570  // The dimension numbers that represent the 'rhs' batch dimensions.
571  repeated int64 rhs_batch_dimensions = 4;
572};
573
574message DotRequest {
575  ComputationDataHandle lhs = 2;
576  ComputationDataHandle rhs = 3;
577  DotDimensionNumbers dimension_numbers = 4;
578}
579
580message MapRequest {
581  repeated ComputationDataHandle operands = 2;
582  ComputationHandle to_apply = 3;
583  repeated ComputationDataHandle static_operands = 4;
584  // The dimensions over which to map.
585  // Example mapping a Dot operation along the batch dimension 0:
586  //   operand0.shape = [2, 2, 2], operand1.shape = [2,2,3]
587  //   Map({operand0, operand1}, Dot, {0})
588  repeated int64 dimensions = 5;
589}
590
591message ReduceRequest {
592  // Operand to the reduction.
593  ComputationDataHandle operand = 2;
594
595  // Initial value for the reduction. This must be consistent with the result
596  // shape of to_apply.
597  ComputationDataHandle init_value = 3;
598
599  // The dimensions to reduce over.
600  repeated int64 dimensions = 4;
601
602  // The computation to apply in the reduction.
603  ComputationHandle to_apply = 5;
604}
605
606message ReduceWindowRequest {
607  ComputationDataHandle operand = 2;
608  ComputationDataHandle init_value = 3;
609  Window window = 4;
610  ComputationHandle to_apply = 5;
611}
612
613message BatchNormTrainingRequest {
614  ComputationDataHandle operand = 1;
615  ComputationDataHandle scale = 2;
616  ComputationDataHandle offset = 3;
617  float epsilon = 4;
618  int64 feature_index = 5;
619}
620
621message BatchNormInferenceRequest {
622  ComputationDataHandle operand = 1;
623  ComputationDataHandle scale = 2;
624  ComputationDataHandle offset = 3;
625  ComputationDataHandle mean = 4;
626  ComputationDataHandle variance = 5;
627  float epsilon = 6;
628  int64 feature_index = 7;
629}
630
631message BatchNormGradRequest {
632  ComputationDataHandle operand = 1;
633  ComputationDataHandle scale = 2;
634  ComputationDataHandle mean = 3;
635  ComputationDataHandle variance = 4;
636  ComputationDataHandle grad_output = 5;
637  float epsilon = 6;
638  int64 feature_index = 7;
639}
640
641message CrossReplicaSumRequest {
642  ComputationDataHandle operand = 2;
643}
644
645message SelectAndScatterRequest {
646  // Operand array on which the windows slide.
647  ComputationDataHandle operand = 2;
648
649  // Source array for the data to scatter.
650  ComputationDataHandle source = 3;
651
652  // Initial scalar value for each element in the output.
653  ComputationDataHandle init_value = 4;
654
655  // Window configuration.
656  Window window = 5;
657
658  // Binary function used to select an element from each window.
659  ComputationHandle select = 6;
660
661  // Binary function used to combine each scattered value from source with the
662  // current output value at the selected location.
663  ComputationHandle scatter = 7;
664}
665
666message ReverseRequest {
667  ComputationDataHandle operand = 2;
668  repeated int64 dimensions = 3;
669}
670
671message BroadcastRequest {
672  ComputationDataHandle operand = 2;
673  repeated int64 broadcast_sizes = 3;
674}
675
676message PadRequest {
677  ComputationDataHandle operand = 2;
678  ComputationDataHandle padding_value = 3;
679  PaddingConfig padding_config = 4;
680}
681
682message ReshapeRequest {
683  ComputationDataHandle operand = 2;
684
685  // The dimension order for collapse (from fastest-changing to slowest).
686  repeated int64 dimensions = 3;
687
688  // The new dimension sizes (from dimension 0 to n-1).
689  repeated int64 new_sizes = 4;
690}
691
692message TransposeRequest {
693  ComputationDataHandle operand = 2;
694
695  // The permutation of the operand's dimensions (in the range 0 to n-1).
696  repeated int64 dimensions = 3;
697}
698
699message ParameterRequest {
700  Shape shape = 2;
701  int64 parameter = 3;
702  string name = 4;
703}
704
705message GetLocalShapeRequest {
706  ComputationHandle computation = 1;
707  ComputationDataHandle operand = 2;
708}
709
710message GetLocalShapeResponse {
711  Shape shape = 1;
712}
713
714message TraceRequest {
715  string tag = 2;
716  ComputationDataHandle operand = 3;
717}
718
719message ConvertRequest {
720  ComputationDataHandle operand = 2;
721  PrimitiveType new_element_type = 3;
722}
723
724message ConcatenateRequest {
725  repeated ComputationDataHandle operands = 2;
726  // The dimension in which we concatenate; e.g. if you had dimension arrays of
727  // [4, 1] and [5, 1], you'd concatenate in dimension 0 to produce a [9, 1].
728  // Attempting to concatenate those in dimension 1 would produce an error, as
729  // 4 != 5 (and there is no ragged array support).
730  int64 dimension = 3;
731}
732
733message ConditionalRequest {
734  ComputationDataHandle predicate = 2;
735  ComputationDataHandle true_operand = 3;
736  ComputationHandle true_computation = 4;
737  ComputationDataHandle false_operand = 5;
738  ComputationHandle false_computation = 6;
739}
740
741message WhileRequest {
742  ComputationHandle condition = 2;
743  ComputationHandle body = 3;
744  ComputationDataHandle init = 4;
745}
746
747enum UnaryOperation {
748  UNOP_INVALID = 0;
749
750  // Elementwise, logical negation on booleans and bitwise negation on ints.
751  UNOP_NOT = 1;
752
753  // Elementwise, computes e^x.
754  UNOP_EXP = 2;
755
756  // Elementwise, computes -x.
757  UNOP_NEGATE = 3;
758
759  // Puts the elements in the operand into sorted order.
760  UNOP_SORT = 4;
761
762  // Elementwise, computes tanh(x).
763  UNOP_TANH = 5;
764
765  // Elementwise, computes the natural logarithm of x.
766  UNOP_LOG = 6;
767
768  // Elementwise, computes the floor of x.
769  UNOP_FLOOR = 7;
770
771  // Elementwise, computes the ceil of x.
772  UNOP_CEIL = 8;
773
774  // Elementwise, computes the abs of x.
775  UNOP_ABS = 9;
776
777  // Elementwise, computes the sign of x.
778  UNOP_SIGN = 10;
779
780  // Elementwise, tests if values are finite (not NaN or inf)
781  UNOP_IS_FINITE = 11;
782
783  // Elementwise, computes the cosine of x.
784  UNOP_COS = 12;
785
786  // Elementwise, computes the sine of x.
787  UNOP_SIN = 13;
788
789  // Elementwise, rounds x to nearest integral value, rounding half-way cases
790  // away from zero.
791  UNOP_ROUND_NEAREST_AFZ = 14;
792
793  // Elementwise, extract real component of complex x.
794  UNOP_REAL = 15;
795
796  // Elementwise, extract real component of complex x.
797  UNOP_IMAG = 16;
798}
799
800message UnaryOpRequest {
801  UnaryOperation unop = 2;
802  ComputationDataHandle operand = 3;
803}
804
805enum BinaryOperation {
806  BINOP_INVALID = 0;
807
808  // Arithmetic operations.
809  BINOP_ADD = 1;
810  BINOP_DIV = 2;
811  BINOP_MUL = 3;
812  BINOP_SUB = 4;
813
814  // Comparison operators.
815  BINOP_EQ = 5;
816  BINOP_GE = 6;
817  BINOP_GT = 7;
818  BINOP_LE = 8;
819  BINOP_LT = 9;
820  BINOP_NE = 10;
821
822  // Element-wise maximum.
823  BINOP_MAX = 14;
824
825  // Element-wise minimum.
826  BINOP_MIN = 15;
827
828  // Raises the left-hand-side to the right-hand-side power.
829  BINOP_POW = 16;
830
831  // Remainder operation.
832  BINOP_REM = 17;
833
834  // Element-wise, logical operators on booleans and bitwise operators on ints.
835  BINOP_AND = 18;
836  BINOP_OR = 19;
837
838  BINOP_SHIFT_LEFT = 20;
839  BINOP_SHIFT_RIGHT_ARITHMETIC = 21;
840  BINOP_SHIFT_RIGHT_LOGICAL = 22;
841
842  // Complex from real, imag.
843  BINOP_COMPLEX = 23;
844
845  // Computes the 4-quadrant arctangent of the y, x input arguments.
846  BINOP_ATAN2 = 24;
847}
848
849message BinaryOpRequest {
850  BinaryOperation binop = 2;
851  ComputationDataHandle lhs = 3;
852  ComputationDataHandle rhs = 4;
853  repeated int64 broadcast_dimensions = 5;
854}
855
856enum RandomDistribution {
857  RNG_INVALID = 0;
858
859  // Creates a uniform-distribution-generated random number on the semi-open
860  // interval [parameter[0], parameter[1]).
861  RNG_UNIFORM = 1;
862
863  // Creates a normal-distribution-generated random number with mean
864  // parameter[0] and standard deviation parameter[1].
865  RNG_NORMAL = 2;
866
867  // Next: 4
868}
869
870message RngRequest {
871  RandomDistribution distribution = 2;
872  repeated ComputationDataHandle parameter = 3;
873  Shape shape = 4;
874}
875
876enum TernaryOperation {
877  TRIOP_INVALID = 0;
878
879  // Given a predicate and two operands, selects operand0 if the predicate is
880  // true and operand1 if the predicate is false.
881  TRIOP_SELECT = 1;
882
883  // Given a min, max and an operand returns the operand if between min and max,
884  // else returns min if operand is less than min or max if operand is greater
885  // than max.
886  TRIOP_CLAMP = 3;
887}
888
889message TernaryOpRequest {
890  TernaryOperation triop = 2;
891  ComputationDataHandle lhs = 3;
892  ComputationDataHandle rhs = 4;
893  ComputationDataHandle ehs = 5;
894}
895
896enum VariadicOperation {
897  VAROP_INVALID = 0;
898
899  // Creates a tuple from its operands.
900  VAROP_TUPLE = 1;
901}
902
903message VariadicOpRequest {
904  VariadicOperation varop = 2;
905  repeated ComputationDataHandle operands = 3;
906}
907
908message ReducePrecisionRequest {
909  ComputationDataHandle operand = 1;
910  int32 exponent_bits = 2;
911  int32 mantissa_bits = 3;
912}
913
914message SendRequest {
915  ComputationDataHandle operand = 1;
916  ChannelHandle channel_handle = 2;
917}
918
919message RecvRequest {
920  Shape shape = 1;
921  ChannelHandle channel_handle = 2;
922}
923
924message GatherRequest {
925  ComputationDataHandle input = 1;
926  ComputationDataHandle gather_indices = 2;
927  GatherDimensionNumbers dimension_numbers = 3;
928  repeated int64 window_bounds = 4;
929}
930
931message OpSharding {
932  enum Type {
933    // This sharding is replicated across all devices (implies maximal,
934    // all other fields are unused).
935    REPLICATED = 0;
936    // This sharding is maximal - one device runs the entire operation.
937    MAXIMAL = 1;
938    // This sharding is a tuple - only the tuple_shardings field is valid.
939    TUPLE = 2;
940    // None of the above; tile_shape and tile_assignment are both used.
941    OTHER = 3;
942  }
943  Type type = 1;
944  // The shape of the sharded tile.
945  Shape tile_shape = 2;
946  // The shape of the tile assignment tensor - this must be the same rank as
947  // tile_shape and the product of its dimensions must equal
948  // tile_assignment_devices.size().
949  repeated int64 tile_assignment_dimensions = 3;
950  // Flattened list of device IDs. The order of flattening is the same as used
951  // by IndexUtil::MultiToLinearIndex(tile_assignment_shape).
952  repeated int64 tile_assignment_devices = 4;
953  // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape,
954  // in pre-order. The tuple shape could be nested; here we store just a
955  // flattened list of all leaves in the tuple shape. Note that the tuple shape
956  // is not stored here; shardings do not store the shapes to which they are
957  // applied, this is inferred from the instruction this sharding gets attached
958  // to.
959  repeated OpSharding tuple_shardings = 5;
960}
961
962message OpRequest {
963  ComputationHandle computation = 1;
964  OpMetadata metadata = 33;
965  OpSharding sharding = 40;
966
967  oneof op {
968    BinaryOpRequest binary_op_request = 2;
969    BroadcastRequest broadcast_request = 3;
970    CallRequest call_request = 4;
971    ConcatenateRequest concatenate_request = 5;
972    ConstantRequest constant_request = 6;
973    ConvertRequest convert_request = 7;
974    ConvolveRequest convolve_request = 8;
975    CrossReplicaSumRequest cross_replica_sum_request = 9;
976    CustomCallRequest custom_call_request = 10;
977    DotRequest dot_request = 43;
978    DynamicSliceRequest dynamic_slice_request = 11;
979    DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
980    GetTupleElementRequest get_tuple_element_request = 13;
981    InfeedRequest infeed_request = 14;
982    MapRequest map_request = 15;
983    PadRequest pad_request = 16;
984    ParameterRequest parameter_request = 17;
985    ReducePrecisionRequest reduce_precision_request = 36;
986    ReduceRequest reduce_request = 18;
987    ReduceWindowRequest reduce_window_request = 19;
988    ReshapeRequest reshape_request = 20;
989    ReverseRequest reverse_request = 21;
990    RngRequest rng_request = 22;
991    SelectAndScatterRequest select_and_scatter_request = 23;
992    SliceRequest slice_request = 24;
993    TernaryOpRequest ternary_op_request = 25;
994    TraceRequest trace_request = 26;
995    TransposeRequest transpose_request = 34;
996    UnaryOpRequest unary_op_request = 27;
997    VariadicOpRequest variadic_op_request = 28;
998    WhileRequest while_request = 29;
999    SendRequest send_request = 30;
1000    RecvRequest recv_request = 31;
1001    OutfeedRequest outfeed_request = 32;
1002    BatchNormTrainingRequest batch_norm_training_request = 35;
1003    BatchNormGradRequest batch_norm_grad_request = 37;
1004    BatchNormInferenceRequest batch_norm_inference_request = 38;
1005    FftRequest fft_request = 41;
1006    ConvertRequest bitcast_convert_request = 42;
1007    ConditionalRequest conditional_request = 44;
1008    HostComputeRequest host_compute_request = 45;
1009    GatherRequest gather_request = 46;
1010    // Next: 47
1011  }
1012}
1013
1014message OpResponse {
1015  ComputationDataHandle output = 1;
1016}
1017