1# Operation Semantics 2 3The following describes the semantics of operations defined in the 4[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 5interface. Typically, these operations map one-to-one to operations defined in 6the RPC interface in 7[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto). 8 9A note on nomenclature: the generalized data type XLA deals with is an 10N-dimensional array holding elements of some uniform type (such as 32-bit 11float). Throughout the documentation, *array* is used to denote an 12arbitrary-dimensional array. For convenience, special cases have more specific 13and familiar names; for example a *vector* is a 1-dimensional array and a 14*matrix* is a 2-dimensional array. 15 16## AfterAll 17 18See also 19[`XlaBuilder::AfterAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 20 21AfterAll takes a variadic number of tokens and produces a single token. Tokens 22are primitive types which can be threaded between side-effecting operations to 23enforce ordering. `AfterAll` can be used as a join of tokens for ordering a 24operation after a set operations. 25 26<b> `AfterAll(operands)` </b> 27 28Arguments | Type | Semantics 29---------- | ------- | ------------------------- 30`operands` | `XlaOp` | variadic number of tokens 31 32## AllGather 33 34See also 35[`XlaBuilder::AllGather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 36 37Performs concatenation across replicas. 38 39<b> `AllGather(operand, all_gather_dim, shard_count, replica_group_ids, 40channel_id)` </b> 41 42| Arguments | Type | Semantics | 43| ---------------- | -------------------- | --------------------------- | 44| `operand` | `XlaOp` | Array to concatenate across | 45: : : replicas. : 46| `all_gather_dim` | `int64` | Concatenation dimension. | 47| `replica_groups` | vector of vectors of | Groups between which the | 48: : `int64` : concatenation is performed. : 49| `channel_id` | optional `int64` | Optional channel ID for | 50: : : cross-module communication. : 51 52- `replica_groups` is a list of replica groups between which the concatenation 53 is performed (replica id for the current replica can be retrieved using 54 [`ReplicaId`](#replicaid)). The order of replicas in each group determines 55 the order in which their inputs are located in the result. `replica_groups` 56 must either be empty (in which case all replicas belong to a single group, 57 ordered from `0` to `N - 1`), or contain the same number of elements as the 58 number of replicas. For example, `replica_groups = {0, 2}, {1, 3}` performs 59 concatenation between the replicas `0` and `2`, and `1` and `3`. 60- `shard_count` is the size of each replica group. We need this in cases where 61 `replica_groups` are empty. 62- `channel_id` is used for cross-module communication: only `all-gather` 63 operations with the same `channel_id` can communicate to each other. 64 65The output shape is the input shape with the `all_gather_dim` made `shard_count` 66times larger. For example, if there are two replicas and the operand has the 67value `[1.0, 2.5]` and `[3.0, 5.25]` respectively on the two replicas, then the 68output value from this op where `all_gather_dim` is `0` will be `[1.0, 2.5, 3.0, 695.25]` on both replicas. 70 71## AllReduce 72 73See also 74[`XlaBuilder::AllReduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 75 76Performs a custom computation across replicas. 77 78<b> `AllReduce(operand, computation, replica_group_ids, channel_id)` </b> 79 80| Arguments | Type | Semantics | 81| ---------------- | -------------------- | --------------------------------- | 82| `operand` | `XlaOp` | Array or a non-empty tuple of | 83: : : arrays to reduce across replicas. : 84| `computation` | `XlaComputation` | Reduction computation | 85| `replica_groups` | vector of vectors of | Groups between which the | 86: : `int64` : reductions are performed : 87| `channel_id` | optional `int64` | Optional channel ID for | 88: : : cross-module communication : 89 90- When `operand` is a tuple of arrays, the all-reduce is performed on each 91 element of the tuple. 92- `replica_groups` is a list of replica groups between which the reduction is 93 performed (replica id for the current replica can be retrieved using 94 [`ReplicaId`](#replicaid)). `replica_groups` must either be empty (in which 95 case all replicas belong to a single group), or contain the same number of 96 elements as the number of replicas. For example, `replica_groups = {0, 2}, 97 {1, 3}` performs reduction between the replicas `0` and `2`, and `1` and 98 `3`. 99- `channel_id` is used for cross-module communication: only `all-reduce` 100 operations with the same `channel_id` can communicate to each other. 101 102The output shape is the same as the input shape. For example, if there are two 103replicas and the operand has the value `[1.0, 2.5]` and `[3.0, 5.25]` 104respectively on the two replicas, then the output value from this op and 105summation computation will be `[4.0, 7.75]` on both replicas. If the input is a 106tuple, the output is a tuple as well. 107 108Computing the result of `AllReduce` requires having one input from each replica, 109so if one replica executes a `AllReduce` node more times than another, then the 110former replica will wait forever. Since the replicas are all running the same 111program, there are not a lot of ways for that to happen, but it is possible when 112a while loop's condition depends on data from infeed and the data that is infed 113causes the while loop to iterate more times on one replica than another. 114 115## AllToAll 116 117See also 118[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 119 120AllToAll is a collective operation that sends data from all cores to all cores. 121It has two phases: 122 1231. The scatter phase. On each core, the operand is split into `split_count` 124 number of blocks along the `split_dimensions`, and the blocks are scattered 125 to all cores, e.g., the ith block is send to the ith core. 1262. The gather phase. Each core concatenates the received blocks along the 127 `concat_dimension`. 128 129The participating cores can be configured by: 130 131- `replica_groups`: each ReplicaGroup contains a list of replica id 132 participating in the computation (replica id for the current replica can be 133 retrieved using [`ReplicaId`](#replicaid)). AllToAll will be applied within 134 subgroups in the specified order. For example, `replica_groups = {{1,2,3}, 135 {4,5,0}}` means that an AllToAll will be applied within replicas `{1, 2, 136 3}`, and in the gather phase, and the received blocks will be concatenated 137 in the same order of 1, 2, 3. Then, another AllToAll will be applied within 138 replicas 4, 5, 0, and the concatenation order is also 4, 5, 0. If 139 `replica_groups` is empty, all replicas belong to one group, in the 140 concatenation order of their appearance. 141 142Prerequisites: 143 144- The dimension size of the operand on the `split_dimension` is divisible by 145`split_count`. 146- The operand's shape is not tuple. 147 148<b> `AllToAll(operand, split_dimension, concat_dimension, split_count, 149replica_groups)` </b> 150 151 152| Arguments | Type | Semantics | 153| ------------------ | --------------------- | ------------------------------- | 154| `operand` | `XlaOp` | n dimensional input array | 155| `split_dimension` | `int64` | A value in the interval `[0, | 156: : : n)` that names the dimension : 157: : : along which the operand is : 158: : : split : 159| `concat_dimension` | `int64` | a value in the interval `[0, | 160: : : n)` that names the dimension : 161: : : along which the split blocks : 162: : : are concatenated : 163| `split_count` | `int64` | the number of cores that | 164: : : participate this operation. If : 165: : : `replica_groups` is empty, this : 166: : : should be the number of : 167: : : replicas; otherwise, this : 168: : : should be equal to the number : 169: : : of replicas in each group. : 170| `replica_groups` | `ReplicaGroup` vector | each group contains a list of | 171: : : replica id. : 172 173Below shows an example of Alltoall. 174 175``` 176XlaBuilder b("alltoall"); 177auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x"); 178AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4); 179``` 180 181<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 182<img style="width:100%" src="./images/ops_alltoall.png"> 183</div> 184 185In this example, there are 4 cores participating the Alltoall. On each core, the 186operand is split into 4 parts along dimension 0, so each part has shape 187f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates 188the received parts along dimension 1, in the order or core 0-4. So the output on 189each core has shape f32[16,4]. 190 191## BatchNormGrad 192 193See also 194[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 195and [the original batch normalization paper](https://arxiv.org/abs/1502.03167) 196for a detailed description of the algorithm. 197 198Calculates gradients of batch norm. 199 200<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b> 201 202| Arguments | Type | Semantics | 203| --------------- | ----------------------- | -------------------------------- | 204| `operand` | `XlaOp` | n dimensional array to be | 205: : : normalized (x) : 206| `scale` | `XlaOp` | 1 dimensional array | 207: : : (\\(\gamma\\)) : 208| `mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) | 209| `variance` | `XlaOp` | 1 dimensional array | 210: : : (\\(\sigma^2\\)) : 211| `grad_output` | `XlaOp` | Gradients passed to | 212: : : `BatchNormTraining` : 213: : : (\\( \nabla y\\)) : 214| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) | 215| `feature_index` | `int64` | Index to feature dimension in | 216: : : `operand` : 217 218For each feature in the feature dimension (`feature_index` is the index for the 219feature dimension in `operand`), the operation calculates the gradients with 220respect to `operand`, `offset` and `scale` across all the other dimensions. The 221`feature_index` must be a valid index for the feature dimension in `operand`. 222 223The three gradients are defined by the following formulas (assuming a 2244-dimensional array as `operand` and with feature dimension index `l`, batch 225size `m` and spatial sizes `w` and `h`): 226 227\\[ \begin{split} c_l&= 228\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h 229\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right) 230\\\\ 231\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}} 232\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l}) 233\right) 234\\\\ 235\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl} 236\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right) 237\\\\\ 238\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} 239\end{split} \\] 240 241The inputs `mean` and `variance` represent moments value 242across batch and spatial dimensions. 243 244The output type is a tuple of three handles: 245 246| Outputs | Type | Semantics | 247| ------------- | ----------------------- | --------------------------------- | 248| `grad_operand` | `XlaOp` | gradient with respect to input | 249: : : `operand` (\\( \nabla x\\)) : 250| `grad_scale` | `XlaOp` | gradient with respect to input | 251: : : `scale` (\\( \nabla \gamma\\)) : 252| `grad_offset` | `XlaOp` | gradient with respect to input | 253: : : `offset`(\\( \nabla \beta\\)) : 254 255## BatchNormInference 256 257See also 258[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 259and [the original batch normalization paper](https://arxiv.org/abs/1502.03167) 260for a detailed description of the algorithm. 261 262Normalizes an array across batch and spatial dimensions. 263 264<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b> 265 266Arguments | Type | Semantics 267--------------- | ------- | --------------------------------------- 268`operand` | `XlaOp` | n dimensional array to be normalized 269`scale` | `XlaOp` | 1 dimensional array 270`offset` | `XlaOp` | 1 dimensional array 271`mean` | `XlaOp` | 1 dimensional array 272`variance` | `XlaOp` | 1 dimensional array 273`epsilon` | `float` | Epsilon value 274`feature_index` | `int64` | Index to feature dimension in `operand` 275 276For each feature in the feature dimension (`feature_index` is the index for the 277feature dimension in `operand`), the operation calculates the mean and variance 278across all the other dimensions and uses the mean and variance to normalize each 279element in `operand`. The `feature_index` must be a valid index for the feature 280dimension in `operand`. 281 282`BatchNormInference` is equivalent to calling `BatchNormTraining` without 283computing `mean` and `variance` for each batch. It uses the input `mean` and 284`variance` instead as estimated values. The purpose of this op is to reduce 285latency in inference, hence the name `BatchNormInference`. 286 287The output is an n-dimensional, normalized array with the same shape as input 288`operand`. 289 290## BatchNormTraining 291 292See also 293[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 294and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167) 295for a detailed description of the algorithm. 296 297Normalizes an array across batch and spatial dimensions. 298 299<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b> 300 301Arguments | Type | Semantics 302--------------- | ------- | ---------------------------------------- 303`operand` | `XlaOp` | n dimensional array to be normalized (x) 304`scale` | `XlaOp` | 1 dimensional array (\\(\gamma\\)) 305`offset` | `XlaOp` | 1 dimensional array (\\(\beta\\)) 306`epsilon` | `float` | Epsilon value (\\(\epsilon\\)) 307`feature_index` | `int64` | Index to feature dimension in `operand` 308 309For each feature in the feature dimension (`feature_index` is the index for the 310feature dimension in `operand`), the operation calculates the mean and variance 311across all the other dimensions and uses the mean and variance to normalize each 312element in `operand`. The `feature_index` must be a valid index for the feature 313dimension in `operand`. 314 315The algorithm goes as follows for each batch in `operand` \\(x\\) that 316contains `m` elements with `w` and `h` as the size of spatial dimensions 317(assuming `operand` is an 4 dimensional array): 318 319- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension: 320\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\) 321 322- Calculates batch variance \\(\sigma^2_l\\): 323\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\) 324 325- Normalizes, scales and shifts: 326\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\) 327 328The epsilon value, usually a small number, is added to avoid divide-by-zero errors. 329 330The output type is a tuple of three `XlaOp`s: 331 332| Outputs | Type | Semantics | 333| ------------ | ----------------------- | -------------------------------------| 334| `output` | `XlaOp` | n dimensional array with the same | 335: : : shape as input `operand` (y) : 336| `batch_mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) | 337| `batch_var` | `XlaOp` | 1 dimensional array (\\(\sigma^2\\)) | 338 339The `batch_mean` and `batch_var` are moments calculated across the batch and 340spatial dimensions using the formulas above. 341 342## BitcastConvertType 343 344See also 345[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 346 347Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast 348operation from a data shape to a target shape. The dimensions must match, and 349the conversion is an element-wise one; e.g. `s32` elements become `f32` elements 350via bitcast routine. Bitcast is implemented as a low-level cast, so machines 351with different floating-point representations will give different results. 352 353<b> `BitcastConvertType(operand, new_element_type)` </b> 354 355Arguments | Type | Semantics 356------------------ | --------------- | --------------------------- 357`operand` | `XlaOp` | array of type T with dims D 358`new_element_type` | `PrimitiveType` | type U 359 360The dimensions of the operand and the target shape must match. The bit-width of 361the source and destination element types must be equal. The source 362and destination element types must not be tuples. 363 364## Broadcast 365 366See also 367[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 368 369Adds dimensions to an array by duplicating the data in the array. 370 371<b> `Broadcast(operand, broadcast_sizes)` </b> 372 373Arguments | Type | Semantics 374----------------- | ------------------- | ------------------------------- 375`operand` | `XlaOp` | The array to duplicate 376`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions 377 378The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has 379values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then 380the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`. 381 382The new dimensions index into copies of the operand, i.e. 383 384``` 385output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM] 386``` 387 388For example, if `operand` is a scalar `f32` with value `2.0f`, and 389`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape 390`f32[2, 3]` and all the values in the result will be `2.0f`. 391 392## BroadcastInDim 393 394See also 395[`XlaBuilder::BroadcastInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 396 397Expands the size and rank of an array by duplicating the data in the array. 398 399<b> `BroadcastInDim(operand, out_dim_size, broadcast_dimensions)` </b> 400 401| Arguments | Type | Semantics | 402| ---------------------- | ------------------- | ----------------------------- | 403| `operand` | `XlaOp` | The array to duplicate | 404| `out_dim_size` | `ArraySlice<int64>` | The sizes of the dimensions | 405: : : of the target shape : 406| `broadcast_dimensions` | `ArraySlice<int64>` | Which dimension in the target | 407: : : shape each dimension of the : 408: : : operand shape corresponds to : 409 410Similar to Broadcast, but allows adding dimensions anywhere and expanding 411existing dimensions with size 1. 412 413The `operand` is broadcast to the shape described by `out_dim_size`. 414`broadcast_dimensions` maps the dimensions of `operand` to the dimensions of the 415target shape, i.e. the i'th dimension of the operand is mapped to the 416broadcast_dimension\[i\]'th dimension of the output shape. The dimensions of 417`operand` must have size 1 or be the same size as the dimension in the output 418shape they are mapped to. The remaining dimensions are filled with dimensions of 419size 1. Degenerate-dimension broadcasting then broadcasts along these degenerate 420dimensions to reach the output shape. The semantics are described in detail on 421the [broadcasting page](broadcasting.md). 422 423## Call 424 425See also 426[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 427 428Invokes a computation with the given arguments. 429 430<b> `Call(computation, args...)` </b> 431 432| Arguments | Type | Semantics | 433| ------------- | ---------------------- | ----------------------------------- | 434| `computation` | `XlaComputation` | computation of type `T_0, T_1, ..., | 435: : : T_{N-1} -> S` with N parameters of : 436: : : arbitrary type : 437| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type | 438 439The arity and types of the `args` must match the parameters of the 440`computation`. It is allowed to have no `args`. 441 442## Cholesky 443 444See also 445[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 446 447Computes the 448[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition) 449of a batch of symmetric (Hermitian) positive definite matrices. 450 451<b> `Cholesky(a, lower)` </b> 452 453Arguments | Type | Semantics 454--------- | ------- | ----------------------------------------------------- 455`a` | `XlaOp` | a rank > 2 array of a complex or floating-point type. 456`lower` | `bool` | whether to use the upper or lower triangle of `a`. 457 458If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l 459. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such 460that $$ a = u^T . u $$. 461 462Input data is read only from the lower/upper triangle of `a`, depending on the 463value of `lower`. Values from the other triangle are ignored. Output data is 464returned in the same triangle; the values in the other triangle are 465implementation-defined and may be anything. 466 467If the rank of `a` is greater than 2, `a` is treated as a batch of matrices, 468where all except the minor 2 dimensions are batch dimensions. 469 470If `a` is not symmetric (Hermitian) positive definite, the result is 471implementation-defined. 472 473## Clamp 474 475See also 476[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 477 478Clamps an operand to within the range between a minimum and maximum value. 479 480<b> `Clamp(min, operand, max)` </b> 481 482Arguments | Type | Semantics 483--------- | ------- | --------------- 484`min` | `XlaOp` | array of type T 485`operand` | `XlaOp` | array of type T 486`max` | `XlaOp` | array of type T 487 488Given an operand and minimum and maximum values, returns the operand if it is in 489the range between the minimum and maximum, else returns the minimum value if the 490operand is below this range or the maximum value if the operand is above this 491range. That is, `clamp(a, x, b) = min(max(a, x), b)`. 492 493All three arrays must be the same shape. Alternatively, as a restricted form of 494[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`. 495 496Example with scalar `min` and `max`: 497 498``` 499let operand: s32[3] = {-1, 5, 9}; 500let min: s32 = 0; 501let max: s32 = 6; 502==> 503Clamp(min, operand, max) = s32[3]{0, 5, 6}; 504``` 505 506## Collapse 507 508See also 509[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 510and the `tf.reshape` operation. 511 512Collapses dimensions of an array into one dimension. 513 514<b> `Collapse(operand, dimensions)` </b> 515 516Arguments | Type | Semantics 517------------ | -------------- | ----------------------------------------------- 518`operand` | `XlaOp` | array of type T 519`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions. 520 521Collapse replaces the given subset of the operand's dimensions by a single 522dimension. The input arguments are an arbitrary array of type T and a 523compile-time-constant vector of dimension indices. The dimension indices must be 524an in-order (low to high dimension numbers), consecutive subset of T's 525dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but 526{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the 527same position in the dimension sequence as those they replace, with the new 528dimension size equal to the product of original dimension sizes. The lowest 529dimension number in `dimensions` is the slowest varying dimension (most major) 530in the loop nest which collapses these dimension, and the highest dimension 531number is fastest varying (most minor). See the `tf.reshape` operator 532if more general collapse ordering is needed. 533 534For example, let v be an array of 24 elements: 535 536``` 537let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, 538{{20, 21, 22}, {25, 26, 27}}, 539{{30, 31, 32}, {35, 36, 37}}, 540{{40, 41, 42}, {45, 46, 47}}}; 541 542// Collapse to a single dimension, leaving one dimension. 543let v012 = Collapse(v, {0,1,2}); 544then v012 == f32[24] {10, 11, 12, 15, 16, 17, 54520, 21, 22, 25, 26, 27, 54630, 31, 32, 35, 36, 37, 54740, 41, 42, 45, 46, 47}; 548 549// Collapse the two lower dimensions, leaving two dimensions. 550let v01 = Collapse(v, {0,1}); 551then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17}, 552{20, 21, 22, 25, 26, 27}, 553{30, 31, 32, 35, 36, 37}, 554{40, 41, 42, 45, 46, 47}}; 555 556// Collapse the two higher dimensions, leaving two dimensions. 557let v12 = Collapse(v, {1,2}); 558then v12 == f32[8x3] {{10, 11, 12}, 559{15, 16, 17}, 560{20, 21, 22}, 561{25, 26, 27}, 562{30, 31, 32}, 563{35, 36, 37}, 564{40, 41, 42}, 565{45, 46, 47}}; 566 567``` 568 569## CollectivePermute 570 571See also 572[`XlaBuilder::CollectivePermute`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 573 574CollectivePermute is a collective operation that sends and receives data cross 575replicas. 576 577<b> `CollectivePermute(operand, source_target_pairs)` </b> 578 579| Arguments | Type | Semantics | 580| --------------------- | ----------------------- | -------------------------- | 581| `operand` | `XlaOp` | n dimensional input array | 582| `source_target_pairs` | `<int64, int64>` vector | A list of | 583: : : (source_replica_id, : 584: : : target_replica_id) pairs. : 585: : : For each pair, the operand : 586: : : is sent from source : 587: : : replica to target replica. : 588 589Note that there are the following restrictions on the `source_target_pair`: 590 591- Any two pairs should not have the same target replica id, and they should 592not have the same source replica id. 593- If a replica id is not a target in any pair, then the output on that replica 594is a tensor consists of 0(s) with the same shape as the input. 595 596## Concatenate 597 598See also 599[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 600 601Concatenate composes an array from multiple array operands. The array is of the 602same rank as each of the input array operands (which must be of the same rank as 603each other) and contains the arguments in the order that they were specified. 604 605<b> `Concatenate(operands..., dimension)` </b> 606 607| Arguments | Type | Semantics | 608| ----------- | --------------------- | -------------------------------------- | 609| `operands` | sequence of N `XlaOp` | N arrays of type T with dimensions | 610: : : [L0, L1, ...]. Requires N >= 1. : 611| `dimension` | `int64` | A value in the interval `[0, N)` that | 612: : : names the dimension to be concatenated : 613: : : between the `operands`. : 614 615With the exception of `dimension` all dimensions must be the same. This is 616because XLA does not support "ragged" arrays. Also note that rank-0 values 617cannot be concatenated (as it's impossible to name the dimension along which the 618concatenation occurs). 619 6201-dimensional example: 621 622``` 623Concat({{2, 3}, {4, 5}, {6, 7}}, 0) 624>>> {2, 3, 4, 5, 6, 7} 625``` 626 6272-dimensional example: 628 629``` 630let a = { 631{1, 2}, 632{3, 4}, 633{5, 6}, 634}; 635let b = { 636{7, 8}, 637}; 638Concat({a, b}, 0) 639>>> { 640{1, 2}, 641{3, 4}, 642{5, 6}, 643{7, 8}, 644} 645``` 646 647Diagram: 648<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 649<img style="width:100%" src="./images/ops_concatenate.png"> 650</div> 651 652## Conditional 653 654See also 655[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 656 657<b> `Conditional(pred, true_operand, true_computation, false_operand, 658false_computation)` </b> 659 660<!-- mdformat off(disable mdformat for proper MathJax formatting) --> 661 662Arguments | Type | Semantics 663------------------- | ---------------- | ------------------------------------ 664`pred` | `XlaOp` | Scalar of type `PRED` 665`true_operand` | `XlaOp` | Argument of type \\(T_0\\) 666`true_computation` | `XlaComputation` | XlaComputation of type \\(T_0 \to S\\) 667`false_operand` | `XlaOp` | Argument of type \\(T_1\\) 668`false_computation` | `XlaComputation` | XlaComputation of type \\(T_1 \to S\\) 669 670Executes `true_computation` if `pred` is `true`, `false_computation` if `pred` 671is `false`, and returns the result. 672 673The `true_computation` must take in a single argument of type \\(T_0\\) and will 674be invoked with `true_operand` which must be of the same type. The 675`false_computation` must take in a single argument of type \\(T_1\\) and will be 676invoked with `false_operand` which must be of the same type. The type of the 677returned value of `true_computation` and `false_computation` must be the same. 678 679<!-- mdformat on --> 680 681Note that only one of `true_computation` and `false_computation` will be 682executed depending on the value of `pred`. 683 684<b> `Conditional(branch_index, branch_computations, branch_operands)` </b> 685 686<!-- mdformat off(disable mdformat for proper MathJax formatting) --> 687 688| Arguments | Type | Semantics | 689| --------------------- | --------------------- | ---------------------------- | 690| `branch_index` | `XlaOp` | Scalar of type `S32` | 691| `branch_computations` | sequence of N | XlaComputations of type \\( | 692: : `XlaComputation` : T_0 \to S , T_1 \to S , ..., : 693: : : T_{N-1} \to S \\) : 694| `branch_operands` | sequence of N `XlaOp` | Arguments of type \\( T_0 , | 695: : : T_1 , ..., T_{N-1} \\) : 696 697<!-- mdformat on --> 698 699Executes `branch_computations[branch_index]`, and returns the result. If 700`branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]` 701is executed as the default branch. 702 703Each `branch_computations[b]` must take in a single argument of type `T_b` and 704will be invoked with `branch_operands[b]` which must be of the same type. The 705type of the returned value of each `branch_computations[b]` must be the same. 706 707Note that only one of the `branch_computations` will be executed depending on 708the value of `branch_index`. 709 710## Conv (convolution) 711 712See also 713[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 714 715As ConvWithGeneralPadding, but the padding is specified in a short-hand way as 716either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that 717the output has the same shape as the input when not taking striding into 718account. VALID padding simply means no padding. 719 720## ConvWithGeneralPadding (convolution) 721 722See also 723[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 724 725Computes a convolution of the kind used in neural networks. Here, a convolution 726can be thought of as a n-dimensional window moving across a n-dimensional base 727area and a computation is performed for each possible position of the window. 728 729| Arguments | Type | Semantics | 730| --------------------- | ------------------------ | ------------------------ | 731| `lhs` | `XlaOp` | rank n+2 array of inputs | 732| `rhs` | `XlaOp` | rank n+2 array of kernel | 733: : : weights : 734| `window_strides` | `ArraySlice<int64>` | n-d array of kernel | 735: : : strides : 736| `padding` | `ArraySlice< pair<int64, | n-d array of (low, high) | 737: : int64>>` : padding : 738| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor | 739: : : array : 740| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor | 741: : : array : 742| `feature_group_count` | int64 | the number of feature | 743: : : groups : 744| `batch_group_count` | int64 | the number of batch | 745: : : groups : 746 747Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 748array describing the base area. This is called the input, even though of course 749the rhs is also an input. In a neural network, these are the input activations. 750The n+2 dimensions are, in this order: 751 752* `batch`: Each coordinate in this dimension represents an independent input 753for which convolution is carried out. 754* `z/depth/features`: Each (y,x) position in the base area has a vector 755associated to it, which goes into this dimension. 756* `spatial_dims`: Describes the `n` spatial dimensions that define the base 757area that the window moves across. 758 759The `rhs` argument is a rank n+2 array describing the convolutional 760filter/kernel/window. The dimensions are, in this order: 761 762* `output-z`: The `z` dimension of the output. 763* `input-z`: The size of this dimension times `feature_group_count` should 764equal the size of the `z` dimension in lhs. 765* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d 766window that moves across the base area. 767 768The `window_strides` argument specifies the stride of the convolutional window 769in the spatial dimensions. For example, if the stride in the first spatial 770dimension is 3, then the window can only be placed at coordinates where the 771first spatial index is divisible by 3. 772 773The `padding` argument specifies the amount of zero padding to be applied to the 774base area. The amount of padding can be negative -- the absolute value of 775negative padding indicates the number of elements to remove from the specified 776dimension before doing the convolution. `padding[0]` specifies the padding for 777dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each 778pair has the low padding as the first element and the high padding as the second 779element. The low padding is applied in the direction of lower indices while the 780high padding is applied in the direction of higher indices. For example, if 781`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and 782by 3 zeroes on the right in the second spatial dimension. Using padding is 783equivalent to inserting those same zero values into the input (`lhs`) before 784doing the convolution. 785 786The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to 787be applied to the lhs and rhs, respectively, in each spatial dimension. If the 788dilation factor in a spatial dimension is d, then d-1 holes are implicitly 789placed between each of the entries in that dimension, increasing the size of the 790array. The holes are filled with a no-op value, which for convolution means 791zeroes. 792 793Dilation of the rhs is also called atrous convolution. For more details, see 794`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed 795convolution. For more details, see `tf.nn.conv2d_transpose`. 796 797The `feature_group_count` argument (default value 1) can be used for grouped 798convolutions. `feature_group_count` needs to be a divisor of both the input and 799the output feature dimension. If `feature_group_count` is greater than 1, it 800means that conceptually the input and output feature dimension and the `rhs` 801output feature dimension are split evenly into `feature_group_count` many 802groups, each group consisting of a consecutive subsequence of features. The 803input feature dimension of `rhs` needs to be equal to the `lhs` input feature 804dimension divided by `feature_group_count` (so it already has the size of a 805group of input features). The i-th groups are used together to compute 806`feature_group_count` many separate convolutions. The results of these 807convolutions are concatenated together in the output feature dimension. 808 809For depthwise convolution the `feature_group_count` argument would be set to the 810input feature dimension, and the filter would be reshaped from 811`[filter_height, filter_width, in_channels, channel_multiplier]` to 812`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more 813details, see `tf.nn.depthwise_conv2d`. 814 815The `batch_group_count` (default value 1) argument can be used for grouped 816filters during backpropagation. `batch_group_count` needs to be a divisor of the 817size of the `lhs` (input) batch dimension. If `batch_group_count` is greater 818than 1, it means that the output batch dimension should be of size `input batch 819/ batch_group_count`. The `batch_group_count` must be a divisor of the output 820feature size. 821 822The output shape has these dimensions, in this order: 823 824* `batch`: The size of this dimension times `batch_group_count` should equal 825 the size of the `batch` dimension in lhs. 826* `z`: Same size as `output-z` on the kernel (`rhs`). 827* `spatial_dims`: One value for each valid placement of the convolutional 828 window. 829 830The valid placements of the convolutional window are determined by the strides 831and the size of the base area after padding. 832 833To describe what a convolution does, consider a 2d convolution, and pick some 834fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a 835position of a corner of the window within the base area (e.g. the upper left 836corner, depending on how you interpret the spatial dimensions). We now have a 2d 837window, taken from the base area, where each 2d point is associated to a 1d 838vector, so we get a 3d box. From the convolutional kernel, since we fixed the 839output coordinate `z`, we also have a 3d box. The two boxes have the same 840dimensions, so we can take the sum of the element-wise products between the two 841boxes (similar to a dot product). That is the output value. 842 843Note that if `output-z` is e.g., 5, then each position of the window produces 5 844values in the output into the `z` dimension of the output. These values differ 845in what part of the convolutional kernel is used - there is a separate 3d box of 846values used for each `output-z` coordinate. So you could think of it as 5 847separate convolutions with a different filter for each of them. 848 849Here is pseudo-code for a 2d convolution with padding and striding: 850 851``` 852for (b, oz, oy, ox) { // output coordinates 853 value = 0; 854 for (iz, ky, kx) { // kernel coordinates and input z 855 iy = oy*stride_y + ky - pad_low_y; 856 ix = ox*stride_x + kx - pad_low_x; 857 if ((iy, ix) inside the base area considered without padding) { 858 value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx); 859 } 860 } 861 output(b, oz, oy, ox) = value; 862} 863``` 864 865## ConvertElementType 866 867See also 868[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 869 870Similar to an element-wise `static_cast` in C++, performs an element-wise 871conversion operation from a data shape to a target shape. The dimensions must 872match, and the conversion is an element-wise one; e.g. `s32` elements become 873`f32` elements via an `s32`-to-`f32` conversion routine. 874 875<b> `ConvertElementType(operand, new_element_type)` </b> 876 877Arguments | Type | Semantics 878------------------ | --------------- | --------------------------- 879`operand` | `XlaOp` | array of type T with dims D 880`new_element_type` | `PrimitiveType` | type U 881 882The dimensions of the operand and the target shape must match. The source and 883destination element types must not be tuples. 884 885A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float 886conversion routine such as round-to-nearest-even. 887 888> Note: The precise float-to-int and visa-versa conversions are currently 889> unspecified, but may become additional arguments to the convert operation in 890> the future. Not all possible conversions have been implemented for all 891>targets. 892 893``` 894let a: s32[3] = {0, 1, 2}; 895let b: f32[3] = convert(a, f32); 896then b == f32[3]{0.0, 1.0, 2.0} 897``` 898 899## CrossReplicaSum 900 901Performs `AllReduce` with a summation computation. 902 903## CustomCall 904 905See also 906[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 907 908Call a user-provided function within a computation. 909 910<b> `CustomCall(target_name, args..., shape)` </b> 911 912| Arguments | Type | Semantics | 913| ------------- | ---------------------- | --------------------------------- | 914| `target_name` | `string` | Name of the function. A call | 915: : : instruction will be emitted which : 916: : : targets this symbol name. : 917| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type, | 918: : : which will be passed to the : 919: : : function. : 920| `shape` | `Shape` | Output shape of the function | 921 922The function signature is the same, regardless of the arity or type of args: 923 924``` 925extern "C" void target_name(void* out, void** in); 926``` 927 928For example, if CustomCall is used as follows: 929 930``` 931let x = f32[2] {1,2}; 932let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}}; 933 934CustomCall("myfunc", {x, y}, f32[3x3]) 935``` 936 937Here is an example of an implementation of `myfunc`: 938 939``` 940extern "C" void myfunc(void* out, void** in) { 941 float (&x)[2] = *static_cast<float(*)[2]>(in[0]); 942 float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]); 943 EXPECT_EQ(1, x[0]); 944 EXPECT_EQ(2, x[1]); 945 EXPECT_EQ(10, y[0][0]); 946 EXPECT_EQ(20, y[0][1]); 947 EXPECT_EQ(30, y[0][2]); 948 EXPECT_EQ(40, y[1][0]); 949 EXPECT_EQ(50, y[1][1]); 950 EXPECT_EQ(60, y[1][2]); 951 float (&z)[3][3] = *static_cast<float(*)[3][3]>(out); 952 z[0][0] = x[1] + y[1][0]; 953 // ... 954} 955``` 956 957The user-provided function must not have side-effects and its execution must be 958idempotent. 959 960> Note: The opaque nature of the user-provided function restricts optimization 961> opportunities for the compiler. Try to express your computation in terms of 962> native XLA ops whenever possible; only use CustomCall as a last resort. 963 964## Dot 965 966See also 967[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 968 969<b> `Dot(lhs, rhs)` </b> 970 971Arguments | Type | Semantics 972--------- | ------- | --------------- 973`lhs` | `XlaOp` | array of type T 974`rhs` | `XlaOp` | array of type T 975 976The exact semantics of this operation depend on the ranks of the operands: 977 978| Input | Output | Semantics | 979| ----------------------- | --------------------- | ----------------------- | 980| vector [n] `dot` vector | scalar | vector dot product | 981: [n] : : : 982| matrix [m x k] `dot` | vector [m] | matrix-vector | 983: vector [k] : : multiplication : 984| matrix [m x k] `dot` | matrix [m x n] | matrix-matrix | 985: matrix [k x n] : : multiplication : 986 987The operation performs sum of products over the second dimension of `lhs` (or 988the first if it has rank 1) and the first dimension of `rhs`. These are the 989"contracted" dimensions. The contracted dimensions of `lhs` and `rhs` must be of 990the same size. In practice, it can be used to perform dot products between 991vectors, vector/matrix multiplications or matrix/matrix multiplications. 992 993## DotGeneral 994 995See also 996[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 997 998<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b> 999 1000Arguments | Type | Semantics 1001------------------- | --------------------- | --------------- 1002`lhs` | `XlaOp` | array of type T 1003`rhs` | `XlaOp` | array of type T 1004`dimension_numbers` | `DotDimensionNumbers` | contracting and batch dimension numbers 1005 1006As Dot, but allows contracting and batch dimension numbers to be specified for 1007both the 'lhs' and 'rhs'. 1008 1009| DotDimensionNumbers Fields | Type | Semantics 1010| --------- | ----------------------- | --------------- 1011| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers | 1012| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers | 1013| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers | 1014| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers | 1015 1016DotGeneral performs the sum of products over contracting dimensions specified 1017in 'dimension_numbers'. 1018 1019Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need 1020to be the same but must have the same dimension sizes. 1021 1022Example with contracting dimension numbers: 1023 1024``` 1025lhs = { {1.0, 2.0, 3.0}, 1026{4.0, 5.0, 6.0} } 1027 1028rhs = { {1.0, 1.0, 1.0}, 1029{2.0, 2.0, 2.0} } 1030 1031DotDimensionNumbers dnums; 1032dnums.add_lhs_contracting_dimensions(1); 1033dnums.add_rhs_contracting_dimensions(1); 1034 1035DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0}, 1036{15.0, 30.0} } 1037``` 1038 1039Associated batch dimension numbers from the 'lhs' and 'rhs' must 1040have the same dimension sizes. 1041 1042Example with batch dimension numbers (batch size 2, 2x2 matrices): 1043 1044``` 1045lhs = { { {1.0, 2.0}, 1046{3.0, 4.0} }, 1047{ {5.0, 6.0}, 1048{7.0, 8.0} } } 1049 1050rhs = { { {1.0, 0.0}, 1051{0.0, 1.0} }, 1052{ {1.0, 0.0}, 1053{0.0, 1.0} } } 1054 1055DotDimensionNumbers dnums; 1056dnums.add_lhs_contracting_dimensions(2); 1057dnums.add_rhs_contracting_dimensions(1); 1058dnums.add_lhs_batch_dimensions(0); 1059dnums.add_rhs_batch_dimensions(0); 1060 1061DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0}, 1062{3.0, 4.0} }, 1063{ {5.0, 6.0}, 1064{7.0, 8.0} } } 1065``` 1066 1067| Input | Output | Semantics | 1068| ----------------------------------- | ----------------- | ---------------- | 1069| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul | 1070| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul | 1071 1072It follows that the resulting dimension number starts with the batch dimension, 1073then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs' 1074non-contracting/non-batch dimension. 1075 1076## DynamicSlice 1077 1078See also 1079[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1080 1081DynamicSlice extracts a sub-array from the input array at dynamic 1082`start_indices`. The size of the slice in each dimension is passed in 1083`size_indices`, which specify the end point of exclusive slice intervals in each 1084dimension: [start, start + size). The shape of `start_indices` must be rank == 10851, with dimension size equal to the rank of `operand`. 1086 1087<b> `DynamicSlice(operand, start_indices, size_indices)` </b> 1088 1089| Arguments | Type | Semantics | 1090| --------------- | --------------------- | ---------------------------------- | 1091| `operand` | `XlaOp` | N dimensional array of type T | 1092| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | 1093: : : containing the starting indices of : 1094: : : the slice for each dimension. : 1095: : : Value must be greater than or : 1096: : : equal to zero. : 1097| `size_indices` | `ArraySlice<int64>` | List of N integers containing the | 1098: : : slice size for each dimension. : 1099: : : Each value must be strictly : 1100: : : greater than zero, and start + : 1101: : : size must be less than or equal to : 1102: : : the size of the dimension to avoid : 1103: : : wrapping modulo dimension size. : 1104 1105The effective slice indices are computed by applying the following 1106transformation for each index `i` in `[1, N)` before performing the slice: 1107 1108``` 1109start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i]) 1110``` 1111 1112This ensures that the extracted slice is always in-bounds with respect to the 1113operand array. If the slice is in-bounds before the transformation is applied, 1114the transformation has no effect. 1115 11161-dimensional example: 1117 1118``` 1119let a = {0.0, 1.0, 2.0, 3.0, 4.0} 1120let s = {2} 1121 1122DynamicSlice(a, s, {2}) produces: 1123{2.0, 3.0} 1124``` 1125 11262-dimensional example: 1127 1128``` 1129let b = 1130{ {0.0, 1.0, 2.0}, 1131{3.0, 4.0, 5.0}, 1132{6.0, 7.0, 8.0}, 1133{9.0, 10.0, 11.0} } 1134let s = {2, 1} 1135 1136DynamicSlice(b, s, {2, 2}) produces: 1137{ { 7.0, 8.0}, 1138{10.0, 11.0} } 1139``` 1140## DynamicUpdateSlice 1141 1142See also 1143[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1144 1145DynamicUpdateSlice generates a result which is the value of the input array 1146`operand`, with a slice `update` overwritten at `start_indices`. 1147The shape of `update` determines the shape of the sub-array of the result which 1148is updated. 1149The shape of `start_indices` must be rank == 1, with dimension size equal to 1150the rank of `operand`. 1151 1152<b> `DynamicUpdateSlice(operand, update, start_indices)` </b> 1153 1154| Arguments | Type | Semantics | 1155| --------------- | --------------------- | ---------------------------------- | 1156| `operand` | `XlaOp` | N dimensional array of type T | 1157| `update` | `XlaOp` | N dimensional array of type T | 1158: : : containing the slice update. Each : 1159: : : dimension of update shape must be : 1160: : : strictly greater than zero, and : 1161: : : start + update must be less than : 1162: : : or equal to the operand size for : 1163: : : each dimension to avoid generating : 1164: : : out-of-bounds update indices. : 1165| `start_indices` | sequence of N `XlaOp` | List of N scalar integers | 1166: : : containing the starting indices of : 1167: : : the slice for each dimension. : 1168: : : Value must be greater than or : 1169: : : equal to zero. : 1170 1171The effective slice indices are computed by applying the following 1172transformation for each index `i` in `[1, N)` before performing the slice: 1173 1174``` 1175start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i]) 1176``` 1177 1178This ensures that the updated slice is always in-bounds with respect to the 1179operand array. If the slice is in-bounds before the transformation is applied, 1180the transformation has no effect. 1181 11821-dimensional example: 1183 1184``` 1185let a = {0.0, 1.0, 2.0, 3.0, 4.0} 1186let u = {5.0, 6.0} 1187let s = {2} 1188 1189DynamicUpdateSlice(a, u, s) produces: 1190{0.0, 1.0, 5.0, 6.0, 4.0} 1191``` 1192 11932-dimensional example: 1194 1195``` 1196let b = 1197{ {0.0, 1.0, 2.0}, 1198{3.0, 4.0, 5.0}, 1199{6.0, 7.0, 8.0}, 1200{9.0, 10.0, 11.0} } 1201let u = 1202{ {12.0, 13.0}, 1203{14.0, 15.0}, 1204{16.0, 17.0} } 1205 1206let s = {1, 1} 1207 1208DynamicUpdateSlice(b, u, s) produces: 1209{ {0.0, 1.0, 2.0}, 1210{3.0, 12.0, 13.0}, 1211{6.0, 14.0, 15.0}, 1212{9.0, 16.0, 17.0} } 1213``` 1214 1215## Element-wise binary arithmetic operations 1216 1217See also 1218[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1219 1220A set of element-wise binary arithmetic operations is supported. 1221 1222<b> `Op(lhs, rhs)` </b> 1223 1224Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul` 1225(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min` 1226(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR). 1227 1228Arguments | Type | Semantics 1229--------- | ------- | ---------------------------------------- 1230`lhs` | `XlaOp` | left-hand-side operand: array of type T 1231`rhs` | `XlaOp` | right-hand-side operand: array of type T 1232 1233The arguments' shapes have to be either similar or compatible. See the 1234[broadcasting](broadcasting.md) documentation about what it means for shapes to 1235be compatible. The result of an operation has a shape which is the result of 1236broadcasting the two input arrays. In this variant, operations between arrays of 1237different ranks are *not* supported, unless one of the operands is a scalar. 1238 1239When `Op` is `Rem`, the sign of the result is taken from the dividend, and the 1240absolute value of the result is always less than the divisor's absolute value. 1241 1242Integer division overflow (signed/unsigned division/remainder by zero or signed 1243division/remainder of `INT_SMIN` with `-1`) produces an implementation defined 1244value. 1245 1246An alternative variant with different-rank broadcasting support exists for these 1247operations: 1248 1249<b> `Op(lhs, rhs, broadcast_dimensions)` </b> 1250 1251Where `Op` is the same as above. This variant of the operation should be used 1252for arithmetic operations between arrays of different ranks (such as adding a 1253matrix to a vector). 1254 1255The additional `broadcast_dimensions` operand is a slice of integers used to 1256expand the rank of the lower-rank operand up to the rank of the higher-rank 1257operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to 1258the dimensions of the higher-rank shape. The unmapped dimensions of the expanded 1259shape are filled with dimensions of size one. Degenerate-dimension broadcasting 1260then broadcasts the shapes along these degenerate dimensions to equalize the 1261shapes of both operands. The semantics are described in detail on the 1262[broadcasting page](broadcasting.md). 1263 1264## Element-wise comparison operations 1265 1266See also 1267[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1268 1269A set of standard element-wise binary comparison operations is supported. Note 1270that standard IEEE 754 floating-point comparison semantics apply when comparing 1271floating-point types. 1272 1273<b> `Op(lhs, rhs)` </b> 1274 1275Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge` 1276(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt` 1277(less-than). Another set of operators, EqTotalOrder, NeTotalOrder, GeTotalOrder, 1278GtTotalOrder, LeTotalOrder, and LtTotalOrder, provide the same functionalities, 1279except that they additionally support a total order over the floating point 1280numbers, by enforcing -NaN < -Inf < -Finite < -0 < +0 < +Finite < +Inf < +NaN. 1281 1282Arguments | Type | Semantics 1283--------- | ------- | ---------------------------------------- 1284`lhs` | `XlaOp` | left-hand-side operand: array of type T 1285`rhs` | `XlaOp` | right-hand-side operand: array of type T 1286 1287The arguments' shapes have to be either similar or compatible. See the 1288[broadcasting](broadcasting.md) documentation about what it means for shapes to 1289be compatible. The result of an operation has a shape which is the result of 1290broadcasting the two input arrays with the element type `PRED`. In this variant, 1291operations between arrays of different ranks are *not* supported, unless one of 1292the operands is a scalar. 1293 1294An alternative variant with different-rank broadcasting support exists for these 1295operations: 1296 1297<b> `Op(lhs, rhs, broadcast_dimensions)` </b> 1298 1299Where `Op` is the same as above. This variant of the operation should be used 1300for comparison operations between arrays of different ranks (such as adding a 1301matrix to a vector). 1302 1303The additional `broadcast_dimensions` operand is a slice of integers specifying 1304the dimensions to use for broadcasting the operands. The semantics are described 1305in detail on the [broadcasting page](broadcasting.md). 1306 1307## Element-wise unary functions 1308 1309XlaBuilder supports these element-wise unary functions: 1310 1311<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`. 1312 1313<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`. 1314 1315<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`. 1316 1317<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`. 1318 1319<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`. 1320 1321<b>`Imag(operand)`</b> Element-wise imaginary part of a complex (or real) 1322shape. `x -> imag(x)`. If the operand is a floating point type, returns 0. 1323 1324<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite, 1325i.e., is not positive or negative infinity, and is not `NaN`. Returns an array 1326of `PRED` values with the same shape as the input, where each element is `true` 1327if and only if the corresponding input element is finite. 1328 1329<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`. 1330 1331<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`. 1332 1333<b>`Logistic(operand)`</b> Element-wise logistic function computation `x -> 1334logistic(x)`. 1335 1336<b>`PopulationCount(operand)`</b> Computes the number of bits set in each 1337element of `operand`. 1338 1339<b>`Neg(operand)`</b> Element-wise negation `x -> -x`. 1340 1341<b>`Real(operand)`</b> Element-wise real part of a complex (or real) shape. 1342`x -> real(x)`. If the operand is a floating point type, returns the same value. 1343 1344<b>`Rsqrt(operand)`</b> Element-wise reciprocal of square root operation 1345`x -> 1.0 / sqrt(x)`. 1346 1347<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where 1348 1349$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$ 1350 1351using the comparison operator of the element type of `operand`. 1352 1353<b>`Sqrt(operand)`</b> Element-wise square root operation `x -> sqrt(x)`. 1354 1355<b>`Cbrt(operand)`</b> Element-wise cubic root operation `x -> cbrt(x)`. 1356 1357<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`. 1358 1359 1360Arguments | Type | Semantics 1361--------- | ------- | --------------------------- 1362`operand` | `XlaOp` | The operand to the function 1363 1364The function is applied to each element in the `operand` array, resulting in an 1365array with the same shape. It is allowed for `operand` to be a scalar (rank 0). 1366 1367## Fft 1368 1369The XLA FFT operation implements the forward and inverse Fourier Transforms for 1370real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are 1371supported, except on TPU, where only a single axis is supported (please file a 1372GitHub issue if you require higher order). 1373 1374See also 1375[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1376 1377| Arguments | Type | Semantics | 1378| ------------ | ------------------- | ------------------------ | 1379| `operand` | `XlaOp` | The array we are Fourier | 1380: : : transforming. : 1381| `fft_type` | `FftType` | See the table below. | 1382| `fft_length` | `ArraySlice<int64>` | The time-domain lengths | 1383: : : of the axes being : 1384: : : transformed. This is : 1385: : : needed in particular for : 1386: : : IRFFT to right-size the : 1387: : : innermost axis, since : 1388: : : `RFFT(fft_length=[16])` : 1389: : : has the same output : 1390: : : shape as : 1391: : : `RFFT(fft_length=[17])`. : 1392 1393| `FftType` | Semantics | 1394| --------- | ---------------------------------------------------------------- | 1395| `FFT` | Forward complex-to-complex FFT. Shape is unchanged. | 1396| `IFFT` | Inverse complex-to-complex FFT. Shape is unchanged. | 1397| `RFFT` | Forward real-to-complex FFT. Shape of the innermost axis is | 1398: : reduced to `fft_length[-1] // 2 + 1` if `fft_length[-1]` is a : 1399: : non-zero value, omitting the reversed conjugate part of the : 1400: : transformed signal beyond the Nyquist frequency. : 1401| `IRFFT` | Inverse real-to-complex FFT (i.e. takes complex, returns real). | 1402: : Shape of the innermost axis is expanded to `fft_length[-1]` if : 1403: : `fft_length[-1]` is a non-zero value, inferring the part of the : 1404: : transformed signal beyond the Nyquist frequency from the reverse : 1405: : conjugate of the `1` to `fft_length[-1] // 2 + 1` entries. : 1406 1407#### Multidimensional FFT 1408 1409When more than 1 `fft_length` is provided, this is equivalent to applying a 1410cascade of FFT operations to each of the innermost axes. Note that for the 1411real->complex and complex->real cases, the innermost axis transform is 1412(effectively) performed first (RFFT; last for IRFFT), which is why the innermost 1413axis is the one which changes size. Other axis transforms will then be 1414complex->complex. 1415 1416#### Implementation details 1417 1418CPU FFT is backed by Eigen's TensorFFT. GPU FFT uses cuFFT. 1419 1420## Gather 1421 1422The XLA gather operation stitches together several slices (each slice at a 1423potentially different runtime offset) of an input array. 1424 1425### General Semantics 1426 1427See also 1428[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1429For a more intuitive description, see the "Informal Description" section below. 1430 1431<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b> 1432 1433| Arguments | Type | Semantics | 1434| ---------------------- | ------------------- | ----------------------------- | 1435| `operand` | `XlaOp` | The array we’re gathering | 1436: : : from. : 1437| `start_indices` | `XlaOp` | Array containing the starting | 1438: : : indices of the slices we : 1439: : : gather. : 1440| `index_vector_dim` | `int64` | The dimension in | 1441: : : `start_indices` that : 1442: : : "contains" the starting : 1443: : : indices. See below for a : 1444: : : detailed description. : 1445| `offset_dims` | `ArraySlice<int64>` | The set of dimensions in the | 1446: : : output shape that offset into : 1447: : : an array sliced from operand. : 1448| `slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the | 1449: : : bounds for the slice on : 1450: : : dimension `i`. : 1451| `collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each | 1452: : : slice that are collapsed : 1453: : : away. These dimensions must : 1454: : : have size 1. : 1455| `start_index_map` | `ArraySlice<int64>` | A map that describes how to | 1456: : : map indices in : 1457: : : `start_indices` to legal : 1458: : : indices into operand. : 1459| `indices_are_sorted` | `bool` | Whether the indices are | 1460: : : guaranteed to be sorted by : 1461: : : the caller. : 1462| `unique_indices` | `bool` | Whether the indices are | 1463: : : guaranteed to be unique by : 1464: : : the caller. : 1465 1466For convenience, we label dimensions in the output array not in `offset_dims` 1467as `batch_dims`. 1468 1469The output is an array of rank `batch_dims.size` + `offset_dims.size`. 1470 1471The `operand.rank` must equal the sum of `offset_dims.size` and 1472`collapsed_slice_dims`. Also, `slice_sizes.size` has to be equal to 1473`operand.rank`. 1474 1475If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider 1476`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of 1477shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the 1478shape of `start_indices` to be `[6,7,1]`). 1479 1480The bounds for the output array along dimension `i` is computed as follows: 1481 14821. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for 1483some `k`) then we pick the corresponding dimension bounds out of 1484`start_indices.shape`, skipping `index_vector_dim` (i.e. pick 1485`start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and 1486`start_indices.shape.dims`[`k`+`1`] otherwise). 1487 14882. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for 1489some `k`) then we pick the corresponding bound out of `slice_sizes` after 1490accounting for `collapsed_slice_dims` (i.e. we pick 1491`adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes` 1492with the bounds at indices `collapsed_slice_dims` removed). 1493 1494Formally, the operand index `In` corresponding to a given output index `Out` is 1495calculated as follows: 1496 14971. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out a 1498 vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where 1499 Combine(A, b) inserts b at position `index_vector_dim` into A. Note that 1500 this is well defined even if `G` is empty -- if `G` is empty then `S` = 1501 `start_indices`. 1502 15032. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by 1504 scattering `S` using `start_index_map`. More precisely: 1505 1506 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` < 1507 `start_index_map.size`. 1508 1509 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. 1510 15113. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices 1512 at the offset dimensions in `Out` according to the `collapsed_slice_dims` 1513 set. More precisely: 1514 1515 1. `O`<sub>`in`</sub>[`remapped_offset_dims`(`k`)] = 1516 `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size` 1517 (`remapped_offset_dims` is defined below). 1518 1519 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise. 1520 15214. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise 1522 addition. 1523 1524`remapped_offset_dims` is a monotonic function with domain [`0`, `offset.size`) 1525and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g., 1526`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`, 1527`2`} then `remapped_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. 1528 1529If `indices_are_sorted` is set to true then XLA can assume that `start_indices` 1530are sorted (in ascending `start_index_map` order) by the user. If they are not 1531then the semantics is implementation defined. 1532 1533If `unique_indices` is set to true then XLA can assume that all element 1534scattered to are unique. So XLA could use non-atomic operations. If 1535`unique_indices` is set to true and the indices being scattered to are not 1536unique then the semantics is implementation defined. 1537 1538### Informal Description and Examples 1539 1540Informally, every index `Out` in the output array corresponds to an element `E` 1541in the operand array, computed as follows: 1542 1543- We use the batch dimensions in `Out` to look up a starting index from 1544 `start_indices`. 1545 1546- We use `start_index_map` to map the starting index (whose size may be less 1547 than operand.rank) to a "full" starting index into the `operand`. 1548 1549- We dynamic-slice out a slice with size `slice_sizes` using the full starting 1550 index. 1551 1552- We reshape the slice by collapsing the `collapsed_slice_dims` dimensions. 1553 Since all collapsed slice dimensions must have a bound of 1, this reshape is 1554 always legal. 1555 1556- We use the offset dimensions in `Out` to index into this slice to get the 1557 input element, `E`, corresponding to output index `Out`. 1558 1559`index_vector_dim` is set to `start_indices.rank` - `1` in all of the examples 1560that follow. More interesting values for `index_vector_dim` do not change the 1561operation fundamentally, but make the visual representation more cumbersome. 1562 1563To get an intuition on how all of the above fits together, let's look at an 1564example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The 1565position of a slice into the `[16,11]` array can be represented as an index 1566vector of shape `S64[2]`, so the set of 5 positions can be represented as a 1567`S64[5,2]` array. 1568 1569The behavior of the gather operation can then be depicted as an index 1570transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in 1571the output shape, and maps it to an element in the input array in the following 1572way: 1573 1574<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1575<img style="width:100%" src="./images/ops_xla_gather_0.svg"> 1576</div> 1577 1578We first select an (`X`,`Y`) vector from the gather indices array using `G`. 1579The element in the output array at index 1580[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input 1581array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>]. 1582 1583`slice_sizes` is `[8,6]`, which decides the range of O<sub>`0`</sub> and 1584O<sub>`1`</sub>, and this in turn decides the bounds of the slice. 1585 1586This gather operation acts as a batch dynamic slice with `G` as the batch 1587dimension. 1588 1589The gather indices may be multidimensional. For instance, a more general 1590version of the example above using a "gather indices" array of shape `[4,5,2]` 1591would translate indices like this: 1592 1593<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1594<img style="width:100%" src="./images/ops_xla_gather_1.svg"> 1595</div> 1596 1597Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and 1598`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`. 1599 1600The gather operation in XLA generalizes the informal semantics outlined above in 1601the following ways: 1602 16031. We can configure which dimensions in the output shape are the offset 1604dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in 1605the last example). The output batch dimensions (dimensions containing 1606`G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be 1607the output dimensions that are not offset dimensions. 1608 16092. The number of output offset dimensions explicitly present in the output 1610shape may be smaller than the input rank. These "missing" dimensions, which 1611are listed explicitly as `collapsed_slice_dims`, must have a slice size of 1612`1`. Since they have a slice size of `1` the only valid index for them is 1613`0` and eliding them does not introduce ambiguity. 1614 16153. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last 1616example) may have fewer elements than the input array rank, and an explicit 1617mapping dictates how the index should be expanded to have the same rank as 1618the input. 1619 1620As a final example, we use (2) and (3) to implement `tf.gather_nd`: 1621 1622<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1623<img style="width:100%" src="./images/ops_xla_gather_2.svg"> 1624</div> 1625 1626`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index 1627from the gather indices array as usual, except the starting index has only one 1628element, `X`. Similarly, there is only one output offset index with the value 1629`O`<sub>`0`</sub>. However, before being used as indices into the input array, 1630these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in 1631the formal description) and "Offset Mapping" (`remapped_offset_dims` in the 1632formal description) into [`X`,`0`] and [`0`,`O`<sub>`0`</sub>] respectively, 1633adding up to [`X`,`O`<sub>`0`</sub>]. In other words, the output index 1634[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index 1635[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us 1636the semantics for `tf.gather_nd`. 1637 1638`slice_sizes` for this case is `[1,11]`. Intuitively this means that every 1639index `X` in the gather indices array picks an entire row and the result is the 1640concatenation of all these rows. 1641 1642## GetDimensionSize 1643 1644See also 1645[`XlaBuilder::GetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1646 1647Returns the size of the given dimension of the operand. The operand must be 1648array shaped. 1649 1650<b> `GetDimensionSize(operand, dimension)` </b> 1651 1652| Arguments | Type | Semantics | 1653| ----------- | ------- | --------------------------------------------------- | 1654| `operand` | `XlaOp` | n dimensional input array | 1655| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | 1656: : : dimension : 1657 1658## SetDimensionSize 1659 1660See also 1661[`XlaBuilder::SetDimensionSize`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1662 1663Sets the dynamic size of XlaOp's given dimension. The operand must be 1664array shaped. 1665 1666<b> `SetDimensionSize(operand, size, dimension)` </b> 1667 1668| Arguments | Type | Semantics | 1669| ----------- | ------- | --------------------------------------------------- | 1670| `operand` | `XlaOp` | n dimensional input array. | 1671| `size` | `XlaOp` | int32 representing the runtime dynamic size. | 1672| `dimension` | `int64` | A value in the interval `[0, n)` that specifies the | 1673: : : dimension. : 1674 1675Pass through the operand as result, with dynamic dimension tracked by the 1676compiler. 1677 1678Padded values will be ignored by downstream reduction ops. 1679 1680``` 1681let v: f32[10] = f32[10]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; 1682let five: s32 = 5; 1683let six: s32 = 6; 1684 1685// Setting dynamic dimension size doesn't change the upper bound of the static 1686// shape. 1687let padded_v_five: f32[10] = set_dimension_size(v, five, /*dimension=*/0); 1688let padded_v_six: f32[10] = set_dimension_size(v, six, /*dimension=*/0); 1689 1690// sum == 1 + 2 + 3 + 4 + 5 1691let sum:f32[] = reduce_sum(padded_v_five); 1692// product == 1 * 2 * 3 * 4 * 5 1693let product:f32[] = reduce_product(padded_v_five); 1694 1695// Changing padding size will yield different result. 1696// sum == 1 + 2 + 3 + 4 + 5 + 6 1697let sum':f32[] = reduce_sum(padded_v_six); 1698``` 1699 1700## GetTupleElement 1701 1702See also 1703[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1704 1705Indexes into a tuple with a compile-time-constant value. 1706 1707The value must be a compile-time-constant so that shape inference can determine 1708the type of the resulting value. 1709 1710This is analogous to `std::get<int N>(t)` in C++. Conceptually: 1711 1712``` 1713let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 1714let s: s32 = 5; 1715let t: (f32[10], s32) = tuple(v, s); 1716let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32. 1717``` 1718 1719See also `tf.tuple`. 1720 1721## Infeed 1722 1723See also 1724[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1725 1726<b> `Infeed(shape)` </b> 1727 1728| Argument | Type | Semantics | 1729| -------- | ------- | ----------------------------------------------------- | 1730| `shape` | `Shape` | Shape of the data read from the Infeed interface. The | 1731: : : layout field of the shape must be set to match the : 1732: : : layout of the data sent to the device; otherwise its : 1733: : : behavior is undefined. : 1734 1735Reads a single data item from the implicit Infeed streaming interface of the 1736device, interpreting the data as the given shape and its layout, and returns a 1737`XlaOp` of the data. Multiple Infeed operations are allowed in a 1738computation, but there must be a total order among the Infeed operations. For 1739example, two Infeeds in the code below have a total order since there is a 1740dependency between the while loops. 1741 1742``` 1743result1 = while (condition, init = init_value) { 1744 Infeed(shape) 1745} 1746 1747result2 = while (condition, init = result1) { 1748 Infeed(shape) 1749} 1750``` 1751 1752Nested tuple shapes are not supported. For an empty tuple shape, the Infeed 1753operation is effectively a no-op and proceeds without reading any data from the 1754Infeed of the device. 1755 1756> Note: We plan to allow multiple Infeed operations without a total order, in 1757> which case the compiler will provide information about how the Infeed 1758> operations are serialized in the compiled program. 1759 1760## Iota 1761 1762<b> `Iota()` </b> 1763 1764Builds a constant literal on device rather than a potentially large host 1765transfer. Creates a rank 1 array of values starting at zero and incrementing by 1766one. For floating-point types, the produced array is equivalent to 1767`ConvertElementType(Iota(...))` where the `Iota` is of integral type and the 1768conversion is to the floating-point type. 1769 1770Arguments | Type | Semantics 1771---------------- | --------------- | ------------------------------------ 1772`type` | `PrimitiveType` | type U 1773`size` | `int64` | The number of elements in the array. 1774`iota_dimension` | `int64` | The dimension to increment along. 1775 1776## Map 1777 1778See also 1779[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1780 1781<b> `Map(operands..., computation)` </b> 1782 1783| Arguments | Type | Semantics | 1784| ----------------- | ---------------------- | ------------------------------ | 1785| `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} | 1786| `computation` | `XlaComputation` | computation of type `T_0, T_1, | 1787: : : ..., T_{N + M -1} -> S` with N : 1788: : : parameters of type T and M of : 1789: : : arbitrary type : 1790| `dimensions` | `int64` array | array of map dimensions | 1791 1792Applies a scalar function over the given `operands` arrays, producing an array 1793of the same dimensions where each element is the result of the mapped function 1794applied to the corresponding elements in the input arrays. 1795 1796The mapped function is an arbitrary computation with the restriction that it has 1797N inputs of scalar type `T` and a single output with type `S`. The output has 1798the same dimensions as the operands except that the element type T is replaced 1799with S. 1800 1801For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <- 1802computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the 1803input arrays to produce the output array. 1804 1805## Pad 1806 1807See also 1808[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1809 1810<b> `Pad(operand, padding_value, padding_config)` </b> 1811 1812| Arguments | Type | Semantics | 1813| ---------------- | --------------- | --------------------------------------- | 1814| `operand` | `XlaOp` | array of type `T` | 1815| `padding_value` | `XlaOp` | scalar of type `T` to fill in the added | 1816: : : padding : 1817| `padding_config` | `PaddingConfig` | padding amount on both edges (low, | 1818: : : high) and between the elements of each : 1819: : : dimension : 1820 1821Expands the given `operand` array by padding around the array as well as between 1822the elements of the array with the given `padding_value`. `padding_config` 1823specifies the amount of edge padding and the interior padding for each 1824dimension. 1825 1826`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains 1827three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and 1828`interior_padding`. 1829 1830`edge_padding_low` and `edge_padding_high` specify the amount of padding added 1831at the low-end (next to index 0) and the high-end (next to the highest index) of 1832each dimension respectively. The amount of edge padding can be negative -- the 1833absolute value of negative padding indicates the number of elements to remove 1834from the specified dimension. 1835 1836`interior_padding` specifies the amount of padding added between any two 1837elements in each dimension; it may not be negative. Interior padding occurs 1838logically before edge padding, so in the case of negative edge padding, elements 1839are removed from the interior-padded operand. 1840 1841This operation is a no-op if the edge padding pairs are all (0, 0) and the 1842interior padding values are all 0. The figure below shows examples of different 1843`edge_padding` and `interior_padding` values for a two-dimensional array. 1844 1845<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1846 <img style="width:100%" src="./images/ops_pad.png"> 1847</div> 1848 1849## Recv 1850 1851See also 1852[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1853 1854<b> `Recv(shape, channel_handle)` </b> 1855 1856| Arguments | Type | Semantics | 1857| ---------------- | --------------- | ------------------------------------ | 1858| `shape` | `Shape` | shape of the data to receive | 1859| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair | 1860 1861Receives data of the given shape from a `Send` instruction in another 1862computation that shares the same channel handle. Returns a 1863XlaOp for the received data. 1864 1865The client API of `Recv` operation represents synchronous communication. 1866However, the instruction is internally decomposed into 2 HLO instructions 1867(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also 1868[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). 1869 1870<b>`Recv(const Shape& shape, int64 channel_id)`</b> 1871 1872Allocates resources required to receive data from a `Send` instruction with the 1873same channel_id. Returns a context for the allocated resources, which is used 1874by a following `RecvDone` instruction to wait for the completion of the data 1875transfer. The context is a tuple of {receive buffer (shape), request identifier 1876(U32)} and it can only be used by a `RecvDone` instruction. 1877 1878<b> `RecvDone(HloInstruction context)` </b> 1879 1880Given a context created by a `Recv` instruction, waits for the data transfer to 1881complete and returns the received data. 1882 1883## Reduce 1884 1885See also 1886[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 1887 1888Applies a reduction function to one or more arrays in parallel. 1889 1890<b> `Reduce(operands..., init_values..., computation, dimensions)` </b> 1891 1892| Arguments | Type | Semantics | 1893| ------------- | --------------------- | -------------------------------- | 1894| `operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., | 1895: : : T_{N-1}`. : 1896| `init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., | 1897: : : T_{N-1}`. : 1898| `computation` | `XlaComputation` | computation of type `T_0, ..., | 1899: : : T_{N-1}, T_0, ..., T_{N-1} ->` : 1900: : : `Collate(T_0, ..., T_{N-1})`. : 1901| `dimensions` | `int64` array | unordered array of dimensions to | 1902: : : reduce. : 1903 1904Where: 1905 1906* N is required to be greater or equal to 1. 1907* All input arrays must have the same dimensions. 1908* If `N = 1`, `Collate(T)` is `T`. 1909* If `N > 1`, `Collate(T_0, ..., T_{N-1})` is a tuple of `N` elements of type 1910 `T`. 1911 1912The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type 1913`T_i`, the dimensions of which are described below. 1914 1915This operation reduces one or more dimensions of each input array into scalars. 1916The rank of each returned array is `rank(operand) - len(dimensions)`. The 1917initial value used for every reduction is `init_value`, and it may be inserted 1918anywhere during computation by the back-end. In most cases, `init_value` is an 1919identity of the reduction function (for example, `0` for addition). The applied 1920`computation` is always passed the `init_value` on the left-hand side. 1921 1922The evaluation order of the reduction function is arbitrary and may be 1923non-deterministic. Therefore, the reduction function should not be overly 1924sensitive to reassociation. 1925 1926Some reduction functions like addition are not strictly associative for floats. 1927However, if the range of the data is limited, floating-point addition is close 1928enough to being associative for most practical uses. It is possible to conceive 1929of some completely non-associative reductions, however, and these will produce 1930incorrect or unpredictable results in XLA. 1931 1932As an example, when reducing across one dimension in a single 1D array with 1933values `[10, 11, 12, 13]`, with reduction function `f` (this is `computation`) 1934then that could be computed as 1935 1936`f(10, f(11, f(12, f(init_value, 13)))` 1937 1938but there are also many other possibilities, e.g. 1939 1940`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))` 1941 1942The following is a rough pseudo-code example of how reduction could be 1943implemented, using summation as the reduction computation with an initial value 1944of 0. 1945 1946```python 1947result_shape <- remove all dims in dimensions from operand_shape 1948 1949# Iterate over all elements in result_shape. The number of r's here is equal 1950# to the rank of the result 1951for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...: 1952 # Initialize this result element 1953 result[r0, r1...] <- 0 1954 1955 # Iterate over all the reduction dimensions 1956 for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...: 1957 # Increment the result element with the value of the operand's element. 1958 # The index of the operand's element is constructed from all ri's and di's 1959 # in the right order (by construction ri's and di's together index over the 1960 # whole operand shape). 1961 result[r0, r1...] += operand[ri... di] 1962``` 1963 1964Here's an example of reducing a 2D array (matrix). The shape has rank 2, 1965dimension 0 of size 2 and dimension 1 of size 3: 1966 1967<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1968 <img style="width:35%" src="./images/ops_2d_matrix.png"> 1969</div> 1970 1971Results of reducing dimensions 0 or 1 with an "add" function: 1972 1973<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1974 <img style="width:35%" src="./images/ops_reduce_from_2d_matrix.png"> 1975</div> 1976 1977Note that both reduction results are 1D arrays. The diagram shows one as column 1978and another as row just for visual convenience. 1979 1980For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of 1981size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the 1982values 1 to 6 are replicated across dimension 0. 1983 1984<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 1985 <img style="width:35%" src="./images/ops_reduce_from_3d_matrix.png"> 1986</div> 1987 1988Similarly to the 2D example, we can reduce just one dimension. If we reduce 1989dimension 0, for example, we get a rank-2 array where all values across 1990dimension 0 were folded into a scalar: 1991 1992```text 1993| 4 8 12 | 1994| 16 20 24 | 1995``` 1996 1997If we reduce dimension 2, we also get a rank-2 array where all values across 1998dimension 2 were folded into a scalar: 1999 2000```text 2001| 6 15 | 2002| 6 15 | 2003| 6 15 | 2004| 6 15 | 2005``` 2006 2007Note that the relative order between the remaining dimensions in the input is 2008preserved in the output, but some dimensions may get assigned new numbers (since 2009the rank changes). 2010 2011We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces 2012the 1D array `[20, 28, 36]`. 2013 2014Reducing the 3D array over all its dimensions produces the scalar `84`. 2015 2016### Variadic Reduce 2017 2018When `N > 1`, reduce function application is slightly more complex, as it is 2019applied simultaneously to all inputs. The operands are supplied to the 2020computation in the following order: 2021 2022* Running reduced value for the first operand 2023* ... 2024* Running reduced value for the N'th operand 2025* Input value for the first operand 2026* ... 2027* Input value for the N'th operand 2028 2029For example, consider the following reduction function, which can be used to 2030compute the max and the argmax of a 1-D array in parallel: 2031 2032```python 2033f: (Float, Int, Float, Int) -> Float, Int 2034f(max, argmax, value, index): 2035 if value >= max: 2036 return (value, index) 2037 else: 2038 return (max, argmax) 2039``` 2040 2041For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values 2042`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only 2043input dimension is equivalent to the following recursive application: 2044 2045``` 2046f_0 = f(I_V, I_K, V_0, K_0) 2047f_1 = f(f_0.first, f_0.second, V_1, K_1) 2048... 2049f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1)) 2050``` 2051 2052Applying this reduction to an array of values, and an array of sequential 2053indices (i.e. iota), will co-iterate over the arrays, and return a tuple 2054containing the maximal value and the matching index. 2055 2056## ReducePrecision 2057 2058See also 2059[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2060 2061Models the effect of converting floating-point values to a lower-precision 2062format (such as IEEE-FP16) and back to the original format. The number of 2063exponent and mantissa bits in the lower-precision format can be specified 2064arbitrarily, although all bit sizes may not be supported on all hardware 2065implementations. 2066 2067<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b> 2068 2069Arguments | Type | Semantics 2070--------------- | ------- | ------------------------------------------------- 2071`operand` | `XlaOp` | array of floating-point type `T`. 2072`exponent_bits` | `int32` | number of exponent bits in lower-precision format 2073`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format 2074 2075The result is an array of type `T`. The input values are rounded to the nearest 2076value representable with the given number of mantissa bits (using "ties to even" 2077semantics), and any values that exceed the range specified by the number of 2078exponent bits are clamped to positive or negative infinity. `NaN` values are 2079retained, although they may be converted to canonical `NaN` values. 2080 2081The lower-precision format must have at least one exponent bit (in order to 2082distinguish a zero value from an infinity, since both have a zero mantissa), and 2083must have a non-negative number of mantissa bits. The number of exponent or 2084mantissa bits may exceed the corresponding value for type `T`; the corresponding 2085portion of the conversion is then simply a no-op. 2086 2087## ReduceWindow 2088 2089See also 2090[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2091 2092Applies a reduction function to all elements in each window of the input 2093multi-dimensional array, producing an output multi-dimensional array with the 2094same number of elements as the number of valid positions of the window. A 2095pooling layer can be expressed as a `ReduceWindow`. Similar to 2096[`Reduce`](#reduce), the applied `computation` is always passed the `init_value` 2097on the left-hand side. 2098 2099<b> `ReduceWindow(operand, init_value, computation, window_dimensions, 2100window_strides, padding)` </b> 2101 2102| Arguments | Type | Semantics | 2103| ------------------- | ------------------- | -------------------------------- | 2104| `operand` | `XlaOp` | N dimensional array containing | 2105: : : elements of type T. This is the : 2106: : : base area on which the window is : 2107: : : placed. : 2108| `init_value` | `XlaOp` | Starting value for the | 2109: : : reduction. See [Reduce](#reduce) : 2110: : : for details. : 2111| `computation` | `XlaComputation` | Reduction function of type `T, T | 2112: : : -> T`, to apply to all elements : 2113: : : in each window : 2114| `window_dimensions` | `ArraySlice<int64>` | array of integers for window | 2115: : : dimension values : 2116| `window_strides` | `ArraySlice<int64>` | array of integers for window | 2117: : : stride values : 2118| `base_dilations` | `ArraySlice<int64>` | array of integers for base | 2119: : : dilation values : 2120| `window_dilations` | `ArraySlice<int64>` | array of integers for window | 2121: : : dilation values : 2122| `padding` | `Padding` | padding type for window | 2123: : : (Padding\:\:kSame, which pads so : 2124: : : as to have the same output shape : 2125: : : as input if the stride is 1, or : 2126: : : Padding\:\:kValid, which uses no : 2127: : : padding and "stops" the window : 2128: : : once it no longer fits) : 2129 2130Below code and figure shows an example of using `ReduceWindow`. Input is a 2131matrix of size [4x6] and both window_dimensions and window_stride_dimensions are 2132[2x3]. 2133 2134``` 2135// Create a computation for the reduction (maximum). 2136XlaComputation max; 2137{ 2138 XlaBuilder builder(client_, "max"); 2139 auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y"); 2140 auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x"); 2141 builder.Max(y, x); 2142 max = builder.Build().ConsumeValueOrDie(); 2143} 2144 2145// Create a ReduceWindow computation with the max reduction computation. 2146XlaBuilder builder(client_, "reduce_window_2x3"); 2147auto shape = ShapeUtil::MakeShape(F32, {4, 6}); 2148auto input = builder.Parameter(0, shape, "input"); 2149builder.ReduceWindow( 2150 input, 2151 /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)), 2152 *max, 2153 /*window_dimensions=*/{2, 3}, 2154 /*window_stride_dimensions=*/{2, 3}, 2155 Padding::kValid); 2156``` 2157 2158<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2159 <img style="width:35%" src="./images/ops_reduce_window.png"> 2160</div> 2161 2162Stride of 1 in a dimension specifies that the position of a window in the 2163dimension is 1 element away from its adjacent window. In order to specify that 2164no windows overlap with each other, window_stride_dimensions should be equal to 2165window_dimensions. The figure below illustrates the use of two different stride 2166values. Padding is applied to each dimension of the input and the calculations 2167are the same as though the input came in with the dimensions it has after 2168padding. 2169 2170<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2171 <img style="width:75%" src="./images/ops_reduce_window_stride.png"> 2172</div> 2173 2174For a non-trivial padding example, consider computing reduce-window minimum 2175(initial value is `MAX_FLOAT`) with dimension `3` and stride `2` over the input 2176array `[10000, 1000, 100, 10, 1]`. Padding `kValid` computes minimums over two 2177_valid_ windows: `[10000, 1000, 100]` and `[100, 10, 1]`, resulting in the 2178output `[100, 1]`. Padding `kSame` first pads the array so that the shape after 2179the reduce-window would be the _same_ as input for stride one by adding initial 2180elements on both sides, getting `[MAX_VALUE, 10000, 1000, 100, 10, 1, 2181MAX_VALUE]`. Running reduce-window over the padded array operates on three 2182windows `[MAX_VALUE, 10000, 1000]`, `[1000, 100, 10]`, `[10, 1, MAX_VALUE]`, and 2183yields `[1000, 10, 1]`. 2184 2185The evaluation order of the reduction function is arbitrary and may be 2186non-deterministic. Therefore, the reduction function should not be overly 2187sensitive to reassociation. See the discussion about associativity in the 2188context of [`Reduce`](#reduce) for more details. 2189 2190## ReplicaId 2191 2192See also 2193[`XlaBuilder::ReplicaId`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2194 2195Returns the unique ID (U32 scalar) of the replica. 2196 2197<b> `ReplicaId()` </b> 2198 2199The unique ID of each replica is an unsigned integer in the interval `[0, N)`, 2200where `N` is the number of replicas. Since all the replicas are running the same 2201program, a `ReplicaId()` call in the program will return a different value on 2202each replica. 2203 2204## Reshape 2205 2206See also 2207[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h) 2208and the [`Collapse`](#collapse) operation. 2209 2210Reshapes the dimensions of an array into a new configuration. 2211 2212<b> `Reshape(operand, new_sizes)` </b> 2213<b> `Reshape(operand, dimensions, new_sizes)` </b> 2214 2215Arguments | Type | Semantics 2216------------ | -------------- | --------------------------------------- 2217`operand` | `XlaOp` | array of type T 2218`dimensions` | `int64` vector | order in which dimensions are collapsed 2219`new_sizes` | `int64` vector | vector of sizes of new dimensions 2220 2221Conceptually, reshape first flattens an array into a one-dimensional vector of 2222data values, and then refines this vector into a new shape. The input arguments 2223are an arbitrary array of type T, a compile-time-constant vector of dimension 2224indices, and a compile-time-constant vector of dimension sizes for the result. 2225The values in the `dimension` vector, if given, must be a permutation of all of 2226T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of 2227the dimensions in `dimensions` is from slowest-varying dimension (most major) to 2228fastest-varying dimension (most minor) in the loop nest which collapses the 2229input array into a single dimension. The `new_sizes` vector determines the size 2230of the output array. The value at index 0 in `new_sizes` is the size of 2231dimension 0, the value at index 1 is the size of dimension 1, and so on. The 2232product of the `new_size` dimensions must equal the product of the operand's 2233dimension sizes. When refining the collapsed array into the multidimensional 2234array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from 2235slowest varying (most major) and to fastest varying (most minor). 2236 2237For example, let v be an array of 24 elements: 2238 2239``` 2240let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}}, 2241 {{20, 21, 22}, {25, 26, 27}}, 2242 {{30, 31, 32}, {35, 36, 37}}, 2243 {{40, 41, 42}, {45, 46, 47}}}; 2244 2245In-order collapse: 2246let v012_24 = Reshape(v, {0,1,2}, {24}); 2247then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, 2248 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}; 2249 2250let v012_83 = Reshape(v, {0,1,2}, {8,3}); 2251then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17}, 2252 {20, 21, 22}, {25, 26, 27}, 2253 {30, 31, 32}, {35, 36, 37}, 2254 {40, 41, 42}, {45, 46, 47}}; 2255 2256Out-of-order collapse: 2257let v021_24 = Reshape(v, {1,2,0}, {24}); 2258then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, 2259 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}; 2260 2261let v021_83 = Reshape(v, {1,2,0}, {8,3}); 2262then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21}, 2263 {31, 41, 12}, {22, 32, 42}, 2264 {15, 25, 35}, {45, 16, 26}, 2265 {36, 46, 17}, {27, 37, 47}}; 2266 2267 2268let v021_262 = Reshape(v, {1,2,0}, {2,6,2}); 2269then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40}, 2270 {11, 21}, {31, 41}, 2271 {12, 22}, {32, 42}}, 2272 {{15, 25}, {35, 45}, 2273 {16, 26}, {36, 46}, 2274 {17, 27}, {37, 47}}}; 2275``` 2276 2277As a special case, reshape can transform a single-element array to a scalar and 2278vice versa. For example, 2279 2280``` 2281Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5; 2282Reshape(5, {}, {1,1}) == f32[1x1] {{5}}; 2283``` 2284 2285## Rev (reverse) 2286 2287See also 2288[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2289 2290<b>`Rev(operand, dimensions)`</b> 2291 2292Arguments | Type | Semantics 2293------------ | ------------------- | --------------------- 2294`operand` | `XlaOp` | array of type T 2295`dimensions` | `ArraySlice<int64>` | dimensions to reverse 2296 2297Reverses the order of elements in the `operand` array along the specified 2298`dimensions`, generating an output array of the same shape. Each element of the 2299operand array at a multidimensional index is stored into the output array at a 2300transformed index. The multidimensional index is transformed by reversing the 2301index in each dimension to be reversed (i.e., if a dimension of size N is one of 2302the reversing dimensions, its index i is transformed into N - 1 - i). 2303 2304One use for the `Rev` operation is to reverse the convolution weight array along 2305the two window dimensions during the gradient computation in neural networks. 2306 2307## RngNormal 2308 2309See also 2310[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2311 2312Constructs an output of a given shape with random numbers generated following 2313the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and 2314$$\sigma$$, and output shape have to have a floating point elemental type. The 2315parameters furthermore have to be scalar valued. 2316 2317<b>`RngNormal(mu, sigma, shape)`</b> 2318 2319| Arguments | Type | Semantics | 2320| --------- | ------- | --------------------------------------------------- | 2321| `mu` | `XlaOp` | Scalar of type T specifying mean of generated | 2322: : : numbers : 2323| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of | 2324: : : generated numbers : 2325| `shape` | `Shape` | Output shape of type T | 2326 2327## RngUniform 2328 2329See also 2330[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2331 2332Constructs an output of a given shape with random numbers generated following 2333the uniform distribution over the interval $$[a,b)$$. The parameters and output 2334element type have to be a boolean type, an integral type or a floating point 2335types, and the types have to be consistent. The CPU and GPU backends currently 2336only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the 2337parameters need to be scalar valued. If $$b <= a$$ the result is 2338implementation-defined. 2339 2340<b>`RngUniform(a, b, shape)`</b> 2341 2342| Arguments | Type | Semantics | 2343| --------- | ----------------------- | --------------------------------- | 2344| `a` | `XlaOp` | Scalar of type T specifying lower | 2345: : : limit of interval : 2346| `b` | `XlaOp` | Scalar of type T specifying upper | 2347: : : limit of interval : 2348| `shape` | `Shape` | Output shape of type T | 2349 2350## RngBitGenerator 2351 2352Generates an output with a given shape filled with uniform random bits using the 2353specified algorithm (or backend default) and returns an updated state (with the 2354same shape as initial state) and the generated random data. 2355 2356Initial state is the initial state of the current random number generation. It 2357and the required shape and valid values are dependent on the algorithm used. 2358 2359The output is guaranteed to be a deterministic function of the initial state but 2360it is *not* guaranteed to be deterministic between backends and different 2361compiler versions. 2362 2363<b>`RngBitGenerator(algorithm, key, shape)`</b> 2364 2365Arguments | Type | Semantics 2366--------------- | ----------------- | ------------------------------------- 2367`algorithm` | `RandomAlgorithm` | PRNG algorithm to be used. 2368`initial_state` | `XlaOp` | Initial state for the PRNG algorithm. 2369`shape` | `Shape` | Output shape for generated data. 2370 2371Available values for `algorithm`: 2372 2373- `rng_default`: Backend specific algorithm with backend specific shape 2374 requirements. 2375 2376- `rng_three_fry`: ThreeFry counter-based PRNG algorithm. The `initial_state` 2377 shape is `u64[2]` with arbitrary values. 2378 [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) 2379 2380- `rng_philox`: Philox algorithm to generate random numbers in parallel. The 2381 `initial_state` shape is `u64[3]` with arbitrary values. 2382 [Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) 2383 2384## Scatter 2385 2386The XLA scatter operation generates a result which is the value of the input 2387array `operand`, with several slices (at indices specified by `scatter_indices`) 2388updated with the values in `updates` using `update_computation`. 2389 2390See also 2391[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2392 2393<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b> 2394 2395Arguments | Type | Semantics 2396------------------------------ | ------------------- | --------- 2397`operand` | `XlaOp` | Array to be scattered into. 2398`scatter_indices` | `XlaOp` | Array containing the starting indices of the slices that must be scattered to. 2399`updates` | `XlaOp` | Array containing the values that must be used for scattering. 2400`update_computation` | `XlaComputation` | Computation to be used for combining the existing values in the input array and the updates during scatter. This computation should be of type `(T, T) -> T`. 2401`index_vector_dim` | `int64` | The dimension in `scatter_indices` that contains the starting indices. 2402`update_window_dims` | `ArraySlice<int64>` | The set of dimensions in `updates` shape that are _window dimensions_. 2403`inserted_window_dims` | `ArraySlice<int64>` | The set of _window dimensions_ that must be inserted into `updates` shape. 2404`scatter_dims_to_operand_dims` | `ArraySlice<int64>` | A dimensions map from the scatter indices to the operand index space. This array is interpreted as mapping `i` to `scatter_dims_to_operand_dims[i]` . It has to be one-to-one and total. 2405`indices_are_sorted` | `bool` | Whether the indices are guaranteed to be sorted by the caller. 2406 2407If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider 2408`scatter_indices` to have a trailing `1` dimension. 2409 2410We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of 2411dimensions in `updates` shape that are not in `update_window_dims`, in ascending 2412order. 2413 2414The arguments of scatter should follow these constraints: 2415 2416- `updates` array must be of rank `update_window_dims.size + 2417 scatter_indices.rank - 1`. 2418 2419- Bounds of dimension `i` in `updates` must conform to the following: 2420 2421 - If `i` is present in `update_window_dims` (i.e. equal to 2422 `update_window_dims`[`k`] for some `k`), then the bound of dimension `i` 2423 in `updates` must not exceed the corresponding bound of `operand` after 2424 accounting for the `inserted_window_dims` (i.e. 2425 `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains 2426 the bounds of `operand` with the bounds at indices 2427 `inserted_window_dims` removed). 2428 - If `i` is present in `update_scatter_dims` (i.e. equal to 2429 `update_scatter_dims`[`k`] for some `k`), then the bound of dimension 2430 `i` in `updates` must be equal to the corresponding bound of 2431 `scatter_indices`, skipping `index_vector_dim` (i.e. 2432 `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and 2433 `scatter_indices.shape.dims`[`k+1`] otherwise). 2434 2435- `update_window_dims` must be in ascending order, not have any repeating 2436 dimension numbers, and be in the range `[0, updates.rank)`. 2437 2438- `inserted_window_dims` must be in ascending order, not have any repeating 2439 dimension numbers, and be in the range `[0, operand.rank)`. 2440 2441- `operand.rank` must equal the sum of `update_window_dims.size` and 2442 `inserted_window_dims.size`. 2443 2444- `scatter_dims_to_operand_dims.size` must be equal to 2445 `scatter_indices`[`index_vector_dim`], and its values must be in the range 2446 `[0, operand.rank)`. 2447 2448For a given index `U` in the `updates` array, the corresponding index `I` in the 2449`operand` array into which this update has to be applied is computed as follows: 2450 24511. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up 2452 an index vector `S` in the `scatter_indices` array such that `S`[`i`] = 2453 `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at 2454 positions `index_vector_dim` into A. 24552. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering 2456 `S` using the `scatter_dims_to_operand_dims` map. More formally: 2457 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if 2458 `k` < `scatter_dims_to_operand_dims.size`. 2459 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. 24603. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices 2461 at `update_window_dims` in `U` according to `inserted_window_dims`. More 2462 formally: 2463 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if `k` 2464 is in `update_window_dims`, where `window_dims_to_operand_dims` is the 2465 monotonic function with domain [`0`, `update_window_dims.size`) and 2466 range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For example, if 2467 `update_window_dims.size` is `4`, `operand.rank` is `6`, and 2468 `inserted_window_dims` is {`0`, `2`} then `window_dims_to_operand_dims` 2469 is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}). 2470 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise. 24714. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise 2472 addition. 2473 2474In summary, the scatter operation can be defined as follows. 2475 2476- Initialize `output` with `operand`, i.e. for all indices `O` in the 2477 `operand` array: \ 2478 `output`[`O`] = `operand`[`O`] 2479- For every index `U` in the `updates` array and the corresponding index `O` 2480 in the `operand` array, if `O` is a valid index for `output`: \ 2481 `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`]) 2482 2483The order in which updates are applied is non-deterministic. So, when multiple 2484indices in `updates` refer to the same index in `operand`, the corresponding 2485value in `output` will be non-deterministic. 2486 2487Note that the first parameter that is passed into the `update_computation` will 2488always be the current value from the `output` array and the second parameter 2489will always be the value from the `updates` array. This is important 2490specifically for cases when the `update_computation` is _not commutative_. 2491 2492If `indices_are_sorted` is set to true then XLA can assume that `start_indices` 2493are sorted (in ascending `start_index_map` order) by the user. If they are not 2494then the semantics is implementation defined. 2495 2496Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. 2497the scatter op updates the elements in the input that are extracted by the 2498corresponding gather op. 2499 2500For a detailed informal description and examples, refer to the 2501"Informal Description" section under `Gather`. 2502 2503## Select 2504 2505See also 2506[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2507 2508Constructs an output array from elements of two input arrays, based on the 2509values of a predicate array. 2510 2511<b> `Select(pred, on_true, on_false)` </b> 2512 2513Arguments | Type | Semantics 2514---------- | ------- | ------------------ 2515`pred` | `XlaOp` | array of type PRED 2516`on_true` | `XlaOp` | array of type T 2517`on_false` | `XlaOp` | array of type T 2518 2519The arrays `on_true` and `on_false` must have the same shape. This is also the 2520shape of the output array. The array `pred` must have the same dimensionality as 2521`on_true` and `on_false`, with the `PRED` element type. 2522 2523For each element `P` of `pred`, the corresponding element of the output array is 2524taken from `on_true` if the value of `P` is `true`, and from `on_false` if the 2525value of `P` is `false`. As a restricted form of [broadcasting](broadcasting.md), 2526`pred` can be a scalar of type `PRED`. In this case, the output array is taken 2527wholly from `on_true` if `pred` is `true`, and from `on_false` if `pred` is `false`. 2528 2529Example with non-scalar `pred`: 2530 2531``` 2532let pred: PRED[4] = {true, false, false, true}; 2533let v1: s32[4] = {1, 2, 3, 4}; 2534let v2: s32[4] = {100, 200, 300, 400}; 2535==> 2536Select(pred, v1, v2) = s32[4]{1, 200, 300, 4}; 2537``` 2538 2539Example with scalar `pred`: 2540 2541``` 2542let pred: PRED = true; 2543let v1: s32[4] = {1, 2, 3, 4}; 2544let v2: s32[4] = {100, 200, 300, 400}; 2545==> 2546Select(pred, v1, v2) = s32[4]{1, 2, 3, 4}; 2547``` 2548 2549Selections between tuples are supported. Tuples are considered to be scalar 2550types for this purpose. If `on_true` and `on_false` are tuples (which must have 2551the same shape!) then `pred` has to be a scalar of type `PRED`. 2552 2553## SelectAndScatter 2554 2555See also 2556[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2557 2558This operation can be considered as a composite operation that first computes 2559`ReduceWindow` on the `operand` array to select an element from each window, and 2560then scatters the `source` array to the indices of the selected elements to 2561construct an output array with the same shape as the operand array. The binary 2562`select` function is used to select an element from each window by applying it 2563across each window, and it is called with the property that the first 2564parameter's index vector is lexicographically less than the second parameter's 2565index vector. The `select` function returns `true` if the first parameter is 2566selected and returns `false` if the second parameter is selected, and the 2567function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are 2568`true`, then `select(a, c)` is also `true`) so that the selected element does 2569not depend on the order of the elements traversed for a given window. 2570 2571The function `scatter` is applied at each selected index in the output array. It 2572takes two scalar parameters: 2573 25741. Current value at the selected index in the output array 25752. The scatter value from `source` that applies to the selected index 2576 2577It combines the two parameters and returns a scalar value that's used to update 2578the value at the selected index in the output array. Initially, all indices of 2579the output array are set to `init_value`. 2580 2581The output array has the same shape as the `operand` array and the `source` 2582array must have the same shape as the result of applying a `ReduceWindow` 2583operation on the `operand` array. `SelectAndScatter` can be used to 2584backpropagate the gradient values for a pooling layer in a neural network. 2585 2586<b>`SelectAndScatter(operand, select, window_dimensions, window_strides, 2587padding, source, init_value, scatter)`</b> 2588 2589| Arguments | Type | Semantics | 2590| ------------------- | ------------------- | -------------------------------- | 2591| `operand` | `XlaOp` | array of type T over which the | 2592: : : windows slide : 2593| `select` | `XlaComputation` | binary computation of type `T, T | 2594: : : -> PRED`, to apply to all : 2595: : : elements in each window; returns : 2596: : : `true` if the first parameter is : 2597: : : selected and returns `false` if : 2598: : : the second parameter is selected : 2599| `window_dimensions` | `ArraySlice<int64>` | array of integers for window | 2600: : : dimension values : 2601| `window_strides` | `ArraySlice<int64>` | array of integers for window | 2602: : : stride values : 2603| `padding` | `Padding` | padding type for window | 2604: : : (Padding\:\:kSame or : 2605: : : Padding\:\:kValid) : 2606| `source` | `XlaOp` | array of type T with the values | 2607: : : to scatter : 2608| `init_value` | `XlaOp` | scalar value of type T for the | 2609: : : initial value of the output : 2610: : : array : 2611| `scatter` | `XlaComputation` | binary computation of type `T, T | 2612: : : -> T`, to apply each scatter : 2613: : : source element with its : 2614: : : destination element : 2615 2616The figure below shows examples of using `SelectAndScatter`, with the `select` 2617function computing the maximal value among its parameters. Note that when the 2618windows overlap, as in the figure (2) below, an index of the `operand` array may 2619be selected multiple times by different windows. In the figure, the element of 2620value 9 is selected by both of the top windows (blue and red) and the binary 2621addition `scatter` function produces the output element of value 8 (2 + 6). 2622 2623<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2624 <img style="width:100%" 2625 src="./images/ops_scatter_to_selected_window_element.png"> 2626</div> 2627 2628The evaluation order of the `scatter` function is arbitrary and may be 2629non-deterministic. Therefore, the `scatter` function should not be overly 2630sensitive to reassociation. See the discussion about associativity in the 2631context of [`Reduce`](#reduce) for more details. 2632 2633## Send 2634 2635See also 2636[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2637 2638<b> `Send(operand, channel_handle)` </b> 2639 2640Arguments | Type | Semantics 2641---------------- | --------------- | ----------------------------------------- 2642`operand` | `XlaOp` | data to send (array of type T) 2643`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair 2644 2645Sends the given operand data to a `Recv` instruction in another computation 2646that shares the same channel handle. Does not return any data. 2647 2648Similar to the `Recv` operation, the client API of `Send` operation represents 2649synchronous communication, and is internally decomposed into 2 HLO instructions 2650(`Send` and `SendDone`) to enable asynchronous data transfers. See also 2651[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h). 2652 2653<b>`Send(HloInstruction operand, int64 channel_id)`</b> 2654 2655Initiates an asynchronous transfer of the operand to the resources allocated by 2656the `Recv` instruction with the same channel id. Returns a context, which is 2657used by a following `SendDone` instruction to wait for the completion of the 2658data transfer. The context is a tuple of {operand (shape), request identifier 2659(U32)} and it can only be used by a `SendDone` instruction. 2660 2661<b> `SendDone(HloInstruction context)` </b> 2662 2663Given a context created by a `Send` instruction, waits for the data transfer to 2664complete. The instruction does not return any data. 2665 2666<b> Scheduling of channel instructions </b> 2667 2668The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`, 2669`Send`, `SendDone`) is as below. 2670 2671<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2672 <img style="width:70%" src="./images/send_recv_order.png"> 2673</div> 2674 2675* `Recv` happens before `Send` 2676* `Send` happens before `RecvDone` 2677* `Recv` happens before `RecvDone` 2678* `Send` happens before `SendDone` 2679 2680When the backend compilers generate a linear schedule for each computation that 2681communicates via channel instructions, there must not be cycles across the 2682computations. For example, below schedules lead to deadlocks. 2683 2684<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2685 <img style="width:100%" src="./images/send_recv_schedule.png"> 2686</div> 2687 2688## Slice 2689 2690See also 2691[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2692 2693Slicing extracts a sub-array from the input array. The sub-array is of the same 2694rank as the input and contains the values inside a bounding box within the input 2695array where the dimensions and indices of the bounding box are given as 2696arguments to the slice operation. 2697 2698<b> `Slice(operand, start_indices, limit_indices)` </b> 2699 2700| Arguments | Type | Semantics | 2701| --------------- | ------------------- | ------------------------------------ | 2702| `operand` | `XlaOp` | N dimensional array of type T | 2703| `start_indices` | `ArraySlice<int64>` | List of N integers containing the | 2704: : : starting indices of the slice for : 2705: : : each dimension. Values must be : 2706: : : greater than or equal to zero. : 2707| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the | 2708: : : ending indices (exclusive) for the : 2709: : : slice for each dimension. Each value : 2710: : : must be greater than or equal to the : 2711: : : respective `start_indices` value for : 2712: : : the dimension and less than or equal : 2713: : : to the size of the dimension. : 2714| `strides` | `ArraySlice<int64>` | List of N integers that decides the | 2715: : : input stride of the slice. The slice : 2716: : : picks every `strides[d]` element in : 2717: : : dimension `d`. : 2718 2719 27201-dimensional example: 2721 2722``` 2723let a = {0.0, 1.0, 2.0, 3.0, 4.0} 2724Slice(a, {2}, {4}) produces: 2725 {2.0, 3.0} 2726``` 2727 27282-dimensional example: 2729 2730``` 2731let b = 2732 { {0.0, 1.0, 2.0}, 2733 {3.0, 4.0, 5.0}, 2734 {6.0, 7.0, 8.0}, 2735 {9.0, 10.0, 11.0} } 2736 2737Slice(b, {2, 1}, {4, 3}) produces: 2738 { { 7.0, 8.0}, 2739 {10.0, 11.0} } 2740``` 2741 2742## Sort 2743 2744See also 2745[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2746 2747<b>`Sort(operands, comparator, dimension, is_stable)`</b> 2748 2749Arguments | Type | Semantics 2750------------ | ------------------- | -------------------- 2751`operands` | `ArraySlice<XlaOp>` | The operands to sort. 2752`comparator` | `XlaComputation` | The comparator computation to use. 2753`dimension` | `int64` | The dimension along which to sort. 2754`is_stable` | `bool` | Whether stable sorting should be used. 2755 2756If only one operand is provided: 2757 2758* If the operand is a rank-1 tensor (an array), the result is a sorted array. 2759 If you want to sort the array into ascending order, the comparator should 2760 perform a less-than comparison. Formally, after the array is sorted, it holds 2761 for all index positions `i, j` with `i < j` that either 2762 `comparator(value[i], value[j]) = comparator(value[j], value[i]) = false` or 2763 `comparator(value[i], value[j]) = true`. 2764 2765* If the operand has higher rank, the operand is sorted along the provided 2766 dimension. For example, for a rank-2 tensor (a matrix), a dimension value of 2767 `0` will independently sort every column, and a dimension value of `1` will 2768 independently sort each row. If no dimension number is provided, then the last 2769 dimension is chosen by default. For the dimension which is sorted, the same 2770 sorting order applies as in the rank-1 case. 2771 2772If `n > 1` operands are provided: 2773 2774* All `n` operands must be tensors with the same dimensions. The element types 2775 of the tensors may be different. 2776 2777* All operands are sorted together, not individually. Conceptually the operands 2778 are treated as a tuple. When checking whether the elements of each operand at 2779 index positions `i` and `j` need to be swapped, the comparator is called with 2780 `2 * n` scalar parameters, where parameter `2 * k` corresponds to the value at 2781 position `i` from the `k-th` operand, and parameter `2 * k + 1` corresponds to 2782 the value at position `j` from the `k-th` operand. Usually, the comparator 2783 would thus compare parameters `2 * k` and `2 * k + 1` with each other and 2784 possibly use other parameter pairs as tie breakers. 2785 2786* The result is a tuple that consists of the operands in sorted order (along 2787 the provided dimension, as above). The `i-th` operand of the tuple corresponds 2788 to the `i-th` operand of Sort. 2789 2790For example, if there are three operands `operand0 = [3, 1]`, 2791`operand1 = [42, 50]`, `operand2 = [-3.0, 1.1]`, and the comparator compares 2792only the values of `operand0` with less-than, then the output of the sort is the 2793tuple `([1, 3], [50, 42], [1.1, -3.0])`. 2794 2795If `is_stable` is set to true, the sort is guaranteed to be stable, that is, if 2796there are elements which are considered to be equal by the comparator, the 2797relative order of the equal values is preserved. By default, `is_stable` is set 2798to false. 2799 2800## Transpose 2801 2802See also the `tf.reshape` operation. 2803 2804<b>`Transpose(operand)`</b> 2805 2806Arguments | Type | Semantics 2807------------- | ------------------- | ------------------------------ 2808`operand` | `XlaOp` | The operand to transpose. 2809`permutation` | `ArraySlice<int64>` | How to permute the dimensions. 2810 2811 2812Permutes the operand dimensions with the given permutation, so 2813`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`. 2814 2815This is the same as Reshape(operand, permutation, 2816 Permute(permutation, operand.shape.dimensions)). 2817 2818## TriangularSolve 2819 2820See also 2821[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2822 2823Solves systems of linear equations with lower or upper triangular coefficient 2824matrices by forward- or back-substitution. Broadcasting along leading 2825dimensions, this routine solves one of the matrix systems `op(a) * x = 2826b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is 2827either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`. 2828 2829<b> `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` </b> 2830 2831| Arguments | Type | Semantics | 2832| --------------- | ----------- | -------------------------------------------- | 2833| `a` | `XlaOp` | a rank > 2 array of a complex or | 2834: : : floating-point type with shape `[..., M, : 2835: : : M]`. : 2836| `b` | `XlaOp` | a rank > 2 array of the same type with shape | 2837: : : `[..., M, K]` if `left_side` is true, `[..., : 2838: : : K, M]` otherwise. : 2839| `left_side` | `bool` | indicates whether to solve a system of the | 2840: : : form `op(a) * x = b` (`true`) or `x * : 2841: : : op(a) = b` (`false`). : 2842| `lower` | `bool` | whether to use the upper or lower triangle | 2843: : : of `a`. : 2844| `unit_diagonal` | `bool` | if `true`, the diagonal elements of `a` are | 2845: : : assumed to be `1` and not accessed. : 2846| `transpose_a` | `Transpose` | whether to use `a` as is, transpose it or | 2847: : : take its conjugate transpose. : 2848 2849Input data is read only from the lower/upper triangle of `a`, depending on the 2850value of `lower`. Values from the other triangle are ignored. Output data is 2851returned in the same triangle; the values in the other triangle are 2852implementation-defined and may be anything. 2853 2854If the rank of `a` and `b` are greater than 2, they are treated as batches of 2855matrices, where all except the minor 2 dimensions are batch dimensions. `a` and 2856`b` must have equal batch dimensions. 2857 2858## Tuple 2859 2860See also 2861[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2862 2863A tuple containing a variable number of data handles, each of which has its own 2864shape. 2865 2866This is analogous to `std::tuple` in C++. Conceptually: 2867 2868``` 2869let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; 2870let s: s32 = 5; 2871let t: (f32[10], s32) = tuple(v, s); 2872``` 2873 2874Tuples can be deconstructed (accessed) via the [`GetTupleElement`] 2875(#gettupleelement) operation. 2876 2877## While 2878 2879See also 2880[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). 2881 2882<b> `While(condition, body, init)` </b> 2883 2884| Arguments | Type | Semantics | 2885| ----------- | ---------------- | ---------------------------------------- | 2886| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which | 2887: : : defines the termination condition of the : 2888: : : loop. : 2889| `body` | `XlaComputation` | XlaComputation of type `T -> T` which | 2890: : : defines the body of the loop. : 2891| `init` | `T` | Initial value for the parameter of | 2892: : : `condition` and `body`. : 2893 2894Sequentially executes the `body` until the `condition` fails. This is similar to 2895a typical while loop in many other languages except for the differences and 2896restrictions listed below. 2897 2898* A `While` node returns a value of type `T`, which is the result from the 2899 last execution of the `body`. 2900* The shape of the type `T` is statically determined and must be the same 2901 across all iterations. 2902 2903The T parameters of the computations are initialized with the `init` value in 2904the first iteration and are automatically updated to the new result from `body` 2905in each subsequent iteration. 2906 2907One main use case of the `While` node is to implement the repeated execution of 2908training in neural networks. Simplified pseudocode is shown below with a graph 2909that represents the computation. The code can be found in 2910[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc). 2911The type `T` in this example is a `Tuple` consisting of an `int32` for the 2912iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the 2913loop keeps adding a constant vector to the accumulator. 2914 2915``` 2916// Pseudocode for the computation. 2917init = {0, zero_vector[10]} // Tuple of int32 and float[10]. 2918result = init; 2919while (result(0) < 1000) { 2920 iteration = result(0) + 1; 2921 new_vector = result(1) + constant_vector[10]; 2922 result = {iteration, new_vector}; 2923} 2924``` 2925 2926<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;"> 2927 <img style="width:100%" src="./images/ops_while.png"> 2928</div> 2929