• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Operation Semantics
2
3The following describes the semantics of operations defined in the
4[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
5interface. Typically, these operations map one-to-one to operations defined in
6the RPC interface in
7[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
8
9A note on nomenclature: the generalized data type XLA deals with is an
10N-dimensional array holding elements of some uniform type (such as 32-bit
11float). Throughout the documentation, *array* is used to denote an
12arbitrary-dimensional array. For convenience, special cases have more specific
13and familiar names; for example a *vector* is a 1-dimensional array and a
14*matrix* is a 2-dimensional array.
15
16## AfterAll
17
18See also
19[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
20
21AfterAll takes a variadic number of tokens and produces a single token. Tokens
22are primitive types which can be threaded between side-effecting operations to
23enforce ordering. `AfterAll` can be used as a join of tokens for ordering a
24operation after a set operations.
25
26<b> `AfterAll(operands)` </b>
27
28Arguments  | Type    | Semantics
29---------- | ------- | -------------------------
30`operands` | `XlaOp` | variadic number of tokens
31
32## AllGather
33
34See also
35[`XlaBuilder::AllGather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
36
37Performs concatenation across replicas.
38
39<b> `AllGather(operand, all_gather_dim, shard_count, replica_group_ids,
40channel_id)` </b>
41
42| Arguments        | Type                 | Semantics                   |
43| ---------------- | -------------------- | --------------------------- |
44| `operand`        | `XlaOp`              | Array to concatenate across |
45:                  :                      : replicas.                   :
46| `all_gather_dim` | `int64`              | Concatenation dimension.    |
47| `replica_groups` | vector of vectors of | Groups between which the    |
48:                  : `int64`              : concatenation is performed. :
49| `channel_id`     | optional `int64`     | Optional channel ID for     |
50:                  :                      : cross-module communication. :
51
52-   `replica_groups` is a list of replica groups between which the concatenation
53    is performed (replica id for the current replica can be retrieved using
54    [`ReplicaId`](#replicaid)). The order of replicas in each group determines
55    the order in which their inputs are located in the result. `replica_groups`
56    must either be empty (in which case all replicas belong to a single group,
57    ordered from `0` to `N - 1`), or contain the same number of elements as the
58    number of replicas. For example, `replica_groups = {0, 2}, {1, 3}` performs
59    concatenation between the replicas `0` and `2`, and `1` and `3`.
60-   `shard_count` is the size of each replica group. We need this in cases where
61    `replica_groups` are empty.
62-   `channel_id` is used for cross-module communication: only `all-gather`
63    operations with the same `channel_id` can communicate to each other.
64
65The output shape is the input shape with the `all_gather_dim` made `shard_count`
66times larger. For example, if there are two replicas and the operand has the
67value `[1.0, 2.5]` and `[3.0, 5.25]` respectively on the two replicas, then the
68output value from this op where `all_gather_dim` is `0` will be `[1.0, 2.5, 3.0,
695.25]` on both replicas.
70
71## AllReduce
72
73See also
74[`XlaBuilder::AllReduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
75
76Performs a custom computation across replicas.
77
78<b> `AllReduce(operand, computation, replica_group_ids, channel_id)` </b>
79
80| Arguments        | Type                 | Semantics                         |
81| ---------------- | -------------------- | --------------------------------- |
82| `operand`        | `XlaOp`              | Array or a non-empty tuple of     |
83:                  :                      : arrays to reduce across replicas. :
84| `computation`    | `XlaComputation`     | Reduction computation             |
85| `replica_groups` | vector of vectors of | Groups between which the          |
86:                  : `int64`              : reductions are performed          :
87| `channel_id`     | optional `int64`     | Optional channel ID for           |
88:                  :                      : cross-module communication        :
89
90-   When `operand` is a tuple of arrays, the all-reduce is performed on each
91    element of the tuple.
92-   `replica_groups` is a list of replica groups between which the reduction is
93    performed (replica id for the current replica can be retrieved using
94    [`ReplicaId`](#replicaid)). `replica_groups` must either be empty (in which
95    case all replicas belong to a single group), or contain the same number of
96    elements as the number of replicas. For example, `replica_groups = {0, 2},
97    {1, 3}` performs reduction between the replicas `0` and `2`, and `1` and
98    `3`.
99-   `channel_id` is used for cross-module communication: only `all-reduce`
100    operations with the same `channel_id` can communicate to each other.
101
102The output shape is the same as the input shape. For example, if there are two
103replicas and the operand has the value `[1.0, 2.5]` and `[3.0, 5.25]`
104respectively on the two replicas, then the output value from this op and
105summation computation will be `[4.0, 7.75]` on both replicas. If the input is a
106tuple, the output is a tuple as well.
107
108Computing the result of `AllReduce` requires having one input from each replica,
109so if one replica executes a `AllReduce` node more times than another, then the
110former replica will wait forever. Since the replicas are all running the same
111program, there are not a lot of ways for that to happen, but it is possible when
112a while loop's condition depends on data from infeed and the data that is infed
113causes the while loop to iterate more times on one replica than another.
114
115## AllToAll
116
117See also
118[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
119
120AllToAll is a collective operation that sends data from all cores to all cores.
121It has two phases:
122
1231.  The scatter phase. On each core, the operand is split into `split_count`
124    number of blocks along the `split_dimensions`, and the blocks are scattered
125    to all cores, e.g., the ith block is send to the ith core.
1262.  The gather phase. Each core concatenates the received blocks along the
127    `concat_dimension`.
128
129The participating cores can be configured by:
130
131-   `replica_groups`: each ReplicaGroup contains a list of replica id
132    participating in the computation (replica id for the current replica can be
133    retrieved using [`ReplicaId`](#replicaid)). AllToAll will be applied within
134    subgroups in the specified order. For example, `replica_groups = {{1,2,3},
135    {4,5,0}}` means that an AllToAll will be applied within replicas `{1, 2,
136    3}`, and in the gather phase, and the received blocks will be concatenated
137    in the same order of 1, 2, 3. Then, another AllToAll will be applied within
138    replicas 4, 5, 0, and the concatenation order is also 4, 5, 0. If
139    `replica_groups` is empty, all replicas belong to one group, in the
140    concatenation order of their appearance.
141
142Prerequisites:
143
144-   The dimension size of the operand on the `split_dimension` is divisible by
145`split_count`.
146-   The operand's shape is not tuple.
147
148<b> `AllToAll(operand, split_dimension, concat_dimension, split_count,
149replica_groups)` </b>
150
151
152| Arguments          | Type                  | Semantics                       |
153| ------------------ | --------------------- | ------------------------------- |
154| `operand`          | `XlaOp`               | n dimensional input array       |
155| `split_dimension`  | `int64`               | A value in the interval `[0,    |
156:                    :                       : n)` that names the dimension    :
157:                    :                       : along which the operand is      :
158:                    :                       : split                           :
159| `concat_dimension` | `int64`               | a value in the interval `[0,    |
160:                    :                       : n)` that names the dimension    :
161:                    :                       : along which the split blocks    :
162:                    :                       : are concatenated                :
163| `split_count`      | `int64`               | the number of cores that        |
164:                    :                       : participate this operation. If  :
165:                    :                       : `replica_groups` is empty, this :
166:                    :                       : should be the number of         :
167:                    :                       : replicas; otherwise, this       :
168:                    :                       : should be equal to the number   :
169:                    :                       : of replicas in each group.      :
170| `replica_groups`   | `ReplicaGroup` vector | each group contains a list of   |
171:                    :                       : replica id.                     :
172
173Below shows an example of Alltoall.
174
175```
176XlaBuilder b("alltoall");
177auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
178AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
179```
180
181<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
182<img style="width:100%" src="./images/ops_alltoall.png">
183</div>
184
185In this example, there are 4 cores participating the Alltoall. On each core, the
186operand is split into 4 parts along dimension 0, so each part has shape
187f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates
188the received parts along dimension 1, in the order or core 0-4. So the output on
189each core has shape f32[16,4].
190
191## BatchNormGrad
192
193See also
194[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
195and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
196for a detailed description of the algorithm.
197
198Calculates gradients of batch norm.
199
200<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b>
201
202| Arguments       | Type                    | Semantics                        |
203| --------------- | ----------------------- | -------------------------------- |
204| `operand`       | `XlaOp`                 | n dimensional array to be        |
205:                 :                         : normalized (x)                   :
206| `scale`         | `XlaOp`                 | 1 dimensional array              |
207:                 :                         : (\\(\gamma\\))                   :
208| `mean`          | `XlaOp`                 | 1 dimensional array (\\(\mu\\))  |
209| `variance`      | `XlaOp`                 | 1 dimensional array              |
210:                 :                         : (\\(\sigma^2\\))                 :
211| `grad_output`   | `XlaOp`                 | Gradients passed to              |
212:                 :                         : `BatchNormTraining`              :
213:                 :                         : (\\( \nabla y\\))                :
214| `epsilon`       | `float`                 | Epsilon value (\\(\epsilon\\))   |
215| `feature_index` | `int64`                 | Index to feature dimension in    |
216:                 :                         : `operand`                        :
217
218For each feature in the feature dimension (`feature_index` is the index for the
219feature dimension in `operand`), the operation calculates the gradients with
220respect to `operand`, `offset` and `scale` across all the other dimensions. The
221`feature_index` must be a valid index for the feature dimension in `operand`.
222
223The three gradients are defined by the following formulas (assuming a
2244-dimensional array as `operand` and with feature dimension index `l`, batch
225size `m` and spatial sizes `w` and `h`):
226
227\\[ \begin{split} c_l&=
228\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h
229\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right)
230\\\\
231\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}}
232\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l})
233\right)
234\\\\
235\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl}
236\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right)
237\\\\\
238\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl}
239\end{split} \\]
240
241The inputs `mean` and `variance` represent moments value
242across batch and spatial dimensions.
243
244The output type is a tuple of three handles:
245
246| Outputs        | Type                    | Semantics                         |
247| -------------  | ----------------------- | --------------------------------- |
248| `grad_operand` | `XlaOp`                 | gradient with respect to input    |
249:                :                         : `operand` (\\( \nabla x\\))       :
250| `grad_scale`   | `XlaOp`                 | gradient with respect to input    |
251:                :                         : `scale` (\\( \nabla \gamma\\))    :
252| `grad_offset`  | `XlaOp`                 | gradient with respect to input    |
253:                :                         : `offset`(\\( \nabla \beta\\))     :
254
255## BatchNormInference
256
257See also
258[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
259and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
260for a detailed description of the algorithm.
261
262Normalizes an array across batch and spatial dimensions.
263
264<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b>
265
266Arguments       | Type    | Semantics
267--------------- | ------- | ---------------------------------------
268`operand`       | `XlaOp` | n dimensional array to be normalized
269`scale`         | `XlaOp` | 1 dimensional array
270`offset`        | `XlaOp` | 1 dimensional array
271`mean`          | `XlaOp` | 1 dimensional array
272`variance`      | `XlaOp` | 1 dimensional array
273`epsilon`       | `float` | Epsilon value
274`feature_index` | `int64` | Index to feature dimension in `operand`
275
276For each feature in the feature dimension (`feature_index` is the index for the
277feature dimension in `operand`), the operation calculates the mean and variance
278across all the other dimensions and uses the mean and variance to normalize each
279element in `operand`. The `feature_index` must be a valid index for the feature
280dimension in `operand`.
281
282`BatchNormInference`  is equivalent to calling `BatchNormTraining` without
283computing `mean` and `variance` for each batch. It uses the input `mean` and
284`variance` instead as estimated values. The purpose of this op is to reduce
285latency in inference, hence the name `BatchNormInference`.
286
287The output is an n-dimensional, normalized array with the same shape as input
288`operand`.
289
290## BatchNormTraining
291
292See also
293[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
294and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
295for a detailed description of the algorithm.
296
297Normalizes an array across batch and spatial dimensions.
298
299<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b>
300
301Arguments       | Type    | Semantics
302--------------- | ------- | ----------------------------------------
303`operand`       | `XlaOp` | n dimensional array to be normalized (x)
304`scale`         | `XlaOp` | 1 dimensional array (\\(\gamma\\))
305`offset`        | `XlaOp` | 1 dimensional array (\\(\beta\\))
306`epsilon`       | `float` | Epsilon value (\\(\epsilon\\))
307`feature_index` | `int64` | Index to feature dimension in `operand`
308
309For each feature in the feature dimension (`feature_index` is the index for the
310feature dimension in `operand`), the operation calculates the mean and variance
311across all the other dimensions and uses the mean and variance to normalize each
312element in `operand`. The `feature_index` must be a valid index for the feature
313dimension in `operand`.
314
315The algorithm goes as follows for each batch in `operand` \\(x\\) that
316contains `m` elements with `w` and `h` as the size of spatial dimensions
317(assuming `operand` is an 4 dimensional array):
318
319- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension:
320\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\)
321
322- Calculates batch variance \\(\sigma^2_l\\):
323\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\)
324
325- Normalizes, scales and shifts:
326\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\)
327
328The epsilon value, usually a small number, is added to avoid divide-by-zero errors.
329
330The output type is a tuple of three `XlaOp`s:
331
332| Outputs      | Type                    | Semantics                            |
333| ------------ | ----------------------- | -------------------------------------|
334| `output`     | `XlaOp`                 | n dimensional array with the same    |
335:              :                         : shape as input `operand` (y)         :
336| `batch_mean` | `XlaOp`                 | 1 dimensional array (\\(\mu\\))      |
337| `batch_var`  | `XlaOp`                 | 1 dimensional array (\\(\sigma^2\\)) |
338
339The `batch_mean` and `batch_var` are moments calculated across the batch and
340spatial dimensions using the formulas above.
341
342## BitcastConvertType
343
344See also
345[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
346
347Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
348operation from a data shape to a target shape. The dimensions must match, and
349the conversion is an element-wise one; e.g. `s32` elements become `f32` elements
350via bitcast routine. Bitcast is implemented as a low-level cast, so machines
351with different floating-point representations will give different results.
352
353<b> `BitcastConvertType(operand, new_element_type)` </b>
354
355Arguments          | Type            | Semantics
356------------------ | --------------- | ---------------------------
357`operand`          | `XlaOp`         | array of type T with dims D
358`new_element_type` | `PrimitiveType` | type U
359
360The dimensions of the operand and the target shape must match. The bit-width of
361the source and destination element types must be equal. The source
362and destination element types must not be tuples.
363
364## Broadcast
365
366See also
367[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
368
369Adds dimensions to an array by duplicating the data in the array.
370
371<b> `Broadcast(operand, broadcast_sizes)` </b>
372
373Arguments         | Type                | Semantics
374----------------- | ------------------- | -------------------------------
375`operand`         | `XlaOp`             | The array to duplicate
376`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
377
378The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has
379values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then
380the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`.
381
382The new dimensions index into copies of the operand, i.e.
383
384```
385output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
386```
387
388For example, if `operand` is a scalar `f32` with value `2.0f`, and
389`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
390`f32[2, 3]` and all the values in the result will be `2.0f`.
391
392## BroadcastInDim
393
394See also
395[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
396
397Expands the size and rank of an array by duplicating the data in the array.
398
399<b> `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` </b>
400
401| Arguments              | Type                | Semantics                     |
402| ---------------------- | ------------------- | ----------------------------- |
403| `operand`              | `XlaOp`             | The array to duplicate        |
404| `out_dim_size`         | `ArraySlice<int64>` | The sizes of the dimensions   |
405:                        :                     : of the target shape           :
406| `broadcast_dimensions` | `ArraySlice<int64>` | Which dimension in the target |
407:                        :                     : shape each dimension of the   :
408:                        :                     : operand shape corresponds to  :
409
410Similar to Broadcast, but allows adding dimensions anywhere and expanding
411existing dimensions with size 1.
412
413The `operand` is broadcast to the shape described by `out_dim_size`.
414`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the
415target shape, i.e. the i'th dimension of the operand is mapped to the
416broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of
417`operand` must have size 1 or be the same size as the dimension in the output
418shape they are mapped to. The remaining dimensions are filled with dimensions of
419size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate
420dimensions to reach the output shape. The semantics are described in detail on
421the [broadcasting page](broadcasting.md).
422
423## Call
424
425See also
426[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
427
428Invokes a computation with the given arguments.
429
430<b> `Call(computation, args...)` </b>
431
432| Arguments     | Type                   | Semantics                           |
433| ------------- | ---------------------- | ----------------------------------- |
434| `computation` | `XlaComputation`       | computation of type `T_0, T_1, ..., |
435:               :                        : T_{N-1} -> S` with N parameters of  :
436:               :                        : arbitrary type                      :
437| `args`        | sequence of N `XlaOp`s | N arguments of arbitrary type       |
438
439The arity and types of the `args` must match the parameters of the
440`computation`. It is allowed to have no `args`.
441
442## Cholesky
443
444See also
445[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
446
447Computes the
448[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition)
449of a batch of symmetric (Hermitian) positive definite matrices.
450
451<b> `Cholesky(a, lower)` </b>
452
453Arguments | Type    | Semantics
454--------- | ------- | -----------------------------------------------------
455`a`       | `XlaOp` | a rank > 2 array of a complex or floating-point type.
456`lower`   | `bool`  | whether to use the upper or lower triangle of `a`.
457
458If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l
459. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such
460that $$ a = u^T . u $$.
461
462Input data is read only from the lower/upper triangle of `a`, depending on the
463value of `lower`. Values from the other triangle are ignored. Output data is
464returned in the same triangle; the values in the other triangle are
465implementation-defined and may be anything.
466
467If the rank of `a` is greater than 2, `a` is treated as a batch of matrices,
468where all except the minor 2 dimensions are batch dimensions.
469
470If `a` is not symmetric (Hermitian) positive definite, the result is
471implementation-defined.
472
473## Clamp
474
475See also
476[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
477
478Clamps an operand to within the range between a minimum and maximum value.
479
480<b> `Clamp(min, operand, max)` </b>
481
482Arguments | Type    | Semantics
483--------- | ------- | ---------------
484`min`     | `XlaOp` | array of type T
485`operand` | `XlaOp` | array of type T
486`max`     | `XlaOp` | array of type T
487
488Given an operand and minimum and maximum values, returns the operand if it is in
489the range between the minimum and maximum, else returns the minimum value if the
490operand is below this range or the maximum value if the operand is above this
491range.  That is, `clamp(a, x, b) =  min(max(a, x), b)`.
492
493All three arrays must be the same shape. Alternatively, as a restricted form of
494[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`.
495
496Example with scalar `min` and `max`:
497
498```
499let operand: s32[3] = {-1, 5, 9};
500let min: s32 = 0;
501let max: s32 = 6;
502==>
503Clamp(min, operand, max) = s32[3]{0, 5, 6};
504```
505
506## Collapse
507
508See also
509[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
510and the `tf.reshape` operation.
511
512Collapses dimensions of an array into one dimension.
513
514<b> `Collapse(operand, dimensions)` </b>
515
516Arguments    | Type           | Semantics
517------------ | -------------- | -----------------------------------------------
518`operand`    | `XlaOp`        | array of type T
519`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions.
520
521Collapse replaces the given subset of the operand's dimensions by a single
522dimension. The input arguments are an arbitrary array of type T and a
523compile-time-constant vector of dimension indices. The dimension indices must be
524an in-order (low to high dimension numbers), consecutive subset of T's
525dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but
526{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the
527same position in the dimension sequence as those they replace, with the new
528dimension size equal to the product of original dimension sizes. The lowest
529dimension number in `dimensions` is the slowest varying dimension (most major)
530in the loop nest which collapses these dimension, and the highest dimension
531number is fastest varying (most minor). See the `tf.reshape` operator
532if more general collapse ordering is needed.
533
534For example, let v be an array of 24 elements:
535
536```
537let v = f32[4x2x3] {{{10, 11, 12},  {15, 16, 17}},
538{{20, 21, 22},  {25, 26, 27}},
539{{30, 31, 32},  {35, 36, 37}},
540{{40, 41, 42},  {45, 46, 47}}};
541
542// Collapse to a single dimension, leaving one dimension.
543let v012 = Collapse(v, {0,1,2});
544then v012 == f32[24] {10, 11, 12, 15, 16, 17,
54520, 21, 22, 25, 26, 27,
54630, 31, 32, 35, 36, 37,
54740, 41, 42, 45, 46, 47};
548
549// Collapse the two lower dimensions, leaving two dimensions.
550let v01 = Collapse(v, {0,1});
551then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17},
552{20, 21, 22, 25, 26, 27},
553{30, 31, 32, 35, 36, 37},
554{40, 41, 42, 45, 46, 47}};
555
556// Collapse the two higher dimensions, leaving two dimensions.
557let v12 = Collapse(v, {1,2});
558then v12 == f32[8x3] {{10, 11, 12},
559{15, 16, 17},
560{20, 21, 22},
561{25, 26, 27},
562{30, 31, 32},
563{35, 36, 37},
564{40, 41, 42},
565{45, 46, 47}};
566
567```
568
569## CollectivePermute
570
571See also
572[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
573
574CollectivePermute is a collective operation that sends and receives data cross
575replicas.
576
577<b> `CollectivePermute(operand, source_target_pairs)` </b>
578
579| Arguments             | Type                    | Semantics                  |
580| --------------------- | ----------------------- | -------------------------- |
581| `operand`             | `XlaOp`                 | n dimensional input array  |
582| `source_target_pairs` | `<int64, int64>` vector | A list of                  |
583:                       :                         : (source_replica_id,        :
584:                       :                         : target_replica_id) pairs.  :
585:                       :                         : For each pair, the operand :
586:                       :                         : is sent from source        :
587:                       :                         : replica to target replica. :
588
589Note that there are the following restrictions on the `source_target_pair`:
590
591-   Any two pairs should not have the same target replica id, and they should
592not have the same source replica id.
593-   If a replica id is not a target in any pair, then the output on that replica
594is a tensor consists of 0(s) with the same shape as the input.
595
596## Concatenate
597
598See also
599[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
600
601Concatenate composes an array from multiple array operands. The array is of the
602same rank as each of the input array operands (which must be of the same rank as
603each other) and contains the arguments in the order that they were specified.
604
605<b> `Concatenate(operands..., dimension)` </b>
606
607| Arguments   | Type                  | Semantics                              |
608| ----------- | --------------------- | -------------------------------------- |
609| `operands`  | sequence of N `XlaOp` | N arrays of type T with dimensions     |
610:             :                       : [L0, L1, ...]. Requires N >= 1.        :
611| `dimension` | `int64`               | A value in the interval `[0, N)` that  |
612:             :                       : names the dimension to be concatenated :
613:             :                       : between the `operands`.                :
614
615With the exception of `dimension` all dimensions must be the same. This is
616because XLA does not support "ragged" arrays. Also note that rank-0 values
617cannot be concatenated (as it's impossible to name the dimension along which the
618concatenation occurs).
619
6201-dimensional example:
621
622```
623Concat({{2, 3}, {4, 5}, {6, 7}}, 0)
624>>> {2, 3, 4, 5, 6, 7}
625```
626
6272-dimensional example:
628
629```
630let a = {
631{1, 2},
632{3, 4},
633{5, 6},
634};
635let b = {
636{7, 8},
637};
638Concat({a, b}, 0)
639>>> {
640{1, 2},
641{3, 4},
642{5, 6},
643{7, 8},
644}
645```
646
647Diagram:
648<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
649<img style="width:100%" src="./images/ops_concatenate.png">
650</div>
651
652## Conditional
653
654See also
655[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
656
657<b> `Conditional(pred, true_operand, true_computation, false_operand,
658false_computation)` </b>
659
660<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
661
662Arguments           | Type             | Semantics
663------------------- | ---------------- | ------------------------------------
664`pred`              | `XlaOp`          | Scalar of type `PRED`
665`true_operand`      | `XlaOp`          | Argument of type \\(T_0\\)
666`true_computation`  | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\)
667`false_operand`     | `XlaOp`          | Argument of type \\(T_1\\)
668`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\)
669
670Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
671is `false`, and returns the result.
672
673The `true_computation` must take in a single argument of type \\(T_0\\) and will
674be invoked with `true_operand` which must be of the same type. The
675`false_computation` must take in a single argument of type \\(T_1\\) and will be
676invoked with `false_operand` which must be of the same type. The type of the
677returned value of `true_computation` and `false_computation` must be the same.
678
679<!-- mdformat on -->
680
681Note that only one of `true_computation` and `false_computation` will be
682executed depending on the value of `pred`.
683
684<b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
685
686<!-- mdformat off(disable mdformat for proper MathJax formatting) -->
687
688| Arguments             | Type                  | Semantics                    |
689| --------------------- | --------------------- | ---------------------------- |
690| `branch_index`        | `XlaOp`               | Scalar of type `S32`         |
691| `branch_computations` | sequence of N         | XlaComputations of type \\(  |
692:                       : `XlaComputation`      : T_0 \to S , T_1 \to S , ..., :
693:                       :                       : T_{N-1} \to S \\)            :
694| `branch_operands`     | sequence of N `XlaOp` | Arguments of type \\( T_0 ,  |
695:                       :                       : T_1 , ..., T_{N-1} \\)       :
696
697<!-- mdformat on -->
698
699Executes `branch_computations[branch_index]`, and returns the result. If
700`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`
701is executed as the default branch.
702
703Each `branch_computations[b]` must take in a single argument of type `T_b` and
704will be invoked with `branch_operands[b]` which must be of the same type. The
705type of the returned value of each `branch_computations[b]` must be the same.
706
707Note that only one of the `branch_computations` will be executed depending on
708the value of `branch_index`.
709
710## Conv (convolution)
711
712See also
713[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
714
715As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
716either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
717the output has the same shape as the input when not taking striding into
718account. VALID padding simply means no padding.
719
720## ConvWithGeneralPadding (convolution)
721
722See also
723[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
724
725Computes a convolution of the kind used in neural networks. Here, a convolution
726can be thought of as a n-dimensional window moving across a n-dimensional base
727area and a computation is performed for each possible position of the window.
728
729| Arguments             | Type                     | Semantics                |
730| --------------------- | ------------------------ | ------------------------ |
731| `lhs`                 | `XlaOp`                  | rank n+2 array of inputs |
732| `rhs`                 | `XlaOp`                  | rank n+2 array of kernel |
733:                       :                          : weights                  :
734| `window_strides`      | `ArraySlice<int64>`      | n-d array of kernel      |
735:                       :                          : strides                  :
736| `padding`             | `ArraySlice< pair<int64, | n-d array of (low, high) |
737:                       : int64>>`                 : padding                  :
738| `lhs_dilation`        | `ArraySlice<int64>`      | n-d lhs dilation factor  |
739:                       :                          : array                    :
740| `rhs_dilation`        | `ArraySlice<int64>`      | n-d rhs dilation factor  |
741:                       :                          : array                    :
742| `feature_group_count` | int64                    | the number of feature    |
743:                       :                          : groups                   :
744| `batch_group_count`   | int64                    | the number of batch      |
745:                       :                          : groups                   :
746
747Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
748array describing the base area. This is called the input, even though of course
749the rhs is also an input. In a neural network, these are the input activations.
750The n+2 dimensions are, in this order:
751
752*   `batch`: Each coordinate in this dimension represents an independent input
753for which convolution is carried out.
754*   `z/depth/features`: Each (y,x) position in the base area has a vector
755associated to it, which goes into this dimension.
756*   `spatial_dims`: Describes the `n` spatial dimensions that define the base
757area that the window moves across.
758
759The `rhs` argument is a rank n+2 array describing the convolutional
760filter/kernel/window. The dimensions are, in this order:
761
762*   `output-z`: The `z` dimension of the output.
763*   `input-z`: The size of this dimension times `feature_group_count` should
764equal the size of the `z` dimension in lhs.
765*   `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
766window that moves across the base area.
767
768The `window_strides` argument specifies the stride of the convolutional window
769in the spatial dimensions. For example, if the stride in the first spatial
770dimension is 3, then the window can only be placed at coordinates where the
771first spatial index is divisible by 3.
772
773The `padding` argument specifies the amount of zero padding to be applied to the
774base area. The amount of padding can be negative -- the absolute value of
775negative padding indicates the number of elements to remove from the specified
776dimension before doing the convolution. `padding[0]` specifies the padding for
777dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each
778pair has the low padding as the first element and the high padding as the second
779element. The low padding is applied in the direction of lower indices while the
780high padding is applied in the direction of higher indices. For example, if
781`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and
782by 3 zeroes on the right in the second spatial dimension. Using padding is
783equivalent to inserting those same zero values into the input (`lhs`) before
784doing the convolution.
785
786The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to
787be applied to the lhs and rhs, respectively, in each spatial dimension. If the
788dilation factor in a spatial dimension is d, then d-1 holes are implicitly
789placed between each of the entries in that dimension, increasing the size of the
790array. The holes are filled with a no-op value, which for convolution means
791zeroes.
792
793Dilation of the rhs is also called atrous convolution. For more details, see
794`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
795convolution. For more details, see `tf.nn.conv2d_transpose`.
796
797The `feature_group_count` argument (default value 1) can be used for grouped
798convolutions. `feature_group_count` needs to be a divisor of both the input and
799the output feature dimension. If `feature_group_count` is greater than 1, it
800means that conceptually the input and output feature dimension and the `rhs`
801output feature dimension are split evenly into `feature_group_count` many
802groups, each group consisting of a consecutive subsequence of features. The
803input feature dimension of `rhs` needs to be equal to the `lhs` input feature
804dimension divided by `feature_group_count` (so it already has the size of a
805group of input features). The i-th groups are used together to compute
806`feature_group_count` many separate convolutions. The results of these
807convolutions are concatenated together in the output feature dimension.
808
809For depthwise convolution the `feature_group_count` argument would be set to the
810input feature dimension, and the filter would be reshaped from
811`[filter_height, filter_width, in_channels, channel_multiplier]` to
812`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
813details, see `tf.nn.depthwise_conv2d`.
814
815The `batch_group_count` (default value 1) argument can be used for grouped
816filters during backpropagation. `batch_group_count` needs to be a divisor of the
817size of the `lhs` (input) batch dimension. If `batch_group_count` is greater
818than 1, it means that the output batch dimension should be of size `input batch
819/ batch_group_count`. The `batch_group_count` must be a divisor of the output
820feature size.
821
822The output shape has these dimensions, in this order:
823
824*   `batch`: The size of this dimension times `batch_group_count` should equal
825    the size of the `batch` dimension in lhs.
826*   `z`: Same size as `output-z` on the kernel (`rhs`).
827*   `spatial_dims`: One value for each valid placement of the convolutional
828    window.
829
830The valid placements of the convolutional window are determined by the strides
831and the size of the base area after padding.
832
833To describe what a convolution does, consider a 2d convolution, and pick some
834fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a
835position of a corner of the window within the base area (e.g. the upper left
836corner, depending on how you interpret the spatial dimensions). We now have a 2d
837window, taken from the base area, where each 2d point is associated to a 1d
838vector, so we get a 3d box. From the convolutional kernel, since we fixed the
839output coordinate `z`, we also have a 3d box. The two boxes have the same
840dimensions, so we can take the sum of the element-wise products between the two
841boxes (similar to a dot product). That is the output value.
842
843Note that if `output-z` is e.g., 5, then each position of the window produces 5
844values in the output into the `z` dimension of the output. These values differ
845in what part of the convolutional kernel is used - there is a separate 3d box of
846values used for each `output-z` coordinate. So you could think of it as 5
847separate convolutions with a different filter for each of them.
848
849Here is pseudo-code for a 2d convolution with padding and striding:
850
851```
852for (b, oz, oy, ox) {  // output coordinates
853  value = 0;
854  for (iz, ky, kx) {  // kernel coordinates and input z
855    iy = oy*stride_y + ky - pad_low_y;
856    ix = ox*stride_x + kx - pad_low_x;
857    if ((iy, ix) inside the base area considered without padding) {
858      value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
859    }
860  }
861  output(b, oz, oy, ox) = value;
862}
863```
864
865## ConvertElementType
866
867See also
868[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
869
870Similar to an element-wise `static_cast` in C++, performs an element-wise
871conversion operation from a data shape to a target shape. The dimensions must
872match, and the conversion is an element-wise one; e.g. `s32` elements become
873`f32` elements via an `s32`-to-`f32` conversion routine.
874
875<b> `ConvertElementType(operand, new_element_type)` </b>
876
877Arguments          | Type            | Semantics
878------------------ | --------------- | ---------------------------
879`operand`          | `XlaOp`         | array of type T with dims D
880`new_element_type` | `PrimitiveType` | type U
881
882The dimensions of the operand and the target shape must match. The source and
883destination element types must not be tuples.
884
885A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float
886conversion routine such as round-to-nearest-even.
887
888> Note: The precise float-to-int and visa-versa conversions are currently
889> unspecified, but may become additional arguments to the convert operation in
890> the future.  Not all possible conversions have been implemented for all
891>targets.
892
893```
894let a: s32[3] = {0, 1, 2};
895let b: f32[3] = convert(a, f32);
896then b == f32[3]{0.0, 1.0, 2.0}
897```
898
899## CrossReplicaSum
900
901Performs `AllReduce` with a summation computation.
902
903## CustomCall
904
905See also
906[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
907
908Call a user-provided function within a computation.
909
910<b> `CustomCall(target_name, args..., shape)` </b>
911
912| Arguments     | Type                   | Semantics                         |
913| ------------- | ---------------------- | --------------------------------- |
914| `target_name` | `string`               | Name of the function. A call      |
915:               :                        : instruction will be emitted which :
916:               :                        : targets this symbol name.         :
917| `args`        | sequence of N `XlaOp`s | N arguments of arbitrary type,    |
918:               :                        : which will be passed to the       :
919:               :                        : function.                         :
920| `shape`       | `Shape`                | Output shape of the function      |
921
922The function signature is the same, regardless of the arity or type of args:
923
924```
925extern "C" void target_name(void* out, void** in);
926```
927
928For example, if CustomCall is used as follows:
929
930```
931let x = f32[2] {1,2};
932let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}};
933
934CustomCall("myfunc", {x, y}, f32[3x3])
935```
936
937Here is an example of an implementation of `myfunc`:
938
939```
940extern "C" void myfunc(void* out, void** in) {
941  float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
942  float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
943  EXPECT_EQ(1, x[0]);
944  EXPECT_EQ(2, x[1]);
945  EXPECT_EQ(10, y[0][0]);
946  EXPECT_EQ(20, y[0][1]);
947  EXPECT_EQ(30, y[0][2]);
948  EXPECT_EQ(40, y[1][0]);
949  EXPECT_EQ(50, y[1][1]);
950  EXPECT_EQ(60, y[1][2]);
951  float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
952  z[0][0] = x[1] + y[1][0];
953  // ...
954}
955```
956
957The user-provided function must not have side-effects and its execution must be
958idempotent.
959
960> Note: The opaque nature of the user-provided function restricts optimization
961> opportunities for the compiler. Try to express your computation in terms of
962> native XLA ops whenever possible; only use CustomCall as a last resort.
963
964## Dot
965
966See also
967[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
968
969<b> `Dot(lhs, rhs)` </b>
970
971Arguments | Type    | Semantics
972--------- | ------- | ---------------
973`lhs`     | `XlaOp` | array of type T
974`rhs`     | `XlaOp` | array of type T
975
976The exact semantics of this operation depend on the ranks of the operands:
977
978| Input                   | Output                | Semantics               |
979| ----------------------- | --------------------- | ----------------------- |
980| vector [n] `dot` vector | scalar                | vector dot product      |
981: [n]                     :                       :                         :
982| matrix [m x k] `dot`    | vector [m]            | matrix-vector           |
983: vector [k]              :                       : multiplication          :
984| matrix [m x k] `dot`    | matrix [m x n]        | matrix-matrix           |
985: matrix [k x n]          :                       : multiplication          :
986
987The operation performs sum of products over the second dimension of `lhs` (or
988the first if it has rank 1) and the first dimension of `rhs`. These are the
989"contracted" dimensions. The contracted dimensions of `lhs` and `rhs` must be of
990the same size. In practice, it can be used to perform dot products between
991vectors, vector/matrix multiplications or matrix/matrix multiplications.
992
993## DotGeneral
994
995See also
996[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
997
998<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
999
1000Arguments           | Type                  | Semantics
1001------------------- | --------------------- | ---------------
1002`lhs`               | `XlaOp`               | array of type T
1003`rhs`               | `XlaOp`               | array of type T
1004`dimension_numbers` | `DotDimensionNumbers` | contracting and batch dimension numbers
1005
1006As Dot, but allows contracting and batch dimension numbers to be specified for
1007both the 'lhs' and 'rhs'.
1008
1009| DotDimensionNumbers Fields | Type                    | Semantics
1010| --------- | ----------------------- | ---------------
1011| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers |
1012| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers |
1013| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers |
1014| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers |
1015
1016DotGeneral performs the sum of products over contracting dimensions specified
1017in 'dimension_numbers'.
1018
1019Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
1020to be the same but must have the same dimension sizes.
1021
1022Example with contracting dimension numbers:
1023
1024```
1025lhs = { {1.0, 2.0, 3.0},
1026{4.0, 5.0, 6.0} }
1027
1028rhs = { {1.0, 1.0, 1.0},
1029{2.0, 2.0, 2.0} }
1030
1031DotDimensionNumbers dnums;
1032dnums.add_lhs_contracting_dimensions(1);
1033dnums.add_rhs_contracting_dimensions(1);
1034
1035DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
1036{15.0, 30.0} }
1037```
1038
1039Associated batch dimension numbers from the 'lhs' and 'rhs' must
1040have the same dimension sizes.
1041
1042Example with batch dimension numbers (batch size 2, 2x2 matrices):
1043
1044```
1045lhs = { { {1.0, 2.0},
1046{3.0, 4.0} },
1047{ {5.0, 6.0},
1048{7.0, 8.0} } }
1049
1050rhs = { { {1.0, 0.0},
1051{0.0, 1.0} },
1052{ {1.0, 0.0},
1053{0.0, 1.0} } }
1054
1055DotDimensionNumbers dnums;
1056dnums.add_lhs_contracting_dimensions(2);
1057dnums.add_rhs_contracting_dimensions(1);
1058dnums.add_lhs_batch_dimensions(0);
1059dnums.add_rhs_batch_dimensions(0);
1060
1061DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
1062{3.0, 4.0} },
1063{ {5.0, 6.0},
1064{7.0, 8.0} } }
1065```
1066
1067| Input                               | Output            | Semantics        |
1068| ----------------------------------- | ----------------- | ---------------- |
1069| [b0, m, k] `dot` [b0, k, n]         | [b0, m, n]        |  batch matmul    |
1070| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n]    |  batch matmul    |
1071
1072It follows that the resulting dimension number starts with the batch dimension,
1073then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
1074non-contracting/non-batch dimension.
1075
1076## DynamicSlice
1077
1078See also
1079[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1080
1081DynamicSlice extracts a sub-array from the input array at dynamic
1082`start_indices`. The size of the slice in each dimension is passed in
1083`size_indices`, which specify the end point of exclusive slice intervals in each
1084dimension: [start, start + size). The shape of `start_indices` must be rank ==
10851, with dimension size equal to the rank of `operand`.
1086
1087<b> `DynamicSlice(operand, start_indices, size_indices)` </b>
1088
1089| Arguments       | Type                  | Semantics                          |
1090| --------------- | --------------------- | ---------------------------------- |
1091| `operand`       | `XlaOp`               | N dimensional array of type T      |
1092| `start_indices` | sequence of N `XlaOp` | List of N scalar integers          |
1093:                 :                       : containing the starting indices of :
1094:                 :                       : the slice for each dimension.      :
1095:                 :                       : Value must be greater than or      :
1096:                 :                       : equal to zero.                     :
1097| `size_indices`  | `ArraySlice<int64>`   | List of N integers containing the  |
1098:                 :                       : slice size for each dimension.     :
1099:                 :                       : Each value must be strictly        :
1100:                 :                       : greater than zero, and start +     :
1101:                 :                       : size must be less than or equal to :
1102:                 :                       : the size of the dimension to avoid :
1103:                 :                       : wrapping modulo dimension size.    :
1104
1105The effective slice indices are computed by applying the following
1106transformation for each index `i` in `[1, N)` before performing the slice:
1107
1108```
1109start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
1110```
1111
1112This ensures that the extracted slice is always in-bounds with respect to the
1113operand array. If the slice is in-bounds before the transformation is applied,
1114the transformation has no effect.
1115
11161-dimensional example:
1117
1118```
1119let a = {0.0, 1.0, 2.0, 3.0, 4.0}
1120let s = {2}
1121
1122DynamicSlice(a, s, {2}) produces:
1123{2.0, 3.0}
1124```
1125
11262-dimensional example:
1127
1128```
1129let b =
1130{ {0.0,  1.0,  2.0},
1131{3.0,  4.0,  5.0},
1132{6.0,  7.0,  8.0},
1133{9.0, 10.0, 11.0} }
1134let s = {2, 1}
1135
1136DynamicSlice(b, s, {2, 2}) produces:
1137{ { 7.0,  8.0},
1138{10.0, 11.0} }
1139```
1140## DynamicUpdateSlice
1141
1142See also
1143[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1144
1145DynamicUpdateSlice generates a result which is the value of the input array
1146`operand`, with a slice `update` overwritten at `start_indices`.
1147The shape of `update` determines the shape of the sub-array of the result which
1148is updated.
1149The shape of `start_indices` must be rank == 1, with dimension size equal to
1150the rank of `operand`.
1151
1152<b> `DynamicUpdateSlice(operand, update, start_indices)` </b>
1153
1154| Arguments       | Type                  | Semantics                          |
1155| --------------- | --------------------- | ---------------------------------- |
1156| `operand`       | `XlaOp`               | N dimensional array of type T      |
1157| `update`        | `XlaOp`               | N dimensional array of type T      |
1158:                 :                       : containing the slice update. Each  :
1159:                 :                       : dimension of update shape must be  :
1160:                 :                       : strictly greater than zero, and    :
1161:                 :                       : start + update must be less than   :
1162:                 :                       : or equal to the operand size for   :
1163:                 :                       : each dimension to avoid generating :
1164:                 :                       : out-of-bounds update indices.      :
1165| `start_indices` | sequence of N `XlaOp` | List of N scalar integers          |
1166:                 :                       : containing the starting indices of :
1167:                 :                       : the slice for each dimension.      :
1168:                 :                       : Value must be greater than or      :
1169:                 :                       : equal to zero.                     :
1170
1171The effective slice indices are computed by applying the following
1172transformation for each index `i` in `[1, N)` before performing the slice:
1173
1174```
1175start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
1176```
1177
1178This ensures that the updated slice is always in-bounds with respect to the
1179operand array. If the slice is in-bounds before the transformation is applied,
1180the transformation has no effect.
1181
11821-dimensional example:
1183
1184```
1185let a = {0.0, 1.0, 2.0, 3.0, 4.0}
1186let u = {5.0, 6.0}
1187let s = {2}
1188
1189DynamicUpdateSlice(a, u, s) produces:
1190{0.0, 1.0, 5.0, 6.0, 4.0}
1191```
1192
11932-dimensional example:
1194
1195```
1196let b =
1197{ {0.0,  1.0,  2.0},
1198{3.0,  4.0,  5.0},
1199{6.0,  7.0,  8.0},
1200{9.0, 10.0, 11.0} }
1201let u =
1202{ {12.0,  13.0},
1203{14.0,  15.0},
1204{16.0,  17.0} }
1205
1206let s = {1, 1}
1207
1208DynamicUpdateSlice(b, u, s) produces:
1209{ {0.0,  1.0,  2.0},
1210{3.0, 12.0, 13.0},
1211{6.0, 14.0, 15.0},
1212{9.0, 16.0, 17.0} }
1213```
1214
1215## Element-wise binary arithmetic operations
1216
1217See also
1218[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1219
1220A set of element-wise binary arithmetic operations is supported.
1221
1222<b> `Op(lhs, rhs)` </b>
1223
1224Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul`
1225(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min`
1226(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR).
1227
1228Arguments | Type    | Semantics
1229--------- | ------- | ----------------------------------------
1230`lhs`     | `XlaOp` | left-hand-side operand: array of type T
1231`rhs`     | `XlaOp` | right-hand-side operand: array of type T
1232
1233The arguments' shapes have to be either similar or compatible. See the
1234[broadcasting](broadcasting.md) documentation about what it means for shapes to
1235be compatible. The result of an operation has a shape which is the result of
1236broadcasting the two input arrays. In this variant, operations between arrays of
1237different ranks are *not* supported, unless one of the operands is a scalar.
1238
1239When `Op` is `Rem`, the sign of the result is taken from the dividend, and the
1240absolute value of the result is always less than the divisor's absolute value.
1241
1242Integer division overflow (signed/unsigned division/remainder by zero or signed
1243division/remainder of `INT_SMIN` with `-1`) produces an implementation defined
1244value.
1245
1246An alternative variant with different-rank broadcasting support exists for these
1247operations:
1248
1249<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
1250
1251Where `Op` is the same as above. This variant of the operation should be used
1252for arithmetic operations between arrays of different ranks (such as adding a
1253matrix to a vector).
1254
1255The additional `broadcast_dimensions` operand is a slice of integers used to
1256expand the rank of the lower-rank operand up to the rank of the higher-rank
1257operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to
1258the dimensions of the higher-rank shape. The unmapped dimensions of the expanded
1259shape are filled with dimensions of size one. Degenerate-dimension broadcasting
1260then broadcasts the shapes along these degenerate dimensions to equalize the
1261shapes of both operands. The semantics are described in detail on the
1262[broadcasting page](broadcasting.md).
1263
1264## Element-wise comparison operations
1265
1266See also
1267[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1268
1269A set of standard element-wise binary comparison operations is supported. Note
1270that standard IEEE 754 floating-point comparison semantics apply when comparing
1271floating-point types.
1272
1273<b> `Op(lhs, rhs)` </b>
1274
1275Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
1276(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
1277(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder,
1278GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities,
1279except that they additionally support a total order over the floating point
1280numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN.
1281
1282Arguments | Type    | Semantics
1283--------- | ------- | ----------------------------------------
1284`lhs`     | `XlaOp` | left-hand-side operand: array of type T
1285`rhs`     | `XlaOp` | right-hand-side operand: array of type T
1286
1287The arguments' shapes have to be either similar or compatible. See the
1288[broadcasting](broadcasting.md) documentation about what it means for shapes to
1289be compatible. The result of an operation has a shape which is the result of
1290broadcasting the two input arrays with the element type `PRED`. In this variant,
1291operations between arrays of different ranks are *not* supported, unless one of
1292the operands is a scalar.
1293
1294An alternative variant with different-rank broadcasting support exists for these
1295operations:
1296
1297<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
1298
1299Where `Op` is the same as above. This variant of the operation should be used
1300for comparison operations between arrays of different ranks (such as adding a
1301matrix to a vector).
1302
1303The additional `broadcast_dimensions` operand is a slice of integers specifying
1304the dimensions to use for broadcasting the operands. The semantics are described
1305in detail on the [broadcasting page](broadcasting.md).
1306
1307## Element-wise unary functions
1308
1309XlaBuilder supports these element-wise unary functions:
1310
1311<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
1312
1313<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
1314
1315<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
1316
1317<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
1318
1319<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
1320
1321<b>`Imag(operand)`</b> Element-wise imaginary part of a complex (or real)
1322shape. `x -> imag(x)`. If the operand is a floating point type, returns 0.
1323
1324<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite,
1325i.e., is not positive or negative infinity, and is not `NaN`. Returns an array
1326of `PRED` values with the same shape as the input, where each element is `true`
1327if and only if the corresponding input element is finite.
1328
1329<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
1330
1331<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
1332
1333<b>`Logistic(operand)`</b> Element-wise logistic function computation `x ->
1334logistic(x)`.
1335
1336<b>`PopulationCount(operand)`</b> Computes the number of bits set in each
1337element of `operand`.
1338
1339<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
1340
1341<b>`Real(operand)`</b> Element-wise real part of a complex (or real) shape.
1342`x -> real(x)`. If the operand is a floating point type, returns the same value.
1343
1344<b>`Rsqrt(operand)`</b> Element-wise reciprocal of square root operation
1345`x -> 1.0 / sqrt(x)`.
1346
1347<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
1348
1349$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$
1350
1351using the comparison operator of the element type of `operand`.
1352
1353<b>`Sqrt(operand)`</b> Element-wise square root operation `x -> sqrt(x)`.
1354
1355<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`.
1356
1357<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
1358
1359
1360Arguments | Type    | Semantics
1361--------- | ------- | ---------------------------
1362`operand` | `XlaOp` | The operand to the function
1363
1364The function is applied to each element in the `operand` array, resulting in an
1365array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
1366
1367## Fft
1368
1369The XLA FFT operation implements the forward and inverse Fourier Transforms for
1370real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are
1371supported, except on TPU, where only a single axis is supported (please file a
1372GitHub issue if you require higher order).
1373
1374See also
1375[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1376
1377| Arguments    | Type                | Semantics                |
1378| ------------ | ------------------- | ------------------------ |
1379| `operand`    | `XlaOp`             | The array we are Fourier |
1380:              :                     : transforming.            :
1381| `fft_type`   | `FftType`           | See the table below.     |
1382| `fft_length` | `ArraySlice<int64>` | The time-domain lengths  |
1383:              :                     : of the axes being        :
1384:              :                     : transformed. This is     :
1385:              :                     : needed in particular for :
1386:              :                     : IRFFT to right-size the  :
1387:              :                     : innermost axis, since    :
1388:              :                     : `RFFT(fft_length=[16])`  :
1389:              :                     : has the same output      :
1390:              :                     : shape as                 :
1391:              :                     : `RFFT(fft_length=[17])`. :
1392
1393| `FftType` | Semantics                                                        |
1394| --------- | ---------------------------------------------------------------- |
1395| `FFT`     | Forward complex-to-complex FFT. Shape is unchanged.              |
1396| `IFFT`    | Inverse complex-to-complex FFT. Shape is unchanged.              |
1397| `RFFT`    | Forward real-to-complex FFT. Shape of the innermost axis is      |
1398:           : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a    :
1399:           : non-zero value, omitting the reversed conjugate part of the      :
1400:           : transformed signal beyond the Nyquist frequency.                 :
1401| `IRFFT`   | Inverse real-to-complex FFT (i.e. takes complex, returns real).  |
1402:           : Shape of the innermost axis is expanded to `fft_length[-1]` if   :
1403:           : `fft_length[-1]` is a non-zero value, inferring the part of the  :
1404:           : transformed signal beyond the Nyquist frequency from the reverse :
1405:           : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries.       :
1406
1407#### Multidimensional FFT
1408
1409When more than 1 `fft_length` is provided, this is equivalent to applying a
1410cascade of FFT operations to each of the innermost axes. Note that for the
1411real->complex and complex->real cases, the innermost axis transform is
1412(effectively) performed first (RFFT; last for IRFFT), which is why the innermost
1413axis is the one which changes size. Other axis transforms will then be
1414complex->complex.
1415
1416#### Implementation details
1417
1418CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT.
1419
1420## Gather
1421
1422The XLA gather operation stitches together several slices (each slice at a
1423potentially different runtime offset) of an input array.
1424
1425### General Semantics
1426
1427See also
1428[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1429For a more intuitive description, see the "Informal Description" section below.
1430
1431<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
1432
1433| Arguments              | Type                | Semantics                     |
1434| ---------------------- | ------------------- | ----------------------------- |
1435| `operand`              | `XlaOp`             | The array we’re gathering     |
1436:                        :                     : from.                         :
1437| `start_indices`        | `XlaOp`             | Array containing the starting |
1438:                        :                     : indices of the slices we      :
1439:                        :                     : gather.                       :
1440| `index_vector_dim`     | `int64`             | The dimension in              |
1441:                        :                     : `start_indices` that          :
1442:                        :                     : "contains" the starting       :
1443:                        :                     : indices. See below for a      :
1444:                        :                     : detailed description.         :
1445| `offset_dims`          | `ArraySlice<int64>` | The set of dimensions in the  |
1446:                        :                     : output shape that offset into :
1447:                        :                     : an array sliced from operand. :
1448| `slice_sizes`          | `ArraySlice<int64>` | `slice_sizes[i]` is the       |
1449:                        :                     : bounds for the slice on       :
1450:                        :                     : dimension `i`.                :
1451| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each |
1452:                        :                     : slice that are collapsed      :
1453:                        :                     : away. These dimensions must   :
1454:                        :                     : have size 1.                  :
1455| `start_index_map`      | `ArraySlice<int64>` | A map that describes how to   |
1456:                        :                     : map indices in                :
1457:                        :                     : `start_indices` to legal      :
1458:                        :                     : indices into operand.         :
1459| `indices_are_sorted`   | `bool`              | Whether the indices are       |
1460:                        :                     : guaranteed to be sorted by    :
1461:                        :                     : the caller.                   :
1462| `unique_indices`       | `bool`              | Whether the indices are       |
1463:                        :                     : guaranteed to be unique by    :
1464:                        :                     : the caller.                   :
1465
1466For convenience, we label dimensions in the output array not in `offset_dims`
1467as `batch_dims`.
1468
1469The output is an array of rank `batch_dims.size` + `offset_dims.size`.
1470
1471The `operand.rank` must equal the sum of `offset_dims.size` and
1472`collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to
1473`operand.rank`.
1474
1475If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
1476`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
1477shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
1478shape of `start_indices` to be `[6,7,1]`).
1479
1480The bounds for the output array along dimension `i` is computed as follows:
1481
14821. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
1483some `k`) then we pick the corresponding dimension bounds out of
1484`start_indices.shape`, skipping `index_vector_dim` (i.e. pick
1485`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
1486`start_indices.shape.dims`[`k`+`1`] otherwise).
1487
14882. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
1489some `k`) then we pick the corresponding bound out of `slice_sizes` after
1490accounting for `collapsed_slice_dims` (i.e. we pick
1491`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
1492with the bounds at indices `collapsed_slice_dims` removed).
1493
1494Formally, the operand index `In` corresponding to a given output index `Out` is
1495calculated as follows:
1496
14971.  Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a
1498    vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
1499    Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
1500    this is well defined even if `G` is empty -- if `G` is empty then `S` =
1501    `start_indices`.
1502
15032.  Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
1504    scattering `S` using `start_index_map`. More precisely:
1505
1506    1.  `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
1507        `start_index_map.size`.
1508
1509    2.  `S`<sub>`in`</sub>[`_`] = `0` otherwise.
1510
15113.  Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
1512    at the offset dimensions in `Out` according to the `collapsed_slice_dims`
1513    set. More precisely:
1514
1515    1.  `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] =
1516        `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
1517        (`remapped_offset_dims` is defined below).
1518
1519    2.  `O`<sub>`in`</sub>[`_`] = `0` otherwise.
1520
15214.  `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
1522    addition.
1523
1524`remapped_offset_dims` is a monotonic function with domain [`0`, `offset.size`)
1525and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
1526`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
1527`2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
1528
1529If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
1530are sorted (in ascending `start_index_map` order) by the user. If they are not
1531then the semantics is implementation defined.
1532
1533If `unique_indices` is set to true then XLA can assume that all element
1534scattered to are unique. So XLA could use non-atomic operations. If
1535`unique_indices` is set to true and the indices being scattered to are not
1536unique then the semantics is implementation defined.
1537
1538### Informal Description and Examples
1539
1540Informally, every index `Out` in the output array corresponds to an element `E`
1541in the operand array, computed as follows:
1542
1543-   We use the batch dimensions in `Out` to look up a starting index from
1544    `start_indices`.
1545
1546-   We use `start_index_map` to map the starting index (whose size may be less
1547    than operand.rank) to a "full" starting index into the `operand`.
1548
1549-   We dynamic-slice out a slice with size `slice_sizes` using the full starting
1550    index.
1551
1552-   We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
1553    Since all collapsed slice dimensions must have a bound of 1, this reshape is
1554    always legal.
1555
1556-   We use the offset dimensions in `Out` to index into this slice to get the
1557    input element, `E`, corresponding to output index `Out`.
1558
1559`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples
1560that follow. More interesting values for `index_vector_dim` do not change the
1561operation fundamentally, but make the visual representation more cumbersome.
1562
1563To get an intuition on how all of the above fits together, let's look at an
1564example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array.  The
1565position of a slice into the `[16,11]` array can be represented as an index
1566vector of shape `S64[2]`, so the set of 5 positions can be represented as a
1567`S64[5,2]` array.
1568
1569The behavior of the gather operation can then be depicted as an index
1570transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
1571the output shape, and maps it to an element in the input array in the following
1572way:
1573
1574<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1575<img style="width:100%" src="./images/ops_xla_gather_0.svg">
1576</div>
1577
1578We first select an (`X`,`Y`) vector from the gather indices array using `G`.
1579The element in the output array at index
1580[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
1581array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
1582
1583`slice_sizes` is `[8,6]`, which decides the range of O<sub>`0`</sub> and
1584O<sub>`1`</sub>, and this in turn decides the bounds of the slice.
1585
1586This gather operation acts as a batch dynamic slice with `G` as the batch
1587dimension.
1588
1589The gather indices may be multidimensional.  For instance, a more general
1590version of the example above using a "gather indices" array of shape `[4,5,2]`
1591would translate indices like this:
1592
1593<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1594<img style="width:100%" src="./images/ops_xla_gather_1.svg">
1595</div>
1596
1597Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
1598`G`<sub>`1`</sub> as the batch dimensions.  The slice size is still `[8,6]`.
1599
1600The gather operation in XLA generalizes the informal semantics outlined above in
1601the following ways:
1602
16031. We can configure which dimensions in the output shape are the offset
1604dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
1605the last example).  The output batch dimensions (dimensions containing
1606`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
1607the output dimensions that are not offset dimensions.
1608
16092. The number of output offset dimensions explicitly present in the output
1610shape may be smaller than the input rank.  These "missing" dimensions, which
1611are listed explicitly as `collapsed_slice_dims`, must have a slice size of
1612`1`.  Since they have a slice size of `1` the only valid index for them is
1613`0` and eliding them does not introduce ambiguity.
1614
16153. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
1616example) may have fewer elements than the input array rank, and an explicit
1617mapping dictates how the index should be expanded to have the same rank as
1618the input.
1619
1620As a final example, we use (2) and (3) to implement `tf.gather_nd`:
1621
1622<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1623<img style="width:100%" src="./images/ops_xla_gather_2.svg">
1624</div>
1625
1626`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
1627from the gather indices array as usual, except the starting index has only one
1628element, `X`. Similarly, there is only one output offset index with the value
1629`O`<sub>`0`</sub>. However, before being used as indices into the input array,
1630these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
1631the formal description) and "Offset Mapping" (`remapped_offset_dims` in the
1632formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively,
1633adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
1634[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
1635[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
1636the semantics for `tf.gather_nd`.
1637
1638`slice_sizes` for this case is `[1,11]`.  Intuitively this means that every
1639index `X` in the gather indices array picks an entire row and the result is the
1640concatenation of all these rows.
1641
1642## GetDimensionSize
1643
1644See also
1645[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1646
1647Returns the size of the given dimension of the operand. The operand must be
1648array shaped.
1649
1650<b> `GetDimensionSize(operand, dimension)` </b>
1651
1652| Arguments   | Type    | Semantics                                           |
1653| ----------- | ------- | --------------------------------------------------- |
1654| `operand`   | `XlaOp` | n dimensional input array                           |
1655| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the |
1656:             :         : dimension                                           :
1657
1658## SetDimensionSize
1659
1660See also
1661[`XlaBuilder::SetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1662
1663Sets the dynamic size of XlaOp's given dimension. The operand must be
1664array shaped.
1665
1666<b> `SetDimensionSize(operand, size, dimension)` </b>
1667
1668| Arguments   | Type    | Semantics                                           |
1669| ----------- | ------- | --------------------------------------------------- |
1670| `operand`   | `XlaOp` | n dimensional input array.                          |
1671| `size`      | `XlaOp` | int32 representing the runtime dynamic size.        |
1672| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the |
1673:             :         : dimension.                                          :
1674
1675Pass through the operand as result, with dynamic dimension tracked by the
1676compiler.
1677
1678Padded values will be ignored by downstream reduction ops.
1679
1680```
1681let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
1682let five: s32 = 5;
1683let six: s32 = 6;
1684
1685// Setting dynamic dimension size doesn't change the upper bound of the static
1686// shape.
1687let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0);
1688let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0);
1689
1690// sum == 1 + 2 + 3 + 4 + 5
1691let sum:f32[] = reduce_sum(padded_v_five);
1692// product == 1 * 2 * 3 * 4 * 5
1693let product:f32[] = reduce_product(padded_v_five);
1694
1695// Changing padding size will yield different result.
1696// sum == 1 + 2 + 3 + 4 + 5 + 6
1697let sum':f32[] = reduce_sum(padded_v_six);
1698```
1699
1700## GetTupleElement
1701
1702See also
1703[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1704
1705Indexes into a tuple with a compile-time-constant value.
1706
1707The value must be a compile-time-constant so that shape inference can determine
1708the type of the resulting value.
1709
1710This is analogous to `std::get<int N>(t)` in C++. Conceptually:
1711
1712```
1713let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
1714let s: s32 = 5;
1715let t: (f32[10], s32) = tuple(v, s);
1716let element_1: s32 = gettupleelement(t, 1);  // Inferred shape matches s32.
1717```
1718
1719See also `tf.tuple`.
1720
1721## Infeed
1722
1723See also
1724[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1725
1726<b> `Infeed(shape)` </b>
1727
1728| Argument | Type    | Semantics                                             |
1729| -------- | ------- | ----------------------------------------------------- |
1730| `shape`  | `Shape` | Shape of the data read from the Infeed interface. The |
1731:          :         : layout field of the shape must be set to match the    :
1732:          :         : layout of the data sent to the device; otherwise its  :
1733:          :         : behavior is undefined.                                :
1734
1735Reads a single data item from the implicit Infeed streaming interface of the
1736device, interpreting the data as the given shape and its layout, and returns a
1737`XlaOp` of the data. Multiple Infeed operations are allowed in a
1738computation, but there must be a total order among the Infeed operations. For
1739example, two Infeeds in the code below have a total order since there is a
1740dependency between the while loops.
1741
1742```
1743result1 = while (condition, init = init_value) {
1744  Infeed(shape)
1745}
1746
1747result2 = while (condition, init = result1) {
1748  Infeed(shape)
1749}
1750```
1751
1752Nested tuple shapes are not supported. For an empty tuple shape, the Infeed
1753operation is effectively a no-op and proceeds without reading any data from the
1754Infeed of the device.
1755
1756> Note: We plan to allow multiple Infeed operations without a total order, in
1757> which case the compiler will provide information about how the Infeed
1758> operations are serialized in the compiled program.
1759
1760## Iota
1761
1762<b> `Iota()` </b>
1763
1764Builds a constant literal on device rather than a potentially large host
1765transfer. Creates a rank 1 array of values starting at zero and incrementing by
1766one. For floating-point types, the produced array is equivalent to
1767`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the
1768conversion is to the floating-point type.
1769
1770Arguments        | Type            | Semantics
1771---------------- | --------------- | ------------------------------------
1772`type`           | `PrimitiveType` | type U
1773`size`           | `int64`         | The number of elements in the array.
1774`iota_dimension` | `int64`         | The dimension to increment along.
1775
1776## Map
1777
1778See also
1779[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1780
1781<b> `Map(operands..., computation)` </b>
1782
1783| Arguments         | Type                   | Semantics                      |
1784| ----------------- | ---------------------- | ------------------------------ |
1785| `operands`        | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} |
1786| `computation`     | `XlaComputation`       | computation of type `T_0, T_1, |
1787:                   :                        : ..., T_{N + M -1} -> S` with N :
1788:                   :                        : parameters of type T and M of  :
1789:                   :                        : arbitrary type                 :
1790| `dimensions`      | `int64` array          | array of map dimensions        |
1791
1792Applies a scalar function over the given `operands` arrays, producing an array
1793of the same dimensions where each element is the result of the mapped function
1794applied to the corresponding elements in the input arrays.
1795
1796The mapped function is an arbitrary computation with the restriction that it has
1797N inputs of scalar type `T` and a single output with type `S`. The output has
1798the same dimensions as the operands except that the element type T is replaced
1799with S.
1800
1801For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <-
1802computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the
1803input arrays to produce the output array.
1804
1805## Pad
1806
1807See also
1808[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1809
1810<b> `Pad(operand, padding_value, padding_config)` </b>
1811
1812| Arguments        | Type            | Semantics                               |
1813| ---------------- | --------------- | --------------------------------------- |
1814| `operand`        | `XlaOp`         | array of type `T`                       |
1815| `padding_value`  | `XlaOp`         | scalar of type `T` to fill in the added |
1816:                  :                 : padding                                 :
1817| `padding_config` | `PaddingConfig` | padding amount on both edges (low,      |
1818:                  :                 : high) and between the elements of each  :
1819:                  :                 : dimension                               :
1820
1821Expands the given `operand` array by padding around the array as well as between
1822the elements of the array with the given `padding_value`. `padding_config`
1823specifies the amount of edge padding and the interior padding for each
1824dimension.
1825
1826`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains
1827three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and
1828`interior_padding`.
1829
1830`edge_padding_low` and `edge_padding_high` specify the amount of padding added
1831at the low-end (next to index 0) and the high-end (next to the highest index) of
1832each dimension respectively. The amount of edge padding can be negative -- the
1833absolute value of negative padding indicates the number of elements to remove
1834from the specified dimension.
1835
1836`interior_padding` specifies the amount of padding added between any two
1837elements in each dimension; it may not be negative.  Interior padding occurs
1838logically before edge padding, so in the case of negative edge padding, elements
1839are removed from the interior-padded operand.
1840
1841This operation is a no-op if the edge padding pairs are all (0, 0) and the
1842interior padding values are all 0. The figure below shows examples of different
1843`edge_padding` and `interior_padding` values for a two-dimensional array.
1844
1845<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1846  <img style="width:100%" src="./images/ops_pad.png">
1847</div>
1848
1849## Recv
1850
1851See also
1852[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1853
1854<b> `Recv(shape, channel_handle)` </b>
1855
1856| Arguments        | Type            | Semantics                            |
1857| ---------------- | --------------- | ------------------------------------ |
1858| `shape`          | `Shape`         | shape of the data to receive         |
1859| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
1860
1861Receives data of the given shape from a `Send` instruction in another
1862computation that shares the same channel handle. Returns a
1863XlaOp for the received data.
1864
1865The client API of `Recv` operation represents synchronous communication.
1866However, the instruction is internally decomposed into 2 HLO instructions
1867(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
1868[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
1869
1870<b>`Recv(const Shape& shape, int64 channel_id)`</b>
1871
1872Allocates resources required to receive data from a `Send` instruction with the
1873same channel_id. Returns a context for the allocated resources, which is used
1874by a following `RecvDone` instruction to wait for the completion of the data
1875transfer. The context is a tuple of {receive buffer (shape), request identifier
1876(U32)} and it can only be used by a `RecvDone` instruction.
1877
1878<b> `RecvDone(HloInstruction context)` </b>
1879
1880Given a context created by a `Recv` instruction, waits for the data transfer to
1881complete and returns the received data.
1882
1883## Reduce
1884
1885See also
1886[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
1887
1888Applies a reduction function to one or more arrays in parallel.
1889
1890<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
1891
1892| Arguments     | Type                  | Semantics                        |
1893| ------------- | --------------------- | -------------------------------- |
1894| `operands`    | Sequence of N `XlaOp` | N arrays of types `T_0, ...,     |
1895:               :                       : T_{N-1}`.                        :
1896| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ...,    |
1897:               :                       : T_{N-1}`.                        :
1898| `computation` | `XlaComputation`      | computation of type `T_0, ...,   |
1899:               :                       : T_{N-1}, T_0, ..., T_{N-1} ->`   :
1900:               :                       : `Collate(T_0, ..., T_{N-1})`.    :
1901| `dimensions`  | `int64` array         | unordered array of dimensions to |
1902:               :                       : reduce.                          :
1903
1904Where:
1905
1906*   N is required to be greater or equal to 1.
1907*   All input arrays must have the same dimensions.
1908*   If `N = 1`, `Collate(T)` is `T`.
1909*   If `N > 1`, `Collate(T_0, ..., T_{N-1})` is a tuple of `N` elements of type
1910    `T`.
1911
1912The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
1913`T_i`, the dimensions of which are described below.
1914
1915This operation reduces one or more dimensions of each input array into scalars.
1916The rank of each returned array is `rank(operand) - len(dimensions)`. The
1917initial value used for every reduction is `init_value`, and it may be inserted
1918anywhere during computation by the back-end. In most cases, `init_value` is an
1919identity of the reduction function (for example, `0` for addition). The applied
1920`computation` is always passed the `init_value` on the left-hand side.
1921
1922The evaluation order of the reduction function is arbitrary and may be
1923non-deterministic. Therefore, the reduction function should not be overly
1924sensitive to reassociation.
1925
1926Some reduction functions like addition are not strictly associative for floats.
1927However, if the range of the data is limited, floating-point addition is close
1928enough to being associative for most practical uses. It is possible to conceive
1929of some completely non-associative reductions, however, and these will produce
1930incorrect or unpredictable results in XLA.
1931
1932As an example, when reducing across one dimension in a single 1D array with
1933values `[10, 11, 12, 13]`, with reduction function `f` (this is `computation`)
1934then that could be computed as
1935
1936`f(10, f(11, f(12, f(init_value, 13)))`
1937
1938but there are also many other possibilities, e.g.
1939
1940`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))`
1941
1942The following is a rough pseudo-code example of how reduction could be
1943implemented, using summation as the reduction computation with an initial value
1944of 0.
1945
1946```python
1947result_shape <- remove all dims in dimensions from operand_shape
1948
1949# Iterate over all elements in result_shape. The number of r's here is equal
1950# to the rank of the result
1951for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
1952  # Initialize this result element
1953  result[r0, r1...] <- 0
1954
1955  # Iterate over all the reduction dimensions
1956  for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
1957    # Increment the result element with the value of the operand's element.
1958    # The index of the operand's element is constructed from all ri's and di's
1959    # in the right order (by construction ri's and di's together index over the
1960    # whole operand shape).
1961    result[r0, r1...] += operand[ri... di]
1962```
1963
1964Here's an example of reducing a 2D array (matrix). The shape has rank 2,
1965dimension 0 of size 2 and dimension 1 of size 3:
1966
1967<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1968  <img style="width:35%" src="./images/ops_2d_matrix.png">
1969</div>
1970
1971Results of reducing dimensions 0 or 1 with an "add" function:
1972
1973<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1974  <img style="width:35%" src="./images/ops_reduce_from_2d_matrix.png">
1975</div>
1976
1977Note that both reduction results are 1D arrays. The diagram shows one as column
1978and another as row just for visual convenience.
1979
1980For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of
1981size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the
1982values 1 to 6 are replicated across dimension 0.
1983
1984<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
1985  <img style="width:35%" src="./images/ops_reduce_from_3d_matrix.png">
1986</div>
1987
1988Similarly to the 2D example, we can reduce just one dimension. If we reduce
1989dimension 0, for example, we get a rank-2 array where all values across
1990dimension 0 were folded into a scalar:
1991
1992```text
1993|  4   8  12 |
1994| 16  20  24 |
1995```
1996
1997If we reduce dimension 2, we also get a rank-2 array where all values across
1998dimension 2 were folded into a scalar:
1999
2000```text
2001| 6  15 |
2002| 6  15 |
2003| 6  15 |
2004| 6  15 |
2005```
2006
2007Note that the relative order between the remaining dimensions in the input is
2008preserved in the output, but some dimensions may get assigned new numbers (since
2009the rank changes).
2010
2011We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
2012the 1D array `[20, 28, 36]`.
2013
2014Reducing the 3D array over all its dimensions produces the scalar `84`.
2015
2016### Variadic Reduce
2017
2018When `N > 1`, reduce function application is slightly more complex, as it is
2019applied simultaneously to all inputs. The operands are supplied to the
2020computation in the following order:
2021
2022*   Running reduced value for the first operand
2023*   ...
2024*   Running reduced value for the N'th operand
2025*   Input value for the first operand
2026*   ...
2027*   Input value for the N'th operand
2028
2029For example, consider the following reduction function, which can be used to
2030compute the max and the argmax of a 1-D array in parallel:
2031
2032```python
2033f: (Float, Int, Float, Int) -> Float, Int
2034f(max, argmax, value, index):
2035  if value >= max:
2036    return (value, index)
2037  else:
2038    return (max, argmax)
2039```
2040
2041For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
2042`I_V = Float, I_K =  Int`, the result `f_(N-1)` of reducing across the only
2043input dimension is equivalent to the following recursive application:
2044
2045```
2046f_0 = f(I_V, I_K, V_0, K_0)
2047f_1 = f(f_0.first, f_0.second, V_1, K_1)
2048...
2049f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
2050```
2051
2052Applying this reduction to an array of values, and an array of sequential
2053indices (i.e. iota), will co-iterate over the arrays, and return a tuple
2054containing the maximal value and the matching index.
2055
2056## ReducePrecision
2057
2058See also
2059[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2060
2061Models the effect of converting floating-point values to a lower-precision
2062format (such as IEEE-FP16) and back to the original format.  The number of
2063exponent and mantissa bits in the lower-precision format can be specified
2064arbitrarily, although all bit sizes may not be supported on all hardware
2065implementations.
2066
2067<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b>
2068
2069Arguments       | Type    | Semantics
2070--------------- | ------- | -------------------------------------------------
2071`operand`       | `XlaOp` | array of floating-point type `T`.
2072`exponent_bits` | `int32` | number of exponent bits in lower-precision format
2073`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format
2074
2075The result is an array of type `T`.  The input values are rounded to the nearest
2076value representable with the given number of mantissa bits (using "ties to even"
2077semantics), and any values that exceed the range specified by the number of
2078exponent bits are clamped to positive or negative infinity.  `NaN` values are
2079retained, although they may be converted to canonical `NaN` values.
2080
2081The lower-precision format must have at least one exponent bit (in order to
2082distinguish a zero value from an infinity, since both have a zero mantissa), and
2083must have a non-negative number of mantissa bits.  The number of exponent or
2084mantissa bits may exceed the corresponding value for type `T`; the corresponding
2085portion of the conversion is then simply a no-op.
2086
2087## ReduceWindow
2088
2089See also
2090[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2091
2092Applies a reduction function to all elements in each window of the input
2093multi-dimensional array, producing an output multi-dimensional array with the
2094same number of elements as the number of valid positions of the window. A
2095pooling layer can be expressed as a `ReduceWindow`. Similar to
2096[`Reduce`](#reduce), the applied `computation` is always passed the `init_value`
2097on the left-hand side.
2098
2099<b> `ReduceWindow(operand, init_value, computation, window_dimensions,
2100window_strides, padding)` </b>
2101
2102| Arguments           | Type                | Semantics                        |
2103| ------------------- | ------------------- | -------------------------------- |
2104| `operand`           | `XlaOp`             | N dimensional array containing   |
2105:                     :                     : elements of type T. This is the  :
2106:                     :                     : base area on which the window is :
2107:                     :                     : placed.                          :
2108| `init_value`        | `XlaOp`             | Starting value for the           |
2109:                     :                     : reduction. See [Reduce](#reduce) :
2110:                     :                     : for details.                     :
2111| `computation`       | `XlaComputation`    | Reduction function of type `T, T |
2112:                     :                     : -> T`, to apply to all elements  :
2113:                     :                     : in each window                   :
2114| `window_dimensions` | `ArraySlice<int64>` | array of integers for window     |
2115:                     :                     : dimension values                 :
2116| `window_strides`    | `ArraySlice<int64>` | array of integers for window     |
2117:                     :                     : stride values                    :
2118| `base_dilations`    | `ArraySlice<int64>` | array of integers for base       |
2119:                     :                     : dilation values                  :
2120| `window_dilations`  | `ArraySlice<int64>` | array of integers for window     |
2121:                     :                     : dilation values                  :
2122| `padding`           | `Padding`           | padding type for window          |
2123:                     :                     : (Padding\:\:kSame, which pads so :
2124:                     :                     : as to have the same output shape :
2125:                     :                     : as input if the stride is 1, or  :
2126:                     :                     : Padding\:\:kValid, which uses no :
2127:                     :                     : padding and "stops" the window   :
2128:                     :                     : once it no longer fits)          :
2129
2130Below code and figure shows an example of using `ReduceWindow`. Input is a
2131matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
2132[2x3].
2133
2134```
2135// Create a computation for the reduction (maximum).
2136XlaComputation max;
2137{
2138  XlaBuilder builder(client_, "max");
2139  auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
2140  auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
2141  builder.Max(y, x);
2142  max = builder.Build().ConsumeValueOrDie();
2143}
2144
2145// Create a ReduceWindow computation with the max reduction computation.
2146XlaBuilder builder(client_, "reduce_window_2x3");
2147auto shape = ShapeUtil::MakeShape(F32, {4, 6});
2148auto input = builder.Parameter(0, shape, "input");
2149builder.ReduceWindow(
2150    input,
2151    /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
2152    *max,
2153    /*window_dimensions=*/{2, 3},
2154    /*window_stride_dimensions=*/{2, 3},
2155    Padding::kValid);
2156```
2157
2158<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2159  <img style="width:35%" src="./images/ops_reduce_window.png">
2160</div>
2161
2162Stride of 1 in a dimension specifies that the position of a window in the
2163dimension is 1 element away from its adjacent window. In order to specify that
2164no windows overlap with each other, window_stride_dimensions should be equal to
2165window_dimensions. The figure below illustrates the use of two different stride
2166values. Padding is applied to each dimension of the input and the calculations
2167are the same as though the input came in with the dimensions it has after
2168padding.
2169
2170<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2171  <img style="width:75%" src="./images/ops_reduce_window_stride.png">
2172</div>
2173
2174For a non-trivial padding example, consider computing reduce-window minimum
2175(initial value is `MAX_FLOAT`) with dimension `3` and stride `2` over the input
2176array `[10000, 1000, 100, 10, 1]`. Padding `kValid` computes minimums over two
2177_valid_ windows: `[10000, 1000, 100]` and `[100, 10, 1]`, resulting in the
2178output `[100, 1]`. Padding `kSame` first pads the array so that the shape after
2179the reduce-window would be the _same_ as input for stride one by adding initial
2180elements on both sides, getting `[MAX_VALUE, 10000, 1000, 100, 10, 1,
2181MAX_VALUE]`. Running reduce-window over the padded array operates on three
2182windows `[MAX_VALUE, 10000, 1000]`, `[1000, 100, 10]`, `[10, 1, MAX_VALUE]`, and
2183yields `[1000, 10, 1]`.
2184
2185The evaluation order of the reduction function is arbitrary and may be
2186non-deterministic. Therefore, the reduction function should not be overly
2187sensitive to reassociation. See the discussion about associativity in the
2188context of [`Reduce`](#reduce) for more details.
2189
2190## ReplicaId
2191
2192See also
2193[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2194
2195Returns the unique ID (U32 scalar) of the replica.
2196
2197<b> `ReplicaId()` </b>
2198
2199The unique ID of each replica is an unsigned integer in the interval `[0, N)`,
2200where `N` is the number of replicas. Since all the replicas are running the same
2201program, a `ReplicaId()` call in the program will return a different value on
2202each replica.
2203
2204## Reshape
2205
2206See also
2207[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
2208and the [`Collapse`](#collapse) operation.
2209
2210Reshapes the dimensions of an array into a new configuration.
2211
2212<b> `Reshape(operand, new_sizes)` </b>
2213<b> `Reshape(operand, dimensions, new_sizes)` </b>
2214
2215Arguments    | Type           | Semantics
2216------------ | -------------- | ---------------------------------------
2217`operand`    | `XlaOp`        | array of type T
2218`dimensions` | `int64` vector | order in which dimensions are collapsed
2219`new_sizes`  | `int64` vector | vector of sizes of new dimensions
2220
2221Conceptually, reshape first flattens an array into a one-dimensional vector of
2222data values, and then refines this vector into a new shape. The input arguments
2223are an arbitrary array of type T, a compile-time-constant vector of dimension
2224indices, and a compile-time-constant vector of dimension sizes for the result.
2225The values in the `dimension` vector, if given, must be a permutation of all of
2226T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of
2227the dimensions in `dimensions` is from slowest-varying dimension (most major) to
2228fastest-varying dimension (most minor) in the loop nest which collapses the
2229input array into a single dimension. The `new_sizes` vector determines the size
2230of the output array. The value at index 0 in `new_sizes` is the size of
2231dimension 0, the value at index 1 is the size of dimension 1, and so on. The
2232product of the `new_size` dimensions must equal the product of the operand's
2233dimension sizes. When refining the collapsed array into the multidimensional
2234array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from
2235slowest varying (most major) and to fastest varying (most minor).
2236
2237For example, let v be an array of 24 elements:
2238
2239```
2240let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
2241                    {{20, 21, 22}, {25, 26, 27}},
2242                    {{30, 31, 32}, {35, 36, 37}},
2243                    {{40, 41, 42}, {45, 46, 47}}};
2244
2245In-order collapse:
2246let v012_24 = Reshape(v, {0,1,2}, {24});
2247then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
2248                         30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
2249
2250let v012_83 = Reshape(v, {0,1,2}, {8,3});
2251then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17},
2252                          {20, 21, 22}, {25, 26, 27},
2253                          {30, 31, 32}, {35, 36, 37},
2254                          {40, 41, 42}, {45, 46, 47}};
2255
2256Out-of-order collapse:
2257let v021_24 = Reshape(v, {1,2,0}, {24});
2258then v012_24 == f32[24]  {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
2259                          15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
2260
2261let v021_83 = Reshape(v, {1,2,0}, {8,3});
2262then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21},
2263                          {31, 41, 12}, {22, 32, 42},
2264                          {15, 25, 35}, {45, 16, 26},
2265                          {36, 46, 17}, {27, 37, 47}};
2266
2267
2268let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
2269then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40},
2270                              {11, 21}, {31, 41},
2271                              {12, 22}, {32, 42}},
2272                             {{15, 25}, {35, 45},
2273                              {16, 26}, {36, 46},
2274                              {17, 27}, {37, 47}}};
2275```
2276
2277As a special case, reshape can transform a single-element array to a scalar and
2278vice versa. For example,
2279
2280```
2281Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5;
2282Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
2283```
2284
2285## Rev (reverse)
2286
2287See also
2288[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2289
2290<b>`Rev(operand, dimensions)`</b>
2291
2292Arguments    | Type                | Semantics
2293------------ | ------------------- | ---------------------
2294`operand`    | `XlaOp`             | array of type T
2295`dimensions` | `ArraySlice<int64>` | dimensions to reverse
2296
2297Reverses the order of elements in the `operand` array along the specified
2298`dimensions`, generating an output array of the same shape. Each element of the
2299operand array at a multidimensional index is stored into the output array at a
2300transformed index. The multidimensional index is transformed by reversing the
2301index in each dimension to be reversed (i.e., if a dimension of size N is one of
2302the reversing dimensions, its index i is transformed into N - 1 - i).
2303
2304One use for the `Rev` operation is to reverse the convolution weight array along
2305the two window dimensions during the gradient computation in neural networks.
2306
2307## RngNormal
2308
2309See also
2310[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2311
2312Constructs an output of a given shape with random numbers generated following
2313the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and
2314$$\sigma$$, and output shape have to have a floating point elemental type. The
2315parameters furthermore have to be scalar valued.
2316
2317<b>`RngNormal(mu, sigma, shape)`</b>
2318
2319| Arguments | Type    | Semantics                                           |
2320| --------- | ------- | --------------------------------------------------- |
2321| `mu`      | `XlaOp` | Scalar of type T specifying mean of generated       |
2322:           :         : numbers                                   :
2323| `sigma`   | `XlaOp` | Scalar of type T specifying standard deviation of   |
2324:           :         : generated numbers                                   :
2325| `shape`   | `Shape` | Output shape of type T                              |
2326
2327## RngUniform
2328
2329See also
2330[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2331
2332Constructs an output of a given shape with random numbers generated following
2333the uniform distribution over the interval $$[a,b)$$. The parameters and output
2334element type have to be a boolean type, an integral type or a floating point
2335types, and the types have to be consistent. The CPU and GPU backends currently
2336only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the
2337parameters need to be scalar valued. If $$b <= a$$ the result is
2338implementation-defined.
2339
2340<b>`RngUniform(a, b, shape)`</b>
2341
2342| Arguments | Type                    | Semantics                         |
2343| --------- | ----------------------- | --------------------------------- |
2344| `a`       | `XlaOp`                 | Scalar of type T specifying lower |
2345:           :                         : limit of interval                 :
2346| `b`       | `XlaOp`                 | Scalar of type T specifying upper |
2347:           :                         : limit of interval                 :
2348| `shape`   | `Shape`                 | Output shape of type T            |
2349
2350## RngBitGenerator
2351
2352Generates an output with a given shape filled with uniform random bits using the
2353specified algorithm (or backend default) and returns an updated state (with the
2354same shape as initial state) and the generated random data.
2355
2356Initial state is the initial state of the current random number generation. It
2357and the required shape and valid values are dependent on the algorithm used.
2358
2359The output is guaranteed to be a deterministic function of the initial state but
2360it is *not* guaranteed to be deterministic between backends and different
2361compiler versions.
2362
2363<b>`RngBitGenerator(algorithm, key, shape)`</b>
2364
2365Arguments       | Type              | Semantics
2366--------------- | ----------------- | -------------------------------------
2367`algorithm`     | `RandomAlgorithm` | PRNG algorithm to be used.
2368`initial_state` | `XlaOp`           | Initial state for the PRNG algorithm.
2369`shape`         | `Shape`           | Output shape for generated data.
2370
2371Available values for `algorithm`:
2372
2373-   `rng_default`: Backend specific algorithm with backend specific shape
2374    requirements.
2375
2376-   `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state`
2377    shape is `u64[2]` with arbitrary values.
2378    [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
2379
2380-   `rng_philox`: Philox algorithm to generate random numbers in parallel. The
2381    `initial_state` shape is `u64[3]` with arbitrary values.
2382    [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
2383
2384## Scatter
2385
2386The XLA scatter operation generates a result which is the value of the input
2387array `operand`, with several slices (at indices specified by `scatter_indices`)
2388updated with the values in `updates` using `update_computation`.
2389
2390See also
2391[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2392
2393<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b>
2394
2395Arguments                      | Type                | Semantics
2396------------------------------ | ------------------- | ---------
2397`operand`                      | `XlaOp`             | Array to be scattered into.
2398`scatter_indices`              | `XlaOp`             | Array containing the starting indices of the slices that must be scattered to.
2399`updates`                      | `XlaOp`             | Array containing the values that must be used for scattering.
2400`update_computation`           | `XlaComputation`    | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `(T, T) -> T`.
2401`index_vector_dim`             | `int64`             | The dimension in `scatter_indices` that contains the starting indices.
2402`update_window_dims`           | `ArraySlice<int64>` | The set of dimensions in `updates` shape that are _window dimensions_.
2403`inserted_window_dims`         | `ArraySlice<int64>` | The set of _window dimensions_ that must be inserted into `updates` shape.
2404`scatter_dims_to_operand_dims` | `ArraySlice<int64>` | A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping `i` to `scatter_dims_to_operand_dims[i]` . It has to be one-to-one and total.
2405`indices_are_sorted`           | `bool`              | Whether the indices are guaranteed to be sorted by the caller.
2406
2407If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
2408`scatter_indices` to have a trailing `1` dimension.
2409
2410We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
2411dimensions in `updates` shape that are not in `update_window_dims`, in ascending
2412order.
2413
2414The arguments of scatter should follow these constraints:
2415
2416-   `updates` array must be of rank `update_window_dims.size +
2417    scatter_indices.rank - 1`.
2418
2419-   Bounds of dimension `i` in `updates` must conform to the following:
2420
2421    -   If `i` is present in `update_window_dims` (i.e. equal to
2422        `update_window_dims`[`k`] for some `k`), then the bound of dimension `i`
2423        in `updates` must not exceed the corresponding bound of `operand` after
2424        accounting for the `inserted_window_dims` (i.e.
2425        `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
2426        the bounds of `operand` with the bounds at indices
2427        `inserted_window_dims` removed).
2428    -   If `i` is present in `update_scatter_dims` (i.e. equal to
2429        `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
2430        `i` in `updates` must be equal to the corresponding bound of
2431        `scatter_indices`, skipping `index_vector_dim` (i.e.
2432        `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
2433        `scatter_indices.shape.dims`[`k+1`] otherwise).
2434
2435-   `update_window_dims` must be in ascending order, not have any repeating
2436    dimension numbers, and be in the range `[0, updates.rank)`.
2437
2438-   `inserted_window_dims` must be in ascending order, not have any repeating
2439    dimension numbers, and be in the range `[0, operand.rank)`.
2440
2441-   `operand.rank` must equal the sum of `update_window_dims.size` and
2442    `inserted_window_dims.size`.
2443
2444-   `scatter_dims_to_operand_dims.size` must be equal to
2445    `scatter_indices`[`index_vector_dim`], and its values must be in the range
2446    `[0, operand.rank)`.
2447
2448For a given index `U` in the `updates` array, the corresponding index `I` in the
2449`operand` array into which this update has to be applied is computed as follows:
2450
24511.  Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
2452    an index vector `S` in the `scatter_indices` array such that `S`[`i`] =
2453    `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
2454    positions `index_vector_dim` into A.
24552.  Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
2456    `S` using the `scatter_dims_to_operand_dims` map. More formally:
2457    1.  `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
2458        `k` < `scatter_dims_to_operand_dims.size`.
2459    2.  `S`<sub>`in`</sub>[`_`] = `0` otherwise.
24603.  Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
2461    at `update_window_dims` in `U` according to `inserted_window_dims`. More
2462    formally:
2463    1.  `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if `k`
2464        is in `update_window_dims`, where `window_dims_to_operand_dims` is the
2465        monotonic function with domain [`0`, `update_window_dims.size`) and
2466        range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For example, if
2467        `update_window_dims.size` is `4`, `operand.rank` is `6`, and
2468        `inserted_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims`
2469        is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}).
2470    2.  `W`<sub>`in`</sub>[`_`] = `0` otherwise.
24714.  `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
2472    addition.
2473
2474In summary, the scatter operation can be defined as follows.
2475
2476-   Initialize `output` with `operand`, i.e. for all indices `O` in the
2477    `operand` array: \
2478    `output`[`O`] = `operand`[`O`]
2479-   For every index `U` in the `updates` array and the corresponding index `O`
2480    in the `operand` array, if `O` is a valid index for `output`: \
2481    `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`])
2482
2483The order in which updates are applied is non-deterministic. So, when multiple
2484indices in `updates` refer to the same index in `operand`, the corresponding
2485value in `output` will be non-deterministic.
2486
2487Note that the first parameter that is passed into the `update_computation` will
2488always be the current value from the `output` array and the second parameter
2489will always be the value from the `updates` array. This is important
2490specifically for cases when the `update_computation` is _not commutative_.
2491
2492If `indices_are_sorted` is set to true then XLA can assume that `start_indices`
2493are sorted (in ascending `start_index_map` order) by the user. If they are not
2494then the semantics is implementation defined.
2495
2496Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
2497the scatter op updates the elements in the input that are extracted by the
2498corresponding gather op.
2499
2500For a detailed informal description and examples, refer to the
2501"Informal Description" section under `Gather`.
2502
2503## Select
2504
2505See also
2506[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2507
2508Constructs an output array from elements of two input arrays, based on the
2509values of a predicate array.
2510
2511<b> `Select(pred, on_true, on_false)` </b>
2512
2513Arguments  | Type    | Semantics
2514---------- | ------- | ------------------
2515`pred`     | `XlaOp` | array of type PRED
2516`on_true`  | `XlaOp` | array of type T
2517`on_false` | `XlaOp` | array of type T
2518
2519The arrays `on_true` and `on_false` must have the same shape. This is also the
2520shape of the output array. The array `pred` must have the same dimensionality as
2521`on_true` and `on_false`, with the `PRED` element type.
2522
2523For each element `P` of `pred`, the corresponding element of the output array is
2524taken from `on_true` if the value of `P` is `true`, and from `on_false` if the
2525value of `P` is `false`. As a restricted form of [broadcasting](broadcasting.md),
2526`pred` can be a scalar of type `PRED`. In this case, the output array is taken
2527wholly from `on_true` if `pred` is `true`, and from `on_false` if `pred` is `false`.
2528
2529Example with non-scalar `pred`:
2530
2531```
2532let pred: PRED[4] = {true, false, false, true};
2533let v1: s32[4] = {1, 2, 3, 4};
2534let v2: s32[4] = {100, 200, 300, 400};
2535==>
2536Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
2537```
2538
2539Example with scalar `pred`:
2540
2541```
2542let pred: PRED = true;
2543let v1: s32[4] = {1, 2, 3, 4};
2544let v2: s32[4] = {100, 200, 300, 400};
2545==>
2546Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
2547```
2548
2549Selections between tuples are supported. Tuples are considered to be scalar
2550types for this purpose. If `on_true` and `on_false` are tuples (which must have
2551the same shape!) then `pred` has to be a scalar of type `PRED`.
2552
2553## SelectAndScatter
2554
2555See also
2556[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2557
2558This operation can be considered as a composite operation that first computes
2559`ReduceWindow` on the `operand` array to select an element from each window, and
2560then scatters the `source` array to the indices of the selected elements to
2561construct an output array with the same shape as the operand array. The binary
2562`select` function is used to select an element from each window by applying it
2563across each window, and it is called with the property that the first
2564parameter's index vector is lexicographically less than the second parameter's
2565index vector. The `select` function returns `true` if the first parameter is
2566selected and returns `false` if the second parameter is selected, and the
2567function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are
2568`true`, then `select(a, c)` is also `true`) so that the selected element does
2569not depend on the order of the elements traversed for a given window.
2570
2571The function `scatter` is applied at each selected index in the output array. It
2572takes two scalar parameters:
2573
25741.  Current value at the selected index in the output array
25752.  The scatter value from `source` that applies to the selected index
2576
2577It combines the two parameters and returns a scalar value that's used to update
2578the value at the selected index in the output array. Initially, all indices of
2579the output array are set to `init_value`.
2580
2581The output array has the same shape as the `operand` array and the `source`
2582array must have the same shape as the result of applying a `ReduceWindow`
2583operation on the `operand` array. `SelectAndScatter` can be used to
2584backpropagate the gradient values for a pooling layer in a neural network.
2585
2586<b>`SelectAndScatter(operand, select, window_dimensions, window_strides,
2587padding, source, init_value, scatter)`</b>
2588
2589| Arguments           | Type                | Semantics                        |
2590| ------------------- | ------------------- | -------------------------------- |
2591| `operand`           | `XlaOp`             | array of type T over which the   |
2592:                     :                     : windows slide                    :
2593| `select`            | `XlaComputation`    | binary computation of type `T, T |
2594:                     :                     : -> PRED`, to apply to all        :
2595:                     :                     : elements in each window; returns :
2596:                     :                     : `true` if the first parameter is :
2597:                     :                     : selected and returns `false` if  :
2598:                     :                     : the second parameter is selected :
2599| `window_dimensions` | `ArraySlice<int64>` | array of integers for window     |
2600:                     :                     : dimension values                 :
2601| `window_strides`    | `ArraySlice<int64>` | array of integers for window     |
2602:                     :                     : stride values                    :
2603| `padding`           | `Padding`           | padding type for window          |
2604:                     :                     : (Padding\:\:kSame or             :
2605:                     :                     : Padding\:\:kValid)               :
2606| `source`            | `XlaOp`             | array of type T with the values  |
2607:                     :                     : to scatter                       :
2608| `init_value`        | `XlaOp`             | scalar value of type T for the   |
2609:                     :                     : initial value of the output      :
2610:                     :                     : array                            :
2611| `scatter`           | `XlaComputation`    | binary computation of type `T, T |
2612:                     :                     : -> T`, to apply each scatter     :
2613:                     :                     : source element with its          :
2614:                     :                     : destination element              :
2615
2616The figure below shows examples of using `SelectAndScatter`, with the `select`
2617function computing the maximal value among its parameters. Note that when the
2618windows overlap, as in the figure (2) below, an index of the `operand` array may
2619be selected multiple times by different windows. In the figure, the element of
2620value 9 is selected by both of the top windows (blue and red) and the binary
2621addition `scatter` function produces the output element of value 8 (2 + 6).
2622
2623<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2624  <img style="width:100%"
2625    src="./images/ops_scatter_to_selected_window_element.png">
2626</div>
2627
2628The evaluation order of the `scatter` function is arbitrary and may be
2629non-deterministic. Therefore, the `scatter` function should not be overly
2630sensitive to reassociation. See the discussion about associativity in the
2631context of [`Reduce`](#reduce) for more details.
2632
2633## Send
2634
2635See also
2636[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2637
2638<b> `Send(operand, channel_handle)` </b>
2639
2640Arguments        | Type            | Semantics
2641---------------- | --------------- | -----------------------------------------
2642`operand`        | `XlaOp`         | data to send (array of type T)
2643`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair
2644
2645Sends the given operand data to a `Recv` instruction in another computation
2646that shares the same channel handle. Does not return any data.
2647
2648Similar to the `Recv` operation, the client API of `Send` operation represents
2649synchronous communication, and is internally decomposed into 2 HLO instructions
2650(`Send` and `SendDone`) to enable asynchronous data transfers. See also
2651[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
2652
2653<b>`Send(HloInstruction operand, int64 channel_id)`</b>
2654
2655Initiates an asynchronous transfer of the operand to the resources allocated by
2656the `Recv` instruction with the same channel id. Returns a context, which is
2657used by a following `SendDone` instruction to wait for the completion of the
2658data transfer. The context is a tuple of {operand (shape), request identifier
2659(U32)} and it can only be used by a `SendDone` instruction.
2660
2661<b> `SendDone(HloInstruction context)` </b>
2662
2663Given a context created by a `Send` instruction, waits for the data transfer to
2664complete.  The instruction does not return any data.
2665
2666<b> Scheduling of channel instructions </b>
2667
2668The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
2669`Send`, `SendDone`) is as below.
2670
2671<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2672  <img style="width:70%" src="./images/send_recv_order.png">
2673</div>
2674
2675* `Recv` happens before `Send`
2676* `Send` happens before `RecvDone`
2677* `Recv` happens before `RecvDone`
2678* `Send` happens before `SendDone`
2679
2680When the backend compilers generate a linear schedule for each computation that
2681communicates via channel instructions, there must not be cycles across the
2682computations. For example, below schedules lead to deadlocks.
2683
2684<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2685  <img style="width:100%" src="./images/send_recv_schedule.png">
2686</div>
2687
2688## Slice
2689
2690See also
2691[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2692
2693Slicing extracts a sub-array from the input array. The sub-array is of the same
2694rank as the input and contains the values inside a bounding box within the input
2695array where the dimensions and indices of the bounding box are given as
2696arguments to the slice operation.
2697
2698<b> `Slice(operand, start_indices, limit_indices)` </b>
2699
2700| Arguments       | Type                | Semantics                            |
2701| --------------- | ------------------- | ------------------------------------ |
2702| `operand`       | `XlaOp`             | N dimensional array of type T        |
2703| `start_indices` | `ArraySlice<int64>` | List of N integers containing the    |
2704:                 :                     : starting indices of the slice for    :
2705:                 :                     : each dimension. Values must be       :
2706:                 :                     : greater than or equal to zero.       :
2707| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the    |
2708:                 :                     : ending indices (exclusive) for the   :
2709:                 :                     : slice for each dimension. Each value :
2710:                 :                     : must be greater than or equal to the :
2711:                 :                     : respective `start_indices` value for :
2712:                 :                     : the dimension and less than or equal :
2713:                 :                     : to the size of the dimension.        :
2714| `strides`      | `ArraySlice<int64>` | List of N integers that decides the   |
2715:                 :                     : input stride of the slice.  The slice :
2716:                 :                     : picks every `strides[d]` element in  :
2717:                 :                     : dimension `d`.                       :
2718
2719
27201-dimensional example:
2721
2722```
2723let a = {0.0, 1.0, 2.0, 3.0, 4.0}
2724Slice(a, {2}, {4}) produces:
2725  {2.0, 3.0}
2726```
2727
27282-dimensional example:
2729
2730```
2731let b =
2732 { {0.0,  1.0,  2.0},
2733   {3.0,  4.0,  5.0},
2734   {6.0,  7.0,  8.0},
2735   {9.0, 10.0, 11.0} }
2736
2737Slice(b, {2, 1}, {4, 3}) produces:
2738  { { 7.0,  8.0},
2739    {10.0, 11.0} }
2740```
2741
2742## Sort
2743
2744See also
2745[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2746
2747<b>`Sort(operands, comparator, dimension, is_stable)`</b>
2748
2749Arguments    | Type                | Semantics
2750------------ | ------------------- | --------------------
2751`operands`   | `ArraySlice<XlaOp>` | The operands to sort.
2752`comparator` | `XlaComputation`    | The comparator computation to use.
2753`dimension`  | `int64`             | The dimension along which to sort.
2754`is_stable`  | `bool`              | Whether stable sorting should be used.
2755
2756If only one operand is provided:
2757
2758* If the operand is a rank-1 tensor (an array), the result is a sorted array.
2759  If you want to sort the array into ascending order, the comparator should
2760  perform a less-than comparison. Formally, after the array is sorted, it holds
2761  for all index positions `i, j` with `i < j` that either
2762  `comparator(value[i], value[j]) = comparator(value[j], value[i]) = false` or
2763  `comparator(value[i], value[j]) = true`.
2764
2765* If the operand has higher rank, the operand is sorted along the provided
2766  dimension. For example, for a rank-2 tensor (a matrix), a dimension value of
2767  `0` will independently sort every column, and a dimension value of `1` will
2768  independently sort each row. If no dimension number is provided, then the last
2769  dimension is chosen by default. For the dimension which is sorted, the same
2770  sorting order applies as in the rank-1 case.
2771
2772If `n > 1` operands are provided:
2773
2774* All `n` operands must be tensors with the same dimensions. The element types
2775  of the tensors may be different.
2776
2777* All operands are sorted together, not individually. Conceptually the operands
2778  are treated as a tuple. When checking whether the elements of each operand at
2779  index positions `i` and `j` need to be swapped, the comparator is called with
2780  `2 * n` scalar parameters, where parameter `2 * k` corresponds to the value at
2781  position `i` from the `k-th` operand, and parameter `2 * k + 1` corresponds to
2782  the value at position `j` from the `k-th` operand. Usually, the comparator
2783  would thus compare parameters `2 * k` and `2 * k + 1` with each other and
2784  possibly use other parameter pairs as tie breakers.
2785
2786* The result is a tuple that consists of the operands in sorted order (along
2787  the provided dimension, as above). The `i-th` operand of the tuple corresponds
2788  to the `i-th` operand of Sort.
2789
2790For example, if there are three operands `operand0 = [3, 1]`,
2791`operand1 = [42, 50]`, `operand2 = [-3.0, 1.1]`, and the comparator compares
2792only the values of `operand0` with less-than, then the output of the sort is the
2793tuple `([1, 3], [50, 42], [1.1, -3.0])`.
2794
2795If `is_stable` is set to true, the sort is guaranteed to be stable, that is, if
2796there are elements which are considered to be equal by the comparator, the
2797relative order of the equal values is preserved. By default, `is_stable` is set
2798to false.
2799
2800## Transpose
2801
2802See also the `tf.reshape` operation.
2803
2804<b>`Transpose(operand)`</b>
2805
2806Arguments     | Type                | Semantics
2807------------- | ------------------- | ------------------------------
2808`operand`     | `XlaOp`             | The operand to transpose.
2809`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
2810
2811
2812Permutes the operand dimensions with the given permutation, so
2813`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`.
2814
2815This is the same as Reshape(operand, permutation,
2816                            Permute(permutation, operand.shape.dimensions)).
2817
2818## TriangularSolve
2819
2820See also
2821[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2822
2823Solves systems of linear equations with lower or upper triangular coefficient
2824matrices by forward- or back-substitution. Broadcasting along leading
2825dimensions, this routine solves one of the matrix systems `op(a) * x =
2826b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is
2827either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`.
2828
2829<b> `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` </b>
2830
2831| Arguments       | Type        | Semantics                                    |
2832| --------------- | ----------- | -------------------------------------------- |
2833| `a`             | `XlaOp`     | a rank > 2 array of a complex or             |
2834:                 :             : floating-point type with shape `[..., M,     :
2835:                 :             : M]`.                                         :
2836| `b`             | `XlaOp`     | a rank > 2 array of the same type with shape |
2837:                 :             : `[..., M, K]` if `left_side` is true, `[..., :
2838:                 :             : K, M]` otherwise.                            :
2839| `left_side`     | `bool`      | indicates whether to solve a system of the   |
2840:                 :             : form `op(a) * x = b` (`true`) or `x *        :
2841:                 :             : op(a) = b` (`false`).                        :
2842| `lower`         | `bool`      | whether to use the upper or lower triangle   |
2843:                 :             : of `a`.                                      :
2844| `unit_diagonal` | `bool`      | if `true`, the diagonal elements of `a` are  |
2845:                 :             : assumed to be `1` and not accessed.          :
2846| `transpose_a`   | `Transpose` | whether to use `a` as is, transpose it or    |
2847:                 :             : take its conjugate transpose.                :
2848
2849Input data is read only from the lower/upper triangle of `a`, depending on the
2850value of `lower`. Values from the other triangle are ignored. Output data is
2851returned in the same triangle; the values in the other triangle are
2852implementation-defined and may be anything.
2853
2854If the rank of `a` and `b` are greater than 2, they are treated as batches of
2855matrices, where all except the minor 2 dimensions are batch dimensions. `a` and
2856`b` must have equal batch dimensions.
2857
2858## Tuple
2859
2860See also
2861[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2862
2863A tuple containing a variable number of data handles, each of which has its own
2864shape.
2865
2866This is analogous to `std::tuple` in C++. Conceptually:
2867
2868```
2869let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
2870let s: s32 = 5;
2871let t: (f32[10], s32) = tuple(v, s);
2872```
2873
2874Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
2875(#gettupleelement) operation.
2876
2877## While
2878
2879See also
2880[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
2881
2882<b> `While(condition, body, init)` </b>
2883
2884| Arguments   | Type             | Semantics                                |
2885| ----------- | ---------------- | ---------------------------------------- |
2886| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which |
2887:             :                  : defines the termination condition of the :
2888:             :                  : loop.                                    :
2889| `body`      | `XlaComputation` | XlaComputation of type `T -> T` which    |
2890:             :                  : defines the body of the loop.            :
2891| `init`      | `T`              | Initial value for the parameter of       |
2892:             :                  : `condition` and `body`.                  :
2893
2894Sequentially executes the `body` until the `condition` fails. This is similar to
2895a typical while loop in many other languages except for the differences and
2896restrictions listed below.
2897
2898*   A `While` node returns a value of type `T`, which is the result from the
2899    last execution of the `body`.
2900*   The shape of the type `T` is statically determined and must be the same
2901    across all iterations.
2902
2903The T parameters of the computations are initialized with the `init` value in
2904the first iteration and are automatically updated to the new result from `body`
2905in each subsequent iteration.
2906
2907One main use case of the `While` node is to implement the repeated execution of
2908training in neural networks. Simplified pseudocode is shown below with a graph
2909that represents the computation. The code can be found in
2910[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc).
2911The type `T` in this example is a `Tuple` consisting of an `int32` for the
2912iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the
2913loop keeps adding a constant vector to the accumulator.
2914
2915```
2916// Pseudocode for the computation.
2917init = {0, zero_vector[10]} // Tuple of int32 and float[10].
2918result = init;
2919while (result(0) < 1000) {
2920  iteration = result(0) + 1;
2921  new_vector = result(1) + constant_vector[10];
2922  result = {iteration, new_vector};
2923}
2924```
2925
2926<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
2927  <img style="width:100%" src="./images/ops_while.png">
2928</div>
2929