1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18 
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/status_macros.h"
38 #include "tensorflow/compiler/xla/statusor.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/xla_data.pb.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/stacktrace.h"
43 #include "tensorflow/core/platform/types.h"
44 
45 namespace xla {
46 
47 class XlaBuilder;
48 
49 // This represents an instruction that has been enqueued using the XlaBuilder.
50 // This is used to pass to subsequent computations that depends upon the
51 // instruction as an operand.
52 class XlaOp {
53  public:
XlaOp()54   XlaOp() : handle_(-1), builder_(nullptr) {
55     static_assert(std::is_trivially_destructible<XlaOp>::value,
56                   "XlaOp should be trivially destructible");
57   }
58   ~XlaOp() = default;
59 
60   XlaOp(const XlaOp& other) = default;
61   XlaOp& operator=(const XlaOp& other) = default;
62 
63   // Precondition: !IsUninitialized().
64   //
65   // It's very common to do foo.builder()->bar().  Without this precondition, if
66   // foo.builder() is null, the call to bar will segfault at some point possibly
67   // deep in the callstack when we finally dereference `this`.  The precondition
68   // lets us avoid this tricky-to-debug problem.
builder()69   XlaBuilder* builder() const {
70     CHECK(builder_ != nullptr);
71     return builder_;
72   }
73 
74   // Returns true if the XlaOp represents valid, non-erroneous value.
valid()75   bool valid() const { return handle_ >= 0; }
76 
77   // Returns true if the XlaOp was created by the XlaOp() constructor and
78   // not returned by a builder.
IsUninitialized()79   bool IsUninitialized() const { return builder_ == nullptr; }
80 
IsIdenticalTo(const XlaOp & rhs)81   bool IsIdenticalTo(const XlaOp& rhs) const {
82     return handle_ == rhs.handle_ && builder_ == rhs.builder_;
83   }
84 
85   friend std::ostream& operator<<(std::ostream& out, const XlaOp& op) {
86     out << op.handle();
87     return out;
88   }
89 
90  private:
XlaOp(XlaBuilder * builder)91   explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle,XlaBuilder * builder)92   XlaOp(int64 handle, XlaBuilder* builder)
93       : handle_(handle), builder_(builder) {}
94 
handle()95   int64 handle() const { return handle_; }
96 
97   friend class XlaBuilder;
98 
99   // < 0 means "invalid handle".
100   int64 handle_;
101 
102   // Not owned. Non-null for any handle returned by XlaBuilder, even if the
103   // handle is invalid.
104   XlaBuilder* builder_;
105 };
106 
107 // Arithmetic operator overloads for the XlaOp type.
108 XlaOp operator-(const XlaOp& x);
109 XlaOp operator+(const XlaOp& x, const XlaOp& y);
110 XlaOp operator-(const XlaOp& x, const XlaOp& y);
111 XlaOp operator*(const XlaOp& x, const XlaOp& y);
112 XlaOp operator/(const XlaOp& x, const XlaOp& y);
113 XlaOp operator%(const XlaOp& x, const XlaOp& y);
114 
115 // Bitwise operator overloads for the XlaOp type.
116 XlaOp operator~(const XlaOp& x);
117 XlaOp operator&(const XlaOp& x, const XlaOp& y);
118 XlaOp operator|(const XlaOp& x, const XlaOp& y);
119 XlaOp operator^(const XlaOp& x, const XlaOp& y);
120 XlaOp operator<<(const XlaOp& x, const XlaOp& y);
121 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
122 // a right logical shift.
123 XlaOp operator>>(const XlaOp& x, const XlaOp& y);
124 
125 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
126 // semantics might be surprising since their result types are usually 'bool'.
127 // Further programmers may expect == to be a structural equality.
128 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
129 // because the semantics might be misleading — XLA computations are immutable.
130 
131 // A convenient interface for building up computations.
132 //
133 // Thread-compatible.
134 class XlaBuilder {
135  public:
136   // computation_name: name to use for the built computation.
137   XlaBuilder(const string& computation_name);
138 
139   XlaBuilder(const XlaBuilder&) = delete;
140   XlaBuilder& operator=(const XlaBuilder&) = delete;
141 
142   ~XlaBuilder();
143 
144   // Returns the computation name.
name()145   const string& name() const { return name_; }
146 
147   // Sets OpMetadata that will be added to all instructions until cleared.
148   //
149   // OpMetadata is often applied to a series of XLA HLO instructions. As a
150   // result, OpMetadata is set on the Computation Builder. All subsequent
151   // instructions generated via this Computation Builder will have the same
152   // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(const OpMetadata & metadata)153   void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; }
154 
155   // Clears the HloMetadata state.
ClearOpMetadata()156   void ClearOpMetadata() { metadata_.Clear(); }
157 
158   // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)159   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
160 
161   // Clears the sharding. Ops will be sharded according to the default placement
162   // policy.
ClearSharding()163   void ClearSharding() { sharding_ = absl::nullopt; }
164 
165   // Returns the OpSharding that will be attached to all instructions.
sharding()166   const absl::optional<OpSharding>& sharding() const { return sharding_; }
167 
168   // Sets the builder to a mode where it will die immediately when an error is
169   // encountered, rather than producing it in a deferred fashion when Build() is
170   // called (which is the default).
set_die_immediately_on_error(bool enabled)171   void set_die_immediately_on_error(bool enabled) {
172     die_immediately_on_error_ = enabled;
173   }
174 
175   // Default dimension numbers used for a 2D convolution.
176   static constexpr int64 kConvBatchDimension = 0;
177   static constexpr int64 kConvFeatureDimension = 1;
178   static constexpr int64 kConvFirstSpatialDimension = 2;
179   static constexpr int64 kConvSecondSpatialDimension = 3;
180   static constexpr int64 kConvKernelOutputDimension = 0;
181   static constexpr int64 kConvKernelInputDimension = 1;
182   static constexpr int64 kConvKernelFirstSpatialDimension = 2;
183   static constexpr int64 kConvKernelSecondSpatialDimension = 3;
184 
185   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
186   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
187   // the kernel operand
188   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
189   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
190       int num_spatial_dims = 2);
191 
192   // Returns an error if the convolution dimension numbers have conflicts.
193   static Status Validate(const ConvolutionDimensionNumbers& dnum);
194 
195   // Returns a new XlaBuilder whose resultant Computation is used only by this
196   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
197   // behavior as the parent.
198   std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
199 
200   // Builds the computation with the requested operations, or returns a non-ok
201   // status. Note that all ops that have been enqueued will be moved to the
202   // computation being returned. The root of the computation will be the last
203   // added operation.
204   //
205   // `remove_dynamic_dimensions` tells the builder whether to remove the
206   // dyanmic dimensions information in all ops.
207   //
208   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
209   // dynamic dimensions information when XLA backend can handle dynamic
210   // dimensions.
211   StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = true);
212 
213   // Overload of Build which specifies a particular root instruction for the
214   // computation.
215   StatusOr<XlaComputation> Build(XlaOp root,
216                                  bool remove_dynamic_dimensions = true);
217 
218   // Builds the computation with the requested operations, or notes an error in
219   // the parent XlaBuilder and returns an empty computation if building failed.
220   // This function is intended to be used where the returned XlaComputation is
221   // only used by the parent XlaBuilder and hence further operation on the
222   // returned XlaComputation will simply be error'ed out if an error occurred
223   // while building this computation. If the built computation is to be used by
224   // a XlaBuilder other than the parent XlaBuilder then Build() should be used
225   // instead.
226   XlaComputation BuildAndNoteError();
227 
228   // Returns a subgraph that roots on the given root. If the root is not a
229   // compile-time constant (see `IsConstant`), returns an error.
230   //
231   // This will copy the needed ops/computations to the subgraph.
232   StatusOr<XlaComputation> BuildConstantSubGraph(const XlaOp& root_op);
233 
234   // Returns the first error that was encountered while building the
235   // computation. When an error is encountered, by default we return a vacuous
236   // XlaOp and inform the user of the error that occurred while
237   // building the computation when they make a final call to Build().
238   //
239   // See also set_die_immediately_on_error().
first_error()240   Status first_error() const { return first_error_; }
241 
242   // Returns the current status of the builder, complete with the stack trace
243   // information.
244   Status GetCurrentStatus() const;
245 
246   // Returns the shape of the given op.
247   StatusOr<Shape> GetShape(const XlaOp& op) const;
248 
249   // Returns the (inferred) result for the current computation's shape. This
250   // assumes the root instruction is the last added instruction.
251   StatusOr<ProgramShape> GetProgramShape() const;
252 
253   // Returns the (inferred) result for the current computation's shape using the
254   // given operation as the root.
255   StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
256 
257   // Reports an error to the builder, by
258   // * storing it internally and capturing a backtrace if it's the first error
259   //   (this deferred value will be produced on the call to
260   //    Build()/GetShape()/...)
261   // * dying if die_immediately_on_error_ is true.
262   // Returns an XlaOp with an invalid handle but a valid builder. This value can
263   // be returned in place of a value in APIs that return an XlaOp.
264   XlaOp ReportError(const Status& error);
265 
266   // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
267   // If the Status was an error, reports the error to builder and returns an
268   // invalid XlaOp handle.
269   XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
270 
271   // A helper function that runs a function that returns a StatusOr<XlaOp> and
272   // returns an XlaOp.
273   XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
274 
275   // Returns true if 'operand' is a compile-time constant. A compile-time
276   // constant does not depend on any parameters, or on stateful operators such
277   // as `RngNormal` or `Infeed`.
278   //
279   // This tests whether a computation is a compile-time constant without
280   // evaluating the computation.
281   StatusOr<bool> IsConstant(const XlaOp& operand) const;
282 
283   // Sets up binding which indicates that the `target_dim_num` in the subshape
284   // `target_param_index` of parameter `target_param_num` is a dynamic dimension
285   // and its real dynamic size is represented by `dynamic_param_index` in
286   // parameter `dynamic_param_num`.
287   //
288   // Note that this should be called before the dynamic parameters are used to
289   // create other operations, otherwise created operations won't have the
290   // dynamic dimensions information.
291   //
292   // TODO(b/119520625): Remove this API once we have more dynamic shape infra
293   // ready.
294   Status SetDynamicBinding(int64 dynamic_size_param_num,
295                            ShapeIndex dynamic_size_param_index,
296                            int64 target_param_num,
297                            ShapeIndex target_param_index, int64 target_dim_num);
298 
299   // Adds a new input/output alias. Since the input/ouput shape information are
300   // not available until the computation is built, and eventual error in the
301   // arguments of this API will be detected only at computation Build() time.
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index)302   void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
303                   const ShapeIndex& param_index) {
304     input_output_aliases_.push_back({output_index, param_number, param_index});
305   }
306 
307   // Describes an input/output alias as inserted by the SetUpAlias() API.
308   struct InputOutputAlias {
309     // Specifies the index of the aliased buffer in the result tuple.
310     ShapeIndex output_index;
311     // Specifies the parameter containing the buffer to be aliased.
312     int64 param_number;
313     // Specifies the index of the aliased buffer in the parameter
314     ShapeIndex param_index;
315   };
316 
317  private:
318   // Build helper which takes the id of the root operation..
319   StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
320 
321   // Description for the methods below can be found in the corresponding public
322   // functions section in this file.
323 
324   XlaOp Parameter(int64 parameter_number, const Shape& shape,
325                   const string& name);
326 
327   XlaOp ConstantLiteral(const LiteralSlice& literal);
328 
329   XlaOp Broadcast(const XlaOp& operand,
330                   absl::Span<const int64> broadcast_sizes);
331 
332   XlaOp BroadcastInDim(const XlaOp& operand,
333                        const absl::Span<const int64> out_dim_size,
334                        const absl::Span<const int64> broadcast_dimensions);
335 
336   XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
337             const PaddingConfig& padding_config);
338 
339   XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
340                 absl::Span<const int64> new_sizes);
341 
342   XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
343 
344   XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
345 
346   XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
347               absl::Span<const int64> limit_indices,
348               absl::Span<const int64> strides);
349 
350   XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
351                    int64 stride, int64 dimno);
352 
353   ABSL_DEPRECATED("Use span-of-indices form instead")
354   XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
355                      absl::Span<const int64> slice_sizes);
356   XlaOp DynamicSlice(const XlaOp& operand,
357                      absl::Span<const XlaOp> start_indices,
358                      absl::Span<const int64> slice_sizes);
359 
360   ABSL_DEPRECATED("Use span-of-indices form instead")
361   XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
362                            const XlaOp& start_indices);
363   XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
364                            absl::Span<const XlaOp> start_indices);
365 
366   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
367 
368   void Trace(const string& tag, const XlaOp& operand);
369 
370   XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
371 
372   XlaOp Tuple(absl::Span<const XlaOp> elements);
373 
374   XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
375 
376   XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
377             const PrecisionConfig* precision_config = nullptr);
378 
379   XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
380                    const DotDimensionNumbers& dimension_numbers,
381                    const PrecisionConfig* precision_config = nullptr);
382 
383   XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
384              absl::Span<const int64> window_strides, Padding padding,
385              int64 feature_group_count = 1, int64 batch_group_count = 1,
386              const PrecisionConfig* precision_config = nullptr);
387 
388   XlaOp ConvWithGeneralPadding(
389       const XlaOp& lhs, const XlaOp& rhs,
390       absl::Span<const int64> window_strides,
391       absl::Span<const std::pair<int64, int64>> padding,
392       int64 feature_group_count = 1, int64 batch_group_count = 1,
393       const PrecisionConfig* precision_config = nullptr);
394 
395   XlaOp ConvWithGeneralDimensions(
396       const XlaOp& lhs, const XlaOp& rhs,
397       absl::Span<const int64> window_strides, Padding padding,
398       const ConvolutionDimensionNumbers& dimension_numbers,
399       int64 feature_group_count = 1, int64 batch_group_count = 1,
400       const PrecisionConfig* precision_config = nullptr);
401 
402   XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
403                     absl::Span<const int64> window_strides,
404                     absl::Span<const std::pair<int64, int64>> padding,
405                     const ConvolutionDimensionNumbers& dimension_numbers,
406                     int64 feature_group_count = 1, int64 batch_group_count = 1,
407                     const PrecisionConfig* precision_config = nullptr);
408 
409   XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
410                            absl::Span<const int64> window_strides,
411                            absl::Span<const std::pair<int64, int64>> padding,
412                            absl::Span<const int64> lhs_dilation,
413                            absl::Span<const int64> rhs_dilation,
414                            const ConvolutionDimensionNumbers& dimension_numbers,
415                            int64 feature_group_count = 1,
416                            int64 batch_group_count = 1,
417                            const PrecisionConfig* precision_config = nullptr);
418 
419   XlaOp Fft(const XlaOp& operand, FftType fft_type,
420             absl::Span<const int64> fft_length);
421 
422   XlaOp Infeed(const Shape& shape, const string& config = "");
423   XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
424                         const string& config = "");
425 
426   void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
427                const string& outfeed_config);
428   XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
429                          const Shape& shape_with_layout,
430                          const string& outfeed_config);
431 
432   XlaOp Call(const XlaComputation& computation,
433              absl::Span<const XlaOp> operands);
434 
435   XlaOp CustomCall(
436       const string& call_target_name, absl::Span<const XlaOp> operands,
437       const Shape& shape_with_layout, const string& opaque,
438       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
439 
440   XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
441                const XlaComputation& computation,
442                absl::Span<const int64> dimensions_to_reduce);
443 
444   XlaOp Reduce(absl::Span<const XlaOp> operands,
445                absl::Span<const XlaOp> init_values,
446                const XlaComputation& computation,
447                absl::Span<const int64> dimensions_to_reduce);
448 
449   XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
450                   const XlaComputation& computation);
451 
452   XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
453                      const XlaComputation& computation,
454                      absl::Span<const int64> window_dimensions,
455                      absl::Span<const int64> window_strides, Padding padding);
456 
457   XlaOp ReduceWindowWithGeneralPadding(
458       const XlaOp& operand, const XlaOp& init_value,
459       const XlaComputation& computation,
460       absl::Span<const int64> window_dimensions,
461       absl::Span<const int64> window_strides,
462       absl::Span<const int64> base_dilations,
463       absl::Span<const int64> window_dilations,
464       absl::Span<const std::pair<int64, int64>> padding);
465 
466   XlaOp CrossReplicaSum(const XlaOp& operand,
467                         absl::Span<const ReplicaGroup> replica_groups = {});
468 
469   XlaOp CrossReplicaSum(
470       const XlaOp& operand, const XlaComputation& computation,
471       absl::Span<const ReplicaGroup> replica_groups = {},
472       const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
473 
474   XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
475                  int64 concat_dimension, int64 split_count,
476                  const std::vector<ReplicaGroup>& replica_groups);
477 
478   XlaOp CollectivePermute(
479       const XlaOp& operand,
480       const std::vector<std::pair<int64, int64>>& source_target_pairs);
481 
482   XlaOp ReplicaId();
483 
484   XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
485                          absl::Span<const int64> window_dimensions,
486                          absl::Span<const int64> window_strides,
487                          Padding padding, const XlaOp& source,
488                          const XlaOp& init_value,
489                          const XlaComputation& scatter);
490 
491   XlaOp SelectAndScatterWithGeneralPadding(
492       const XlaOp& operand, const XlaComputation& select,
493       absl::Span<const int64> window_dimensions,
494       absl::Span<const int64> window_strides,
495       absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
496       const XlaOp& init_value, const XlaComputation& scatter);
497 
498   XlaOp Iota(const Shape& shape, int64 iota_dimension);
499 
500   XlaOp Iota(PrimitiveType type, int64 size);
501 
502   XlaOp ConvertElementType(const XlaOp& operand,
503                            PrimitiveType new_element_type);
504 
505   XlaOp BitcastConvertType(const XlaOp& operand,
506                            PrimitiveType new_element_type);
507 
508   XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
509 
510   XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
511 
512   ABSL_DEPRECATED("Use form with comparator computation instead")
513   XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
514              int64 dimension = -1);
515   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
516              int64 dimension = -1, bool is_stable = false);
517 
518   XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
519 
520   XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
521             absl::Span<const int64> dimensions,
522             absl::Span<const XlaOp> static_operands = {});
523 
524   XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
525 
526   XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
527 
528   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
529               const XlaOp& init);
530 
531   XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
532                     const XlaComputation& true_computation,
533                     const XlaOp& false_operand,
534                     const XlaComputation& false_computation);
535 
536   XlaOp Conditional(const XlaOp& branch_index,
537                     absl::Span<const XlaComputation* const> branch_computations,
538                     absl::Span<const XlaOp> branch_operands);
539 
540   XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
541                         const int mantissa_bits);
542 
543   XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
544                const GatherDimensionNumbers& dimension_numbers,
545                absl::Span<const int64> slice_sizes);
546 
547   XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
548                 const XlaOp& updates, const XlaComputation& update_computation,
549                 const ScatterDimensionNumbers& dimension_numbers);
550 
551   void Send(const XlaOp& operand, const ChannelHandle& handle);
552   XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
553                       const ChannelHandle& handle);
554 
555   XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
556                    const Shape& shape_with_layout, const ChannelHandle& handle);
557 
558   XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
559                      const ChannelHandle& handle);
560 
561   XlaOp CreateToken();
562 
563   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
564 
565   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
566   XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
567                       const ChannelHandle& handle);
568 
569   XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
570                           const XlaOp& offset, float epsilon,
571                           int64 feature_index);
572 
573   XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
574                            const XlaOp& offset, const XlaOp& mean,
575                            const XlaOp& variance, float epsilon,
576                            int64 feature_index);
577 
578   XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
579                       const XlaOp& batch_mean, const XlaOp& batch_var,
580                       const XlaOp& grad_output, float epsilon,
581                       int64 feature_index);
582 
583   XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
584 
585   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
586                                  absl::Span<const XlaOp> operands = {});
587 
588   void AddCalledComputation(const XlaComputation& computation,
589                             HloInstructionProto* instr);
590 
591   StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
592   StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
593       int64 handle) const;
594 
595   // Internal helper method that does the building for an arbitrary unary op.
596   XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
597 
598   // Internal helper method that does the building for an arbitrary binary op.
599   // broadcast_dimensions specifies which dimensions to use for broadcasting
600   // when the operation is between tensors of different ranks. The direction is
601   // only used if opcode is kCompare.
602   XlaOp BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
603                  absl::Span<const int64> broadcast_dimensions,
604                  absl::optional<ComparisonDirection> direction = absl::nullopt);
605 
606   // Internal helper method that does the building for an arbitrary ternary op.
607   XlaOp TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
608                   const XlaOp& ehs);
609 
610   XlaOp RngOp(RandomDistribution distribution,
611               absl::Span<const XlaOp> parameters, const Shape& shape);
612 
613   StatusOr<XlaOp> InDimBroadcast(const Shape& shape, const XlaOp& operand,
614                                  absl::Span<const int64> broadcast_dimensions);
615 
616   // Internal helper method that creates a sequence of instructions that
617   // performs an explicit broadcast of the operand to the target shape.
618   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
619                                        const XlaOp& operand);
620 
621   // Internal helper method for creating a Reshape op with the already inferred
622   // shape.
623   StatusOr<XlaOp> Reshape(const Shape& shape, const XlaOp& operand);
624 
625   // Returns the (inferred) result for the program shape using the given root.
626   StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
627 
628   // Returns shapes for the operands.
629   StatusOr<std::vector<Shape>> GetOperandShapes(
630       absl::Span<const XlaOp> operands) const;
631 
632   // A visitor which checks whether an operation is a compile-time constant,
633   // meaning that it doesn't depend on any parameters, or on any stateful
634   // operation such as `RngNormal` or `Infeed`. The visitor walks the
635   // computation starting at a given operation and sets is_constant to false iff
636   // a parameter or stateful operation is encountered.
637   void IsConstantVisitor(const int64 op_handle,
638                          absl::flat_hash_set<int64>* visited,
639                          bool* is_constant) const;
640 
641   // Checks bounds for convolution parameters.
642   Status VerifyConvolution(
643       const Shape& lhs_shape, const Shape& rhs_shape,
644       const ConvolutionDimensionNumbers& dimension_numbers) const;
645 
646   // Helper function for creating a Window proto from user-supplied data.
647   // Returns error if the user-supplied data was invalid.
648   StatusOr<Window> MakeWindow(absl::Span<const int64> window_dimensions,
649                               absl::Span<const int64> window_strides,
650                               absl::Span<const std::pair<int64, int64>> padding,
651                               absl::Span<const int64> lhs_dilation,
652                               absl::Span<const int64> rhs_dilation) const;
653 
GetNextId()654   int64 GetNextId() { return ++next_id_; }
655 
656   // Populates the module with the input/output alias information stored within
657   // the input_output_aliases vector.
658   static Status PopulateInputOutputAlias(
659       HloModuleProto* module, const ProgramShape& program_shape,
660       const std::vector<InputOutputAlias>& input_output_aliases);
661 
662   string name_;  // Name to use for the built computation.
663 
664   // The next sequential ID for every instruction/computation contained within
665   // this computation.
666   int64 next_id_ = 0;
667 
668   // The first error encountered while building the computation.
669   // This is OK until the first error is encountered.
670   Status first_error_;
671 
672   // The saved stack trace from the point at which the first error occurred.
673   tensorflow::SavedStackTrace first_error_backtrace_;
674 
675   // The instructions of this computation.
676   std::vector<HloInstructionProto> instructions_;
677 
678   // Dynamic parameter configuration of this computation.
679   DynamicParameterBinding dynamic_parameter_binding_;
680 
681   // Holds the input/output alias information populated by the SetUpAlias() API.
682   std::vector<InputOutputAlias> input_output_aliases_;
683 
684   // A map from XlaOp::Handle to the index in the instructions_ vector where the
685   // instruction is held.
686   absl::flat_hash_map<int64, int64> handle_to_index_;
687 
688   // The embedded computations used by this computation. Each computation was
689   // the entry computation of some XlaComputation, the key is the unique id of
690   // that XlaComputation.
691   std::map<int64, HloComputationProto> embedded_;
692 
693   // The unique parameter numbers.
694   absl::flat_hash_set<int64> parameter_numbers_;
695 
696   // The metadata to attach to each op. This is structured as a "modal"-like
697   // operation, in order to simplify client code (and not sprinkle this metadata
698   // throughout the TensorFlow op kernel implementations).
699   OpMetadata metadata_;
700 
701   // Sharding for this operator. This is structured as a "model"-like operation,
702   // in order to simplify client code, similar to metadata_.
703   absl::optional<OpSharding> sharding_;
704 
705   // Mode bit that indicates whether to die when a first error is encountered.
706   bool die_immediately_on_error_ = false;
707 
708   XlaBuilder* parent_builder_{nullptr};
709 
710   friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
711                          const Shape& shape, const string& name);
712   friend XlaOp ConstantLiteral(XlaBuilder* builder,
713                                const LiteralSlice& literal);
714 
715   friend XlaOp Broadcast(const XlaOp& operand,
716                          absl::Span<const int64> broadcast_sizes);
717 
718   friend XlaOp BroadcastInDim(
719       const XlaOp& operand, const absl::Span<const int64> out_dim_size,
720       const absl::Span<const int64> broadcast_dimensions);
721 
722   friend XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
723                    const PaddingConfig& padding_config);
724 
725   friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
726                        absl::Span<const int64> new_sizes);
727 
728   friend XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
729 
730   friend XlaOp Collapse(const XlaOp& operand,
731                         absl::Span<const int64> dimensions);
732 
733   friend XlaOp Slice(const XlaOp& operand,
734                      absl::Span<const int64> start_indices,
735                      absl::Span<const int64> limit_indices,
736                      absl::Span<const int64> strides);
737 
738   friend XlaOp SliceInDim(const XlaOp& operand, int64 start_index,
739                           int64 limit_index, int64 stride, int64 dimno);
740 
741   friend XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
742                             absl::Span<const int64> slice_sizes);
743   friend XlaOp DynamicSlice(const XlaOp& operand,
744                             absl::Span<const XlaOp> start_indices,
745                             absl::Span<const int64> slice_sizes);
746 
747   friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
748                                   const XlaOp& start_indices);
749   friend XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
750                                   absl::Span<const XlaOp> start_indices);
751 
752   friend XlaOp ConcatInDim(XlaBuilder* builder,
753                            absl::Span<const XlaOp> operands, int64 dimension);
754 
755   friend void Trace(const string& tag, const XlaOp& operand);
756 
757   friend XlaOp Select(const XlaOp& pred, const XlaOp& on_true,
758                       const XlaOp& on_false);
759   friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
760   friend XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
761   friend XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
762                   absl::Span<const int64> broadcast_dimensions);
763   friend XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
764                   absl::Span<const int64> broadcast_dimensions);
765   friend XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
766                   absl::Span<const int64> broadcast_dimensions);
767   friend XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
768                   absl::Span<const int64> broadcast_dimensions);
769   friend XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
770                   absl::Span<const int64> broadcast_dimensions);
771   friend XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
772                   absl::Span<const int64> broadcast_dimensions);
773   friend XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
774                        absl::Span<const int64> broadcast_dimensions,
775                        ComparisonDirection direction);
776   friend XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
777                    const PrecisionConfig* precision_config);
778   friend XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
779                           const DotDimensionNumbers& dimension_number,
780                           const PrecisionConfig* precision_config);
781   friend XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
782                     absl::Span<const int64> window_strides, Padding padding,
783                     int64 feature_group_count, int64 batch_group_count,
784                     const PrecisionConfig* precision_config);
785   friend XlaOp ConvWithGeneralPadding(
786       const XlaOp& lhs, const XlaOp& rhs,
787       absl::Span<const int64> window_strides,
788       absl::Span<const std::pair<int64, int64>> padding,
789       int64 feature_group_count, int64 batch_group_count,
790       const PrecisionConfig* precision_config);
791   friend XlaOp ConvWithGeneralDimensions(
792       const XlaOp& lhs, const XlaOp& rhs,
793       absl::Span<const int64> window_strides, Padding padding,
794       const ConvolutionDimensionNumbers& dimension_numbers,
795       int64 feature_group_count, int64 batch_group_count,
796       const PrecisionConfig* precision_config);
797   friend XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
798                            absl::Span<const int64> window_strides,
799                            absl::Span<const std::pair<int64, int64>> padding,
800                            const ConvolutionDimensionNumbers& dimension_numbers,
801                            int64 feature_group_count, int64 batch_group_count,
802                            const PrecisionConfig* precision_config);
803   friend XlaOp ConvGeneralDilated(
804       const XlaOp& lhs, const XlaOp& rhs,
805       absl::Span<const int64> window_strides,
806       absl::Span<const std::pair<int64, int64>> padding,
807       absl::Span<const int64> lhs_dilation,
808       absl::Span<const int64> rhs_dilation,
809       const ConvolutionDimensionNumbers& dimension_numbers,
810       int64 feature_group_count, int64 batch_group_count,
811       const PrecisionConfig* precision_config);
812   friend XlaOp Fft(const XlaOp& operand, FftType fft_type,
813                    absl::Span<const int64> fft_length);
814   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
815                                bool unit_diagonal,
816                                TriangularSolveOptions::Transpose transpose_a);
817   friend XlaOp Cholesky(XlaOp a, bool lower);
818   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
819                       const string& config);
820   friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
821                       const string& outfeed_config);
822   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
823                     absl::Span<const XlaOp> operands);
824   friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
825                           absl::Span<const XlaOp> operands, const Shape& shape,
826                           const string& opaque);
827   friend XlaOp CustomCallWithLayout(
828       XlaBuilder* builder, const string& call_target_name,
829       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
830       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
831   friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
832                        absl::Span<const int64> broadcast_dimensions);
833   friend XlaOp Conj(const XlaOp& operand);
834   friend XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
835                    absl::Span<const int64> broadcast_dimensions);
836   friend XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
837                    absl::Span<const int64> broadcast_dimensions);
838   friend XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
839                    absl::Span<const int64> broadcast_dimensions);
840   friend XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
841                    absl::Span<const int64> broadcast_dimensions);
842   friend XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
843                    absl::Span<const int64> broadcast_dimensions);
844   friend XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
845                    absl::Span<const int64> broadcast_dimensions);
846   friend XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
847                    absl::Span<const int64> broadcast_dimensions);
848   friend XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
849                    absl::Span<const int64> broadcast_dimensions);
850   friend XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
851                   absl::Span<const int64> broadcast_dimensions);
852   friend XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
853                    absl::Span<const int64> broadcast_dimensions);
854   friend XlaOp Not(const XlaOp& operand);
855   friend XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
856                          absl::Span<const int64> broadcast_dimensions);
857   friend XlaOp ShiftRightArithmetic(
858       const XlaOp& lhs, const XlaOp& rhs,
859       absl::Span<const int64> broadcast_dimensions);
860   friend XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
861                                  absl::Span<const int64> broadcast_dimensions);
862   friend XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
863                       const XlaComputation& computation,
864                       absl::Span<const int64> dimensions_to_reduce);
865   friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
866                       absl::Span<const XlaOp> init_values,
867                       const XlaComputation& computation,
868                       absl::Span<const int64> dimensions_to_reduce);
869   friend XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
870                          const XlaComputation& computation);
871   friend XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
872                             const XlaComputation& computation,
873                             absl::Span<const int64> window_dimensions,
874                             absl::Span<const int64> window_strides,
875                             Padding padding);
876   friend XlaOp ReduceWindowWithGeneralPadding(
877       const XlaOp& operand, const XlaOp& init_value,
878       const XlaComputation& computation,
879       absl::Span<const int64> window_dimensions,
880       absl::Span<const int64> window_strides,
881       absl::Span<const int64> base_dilations,
882       absl::Span<const int64> window_dilations,
883       absl::Span<const std::pair<int64, int64>> padding);
884   friend XlaOp CrossReplicaSum(const XlaOp& operand,
885                                absl::Span<const ReplicaGroup> replica_groups);
886   friend XlaOp CrossReplicaSum(const XlaOp& operand,
887                                const XlaComputation& computation,
888                                absl::Span<const ReplicaGroup> replica_groups,
889                                const absl::optional<ChannelHandle>& channel_id);
890   friend XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
891                         int64 concat_dimension, int64 split_count,
892                         const std::vector<ReplicaGroup>& replica_groups);
893   friend XlaOp CollectivePermute(
894       const XlaOp& operand,
895       const std::vector<std::pair<int64, int64>>& source_target_pairs);
896   friend XlaOp ReplicaId(XlaBuilder* builder);
897   friend XlaOp SelectAndScatter(const XlaOp& operand,
898                                 const XlaComputation& select,
899                                 absl::Span<const int64> window_dimensions,
900                                 absl::Span<const int64> window_strides,
901                                 Padding padding, const XlaOp& source,
902                                 const XlaOp& init_value,
903                                 const XlaComputation& scatter);
904   friend XlaOp SelectAndScatterWithGeneralPadding(
905       const XlaOp& operand, const XlaComputation& select,
906       absl::Span<const int64> window_dimensions,
907       absl::Span<const int64> window_strides,
908       absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
909       const XlaOp& init_value, const XlaComputation& scatter);
910   friend XlaOp Abs(const XlaOp& operand);
911   friend XlaOp Atan2(const XlaOp& y, const XlaOp& x,
912                      absl::Span<const int64> broadcast_dimensions);
913   friend XlaOp Exp(const XlaOp& operand);
914   friend XlaOp Expm1(const XlaOp& operand);
915   friend XlaOp Floor(const XlaOp& operand);
916   friend XlaOp Ceil(const XlaOp& operand);
917   friend XlaOp Round(const XlaOp& operand);
918   friend XlaOp Log(const XlaOp& operand);
919   friend XlaOp Log1p(const XlaOp& operand);
920   friend XlaOp Sign(const XlaOp& operand);
921   friend XlaOp Clz(const XlaOp& operand);
922   friend XlaOp Cos(const XlaOp& operand);
923   friend XlaOp Sin(const XlaOp& operand);
924   friend XlaOp Tanh(const XlaOp& operand);
925   friend XlaOp Real(const XlaOp& operand);
926   friend XlaOp Imag(const XlaOp& operand);
927   friend XlaOp Sqrt(const XlaOp& operand);
928   friend XlaOp Rsqrt(const XlaOp& operand);
929   friend XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
930                    absl::Span<const int64> broadcast_dimensions);
931   friend XlaOp IsFinite(const XlaOp& operand);
932   friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
933                     int64 iota_dimension);
934   friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
935   friend XlaOp ConvertElementType(const XlaOp& operand,
936                                   PrimitiveType new_element_type);
937   friend XlaOp BitcastConvertType(const XlaOp& operand,
938                                   PrimitiveType new_element_type);
939   friend XlaOp Neg(const XlaOp& operand);
940   friend XlaOp Transpose(const XlaOp& operand,
941                          absl::Span<const int64> permutation);
942   friend XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
943   friend XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
944                     int64 dimension);
945   friend XlaOp Sort(absl::Span<const XlaOp> operands,
946                     const XlaComputation& comparator, int64 dimension,
947                     bool is_stable);
948   friend XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
949   friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
950                    const XlaComputation& computation,
951                    absl::Span<const int64> dimensions,
952                    absl::Span<const XlaOp> static_operands);
953   friend XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma,
954                          const Shape& shape);
955   friend XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
956   friend XlaOp While(const XlaComputation& condition,
957                      const XlaComputation& body, const XlaOp& init);
958   friend XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
959                            const XlaComputation& true_computation,
960                            const XlaOp& false_operand,
961                            const XlaComputation& false_computation);
962   friend XlaOp Conditional(
963       const XlaOp& branch_index,
964       absl::Span<const XlaComputation* const> branch_computations,
965       absl::Span<const XlaOp> branch_operands);
966   friend XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
967                                const int mantissa_bits);
968   friend XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
969                       const GatherDimensionNumbers& dimension_numbers,
970                       absl::Span<const int64> slice_sizes);
971   friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
972                        const XlaOp& updates,
973                        const XlaComputation& update_computation,
974                        const ScatterDimensionNumbers& dimension_numbers);
975   friend void Send(const XlaOp& operand, const ChannelHandle& handle);
976   friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
977                     const ChannelHandle& handle);
978   friend XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
979                                  const XlaOp& offset, float epsilon,
980                                  int64 feature_index);
981   friend XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
982                                   const XlaOp& offset, const XlaOp& mean,
983                                   const XlaOp& variance, float epsilon,
984                                   int64 feature_index);
985   friend XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
986                              const XlaOp& batch_mean, const XlaOp& batch_var,
987                              const XlaOp& grad_output, float epsilon,
988                              int64 feature_index);
989   friend XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
990                              const ChannelHandle& handle);
991   friend XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
992                              const ChannelHandle& handle);
993   friend XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
994                           const Shape& shape_with_layout,
995                           const ChannelHandle& handle);
996   friend XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
997                             const ChannelHandle& handle);
998   friend XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
999                                const string& config);
1000   friend XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
1001                                 const Shape& shape_with_layout,
1002                                 const string& outfeed_config);
1003   friend XlaOp CreateToken(XlaBuilder* builder);
1004   friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1005 
1006   friend XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
1007 };
1008 
1009 // RAII-style object: sets the current sharding assignment in builder on
1010 // construction, and sets back to the previous assignment on destruction.
1011 class XlaScopedShardingAssignment {
1012  public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1013   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1014                               absl::optional<OpSharding> sharding)
1015       : builder_(builder), prev_sharding_(builder->sharding()) {
1016     SetSharding(sharding);
1017   }
1018 
1019   XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1020   XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1021       delete;
1022 
~XlaScopedShardingAssignment()1023   ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1024 
1025  private:
SetSharding(const absl::optional<OpSharding> & sharding)1026   void SetSharding(const absl::optional<OpSharding>& sharding) {
1027     if (sharding.has_value()) {
1028       builder_->SetSharding(sharding.value());
1029     } else {
1030       builder_->ClearSharding();
1031     }
1032   }
1033 
1034   xla::XlaBuilder* const builder_;
1035   absl::optional<OpSharding> prev_sharding_;
1036 };
1037 
1038 // Free functions for building XlaOps. The intention is that these will
1039 // become the public API for building XlaOps rather than calling methods on
1040 // XlaBuilder directly.
1041 //
1042 
1043 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1044 // passed to the computation.
1045 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1046                 const string& name);
1047 
1048 // Enqueues a constant with the value of the given literal onto the
1049 // computation.
1050 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1051 
1052 // Enqueues a constant onto the computation. Methods are templated on the
1053 // native host type (NativeT) which corresponds to a specific XLA
1054 // PrimitiveType as given in the following table:
1055 //
1056 //  Native Type   PrimitiveType
1057 // -----------------------------
1058 //   bool           PRED
1059 //   int32          S32
1060 //   int64          S64
1061 //   uint32         U32
1062 //   uint64         U64
1063 //   float          F32
1064 //   double         F64
1065 //
1066 // Note: not all primitive types defined in xla_data.proto have a
1067 // corresponding native type yet.
1068 template <typename NativeT>
1069 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1070 template <typename NativeT>
1071 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1072 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1073 template <typename NativeT>
1074 XlaOp ConstantR2(XlaBuilder* builder,
1075                  std::initializer_list<std::initializer_list<NativeT>> values);
1076 template <typename NativeT>
1077 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1078                                   const Array<NativeT>& values,
1079                                   const Layout& layout);
1080 template <typename NativeT>
1081 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1082 template <typename NativeT>
1083 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1084                                       const Array2D<NativeT>& values,
1085                                       const Layout& layout);
1086 template <typename NativeT>
1087 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1088                             const Array2D<NativeT>& values);
1089 template <typename NativeT>
1090 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1091                                       const Array3D<NativeT>& values,
1092                                       const Layout& layout);
1093 template <typename NativeT>
1094 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1095                             const Array3D<NativeT>& values);
1096 template <typename NativeT>
1097 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1098                                       const Array4D<NativeT>& values,
1099                                       const Layout& layout);
1100 template <typename NativeT>
1101 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1102                             const Array4D<NativeT>& values);
1103 
1104 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1105 // computation. The vector has size 'length' and every element has the value
1106 // 'value'.
1107 template <typename NativeT>
1108 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
1109 
1110 // Adds dimensions to an array by duplicating the data in the array.
1111 //
1112 // The new dimensions are inserted on the left, i.e. if
1113 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1114 // has dimensions {b0, ..., bM} then the shape of the output has
1115 // dimensions {a0, ..., aN, b0, ..., bM}.
1116 //
1117 // The new dimensions index into copies of the operand, i.e.
1118 //
1119 //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1120 XlaOp Broadcast(const XlaOp& operand, absl::Span<const int64> broadcast_sizes);
1121 
1122 // This op broadcasts the `operand` to an output with the given `shape`.
1123 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1124 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1125 // dimension of the output. This also requires that the i'th input dimension is
1126 // either 1 or is the same as the output dimension it's broadcasting into.
1127 //
1128 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1129 // output shape is s32[2,2]:
1130 // - Specifying {1} as brodcast_dimension will generate output
1131 //   {{1, 2},
1132 //    {1, 2}}
1133 // - On the other hand, specifying {0} as broadcast_dimension
1134 //   will generate output
1135 //   {{1 , 1},
1136 //    {2 , 2}}
1137 XlaOp BroadcastInDim(const XlaOp& operand,
1138                      const absl::Span<const int64> out_dim_size,
1139                      const absl::Span<const int64> broadcast_dimensions);
1140 
1141 // Enqueues a pad operation onto the computation that pads the given value on
1142 // the edges as well as between the elements of the input. padding_config
1143 // specifies the padding amount for each dimension.
1144 XlaOp Pad(const XlaOp& operand, const XlaOp& padding_value,
1145           const PaddingConfig& padding_config);
1146 
1147 // Enqueues an operation onto the computation that flattens the operand based
1148 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1149 // given, followed by reshaping it into the shape with the given dimension
1150 // sizes (also major to minor). Conceptually, this is a limited form of
1151 // "shape casting".
1152 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> dimensions,
1153               absl::Span<const int64> new_sizes);
1154 
1155 // Enqueues an operation onto the computation that collapses the operand, from
1156 // first to last dimension (C order), then reshapes it to the given dimension
1157 // sizes. Conceptually, this is a limited form of "shape casting".
1158 XlaOp Reshape(const XlaOp& operand, absl::Span<const int64> new_sizes);
1159 
1160 // Wrapper for Reshape.
1161 // Enqueues an operation to collapse the provided dimensions; e.g. an
1162 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1163 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1164 // be a consecutive, in-order subsequence of the operand dimensions.
1165 //
1166 // Note that collapsing a single dimension does nothing:
1167 //
1168 //    {256} collapsing {0} => {256}
1169 //    {1} collapsing {0} => {1}
1170 //
1171 // Collapsing multiple dimensions produces a single result dimension:
1172 //
1173 //    {256, 2} collapsing {0,1} => {512}
1174 //    {256, 2, 3} collapsing {0,1} => {512, 3}
1175 //
1176 // This could potentially cause data to be moved -- it provides a more
1177 // structured form of reshaping than an arbitrary Reshape operation.
1178 XlaOp Collapse(const XlaOp& operand, absl::Span<const int64> dimensions);
1179 
1180 // Enqueues a slice operation onto the computation that slices the operand
1181 // from the start indices to the limit indices; e.g.
1182 //
1183 //        x
1184 //   [ 0 1 2 3 ]
1185 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1186 //   [ 8 9 a b ]
1187 //
1188 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1189 // range notation.
1190 // The strides parameter determines the stride over the slice
1191 XlaOp Slice(const XlaOp& operand, absl::Span<const int64> start_indices,
1192             absl::Span<const int64> limit_indices,
1193             absl::Span<const int64> strides);
1194 
1195 // Enqueues a slice operation in a given dimension, taking all other
1196 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1197 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1198 // for:
1199 //
1200 //  array[:, 2:4:1, :]
1201 XlaOp SliceInDim(const XlaOp& operand, int64 start_index, int64 limit_index,
1202                  int64 stride, int64 dimno);
1203 
1204 // Enqueues a slice operation onto the computation that slices the 'operand'
1205 // from dynamic start indices which are passed in 'start_indices'.
1206 // The size of the slice in each dimension is passed in 'slice_sizes',
1207 // which specify the end point of exclusive slice intervals in each
1208 // dimension [start, start + size).
1209 // The shape of each element of 'start_indices' must be scalar, with the span
1210 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1211 // have the same shape.
1212 // Slice index calculations are computed modulo input dimension sizes to
1213 // prevent dynamic start indices from generating out-of-bound array accesses.
1214 XlaOp DynamicSlice(const XlaOp& operand, absl::Span<const XlaOp> start_indices,
1215                    absl::Span<const int64> slice_sizes);
1216 
1217 ABSL_DEPRECATED("Use span-of-indices form instead")
1218 XlaOp DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
1219                    absl::Span<const int64> slice_sizes);
1220 
1221 // Enqueues a dynamic update slice operation onto the computation, which
1222 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1223 // The shape of 'update' determines the shape of the slice of 'operand'
1224 // which is updated.
1225 // The indices specified in 'start_indices' specify the offset of the slice
1226 // of 'operand' which is updated.
1227 //
1228 //               update = {10, 11} // calculated at runtime.
1229 //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
1230 //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
1231 //   [7 8 9]                                                  [7 8  9 ]
1232 //
1233 // The shape of each element of 'start_indices' must be scalar, with the span
1234 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1235 // have the same shape.
1236 // Slice index calculations are computed modulo update dimension sizes to
1237 // prevent dynamic start indices from generating out-of-bound array accesses.
1238 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
1239                          absl::Span<const XlaOp> start_indices);
1240 
1241 ABSL_DEPRECATED("Use span-of-indices form instead")
1242 XlaOp DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
1243                          const XlaOp& start_indices);
1244 
1245 // Enqueues a concatenate instruction onto the computation. 'operands' must
1246 // have >= 1 entry.
1247 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1248                   int64 dimension);
1249 
1250 // Enqueue a tracing operation onto the computation; the computation will emit
1251 // a logging message with the operand.
1252 void Trace(const string& tag, const XlaOp& operand);
1253 
1254 // Enqueues a conditional-move-like select operation onto the computation;
1255 // predicated on pred, selects between on_true and on_false.
1256 XlaOp Select(const XlaOp& pred, const XlaOp& on_true, const XlaOp& on_false);
1257 
1258 // Enqueues a tuple-creation instruction onto the computation.
1259 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1260 
1261 // Enqueues a tuple-element-get instruction onto the computation.
1262 XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
1263 
1264 // Enqueues an equal-to comparison instruction onto the computation.
1265 XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
1266          absl::Span<const int64> broadcast_dimensions = {});
1267 
1268 // Enqueues a not-equal comparison instruction onto the computation.
1269 XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
1270          absl::Span<const int64> broadcast_dimensions = {});
1271 
1272 // Enqueues a greater-or-equal comparison instruction onto the computation.
1273 XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
1274          absl::Span<const int64> broadcast_dimensions = {});
1275 
1276 // Enqueues a greater-than comparison instruction onto the computation.
1277 XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
1278          absl::Span<const int64> broadcast_dimensions = {});
1279 
1280 // Enqueues a less-than comparison instruction onto the computation.
1281 XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
1282          absl::Span<const int64> broadcast_dimensions = {});
1283 
1284 // Enqueues a less-or-equal comparison instruction onto the computation.
1285 XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
1286          absl::Span<const int64> broadcast_dimensions = {});
1287 
1288 // Enqueues a comparison instruction onto the computation.
1289 XlaOp Compare(const XlaOp& lhs, const XlaOp& rhs,
1290               absl::Span<const int64> broadcast_dimensions,
1291               ComparisonDirection direction);
1292 
1293 // Enqueues a dot instruction onto the computation.
1294 XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
1295           const PrecisionConfig* precision_config = nullptr);
1296 
1297 // Enqueues a general dot instruction onto the computation.
1298 XlaOp DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
1299                  const DotDimensionNumbers& dimension_numbers,
1300                  const PrecisionConfig* precision_config = nullptr);
1301 
1302 // Enqueues a convolution instruction onto the computation, which uses the
1303 // default convolution dimension numbers.
1304 XlaOp Conv(const XlaOp& lhs, const XlaOp& rhs,
1305            absl::Span<const int64> window_strides, Padding padding,
1306            int64 feature_group_count = 1, int64 batch_group_count = 1,
1307            const PrecisionConfig* precision_config = nullptr);
1308 
1309 // Enqueues a convolution instruction onto the computation, with the caller
1310 // provided padding configuration in the format returned by MakePadding().
1311 XlaOp ConvWithGeneralPadding(const XlaOp& lhs, const XlaOp& rhs,
1312                              absl::Span<const int64> window_strides,
1313                              absl::Span<const std::pair<int64, int64>> padding,
1314                              int64 feature_group_count = 1,
1315                              int64 batch_group_count = 1,
1316                              const PrecisionConfig* precision_config = nullptr);
1317 
1318 // Enqueues a convolution instruction onto the computation, with the caller
1319 // provided dimension numbers configuration.
1320 XlaOp ConvWithGeneralDimensions(
1321     const XlaOp& lhs, const XlaOp& rhs, absl::Span<const int64> window_strides,
1322     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1323     int64 feature_group_count = 1, int64 batch_group_count = 1,
1324     const PrecisionConfig* precision_config = nullptr);
1325 
1326 // Enqueues a convolution instruction onto the computation, with the caller
1327 // provided padding configuration as well as the dimension numbers.
1328 XlaOp ConvGeneral(const XlaOp& lhs, const XlaOp& rhs,
1329                   absl::Span<const int64> window_strides,
1330                   absl::Span<const std::pair<int64, int64>> padding,
1331                   const ConvolutionDimensionNumbers& dimension_numbers,
1332                   int64 feature_group_count = 1, int64 batch_group_count = 1,
1333                   const PrecisionConfig* precision_config = nullptr);
1334 
1335 // Enqueues a convolution instruction onto the computation, with the caller
1336 // provided padding configuration, dilation factors and dimension numbers.
1337 XlaOp ConvGeneralDilated(const XlaOp& lhs, const XlaOp& rhs,
1338                          absl::Span<const int64> window_strides,
1339                          absl::Span<const std::pair<int64, int64>> padding,
1340                          absl::Span<const int64> lhs_dilation,
1341                          absl::Span<const int64> rhs_dilation,
1342                          const ConvolutionDimensionNumbers& dimension_numbers,
1343                          int64 feature_group_count = 1,
1344                          int64 batch_group_count = 1,
1345                          const PrecisionConfig* precision_config = nullptr);
1346 
1347 // Enqueues an FFT instruction onto the computation, of the given type and
1348 // with the given FFT length.
1349 XlaOp Fft(const XlaOp& operand, FftType fft_type,
1350           absl::Span<const int64> fft_length);
1351 
1352 // Solves systems of linear equations with lower or upper triangular coefficient
1353 // matrices by forward- or back-substitution. Broadcasting along leading
1354 // dimensions, this routine solves for x in one of the matrix systems
1355 //   `op(a) * x = b`,  or `x * op(a) = b`,
1356 // for the variable `x` given `a` and `b`, where `op(a)` is either
1357 //   `op(a) = a`,  or `op(a) = transpose(a)`,  or `op(a) = conj(transpose(a))`.
1358 //
1359 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
1360 //   square matrices. If `lower` is true (false), then the strictly upper
1361 //   (lower) triangular part of each innermost matrix in `a` is assumed to be
1362 //   zero and is not accessed.
1363 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
1364 //   tensor of shape `[..., K, M]`.
1365 // * `left_side` is a boolean, indicating whether to solve a system of the form
1366 //   op(a) * x = b (true) or x * op(a) = b (false).
1367 // * `lower` is a boolean, indicating whether the argument `a` is
1368 //   lower-triangular (true) or upper-triangular (false).
1369 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
1370 //   1 and not accessed.
1371 // * `transpose_a` indicates which function `op` we use to transform the tensor
1372 //   `a`: the identity function, transpose(a), or conjugate(transpose(a))
1373 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1374                       bool unit_diagonal,
1375                       TriangularSolveOptions::Transpose transpose_a);
1376 
1377 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
1378 // positive definite matrices.
1379 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
1380 // two minor dimensions equal.
1381 // If `lower` is true, the data from the lower triangle is used; if false, the
1382 // upper triangle is used. The input data in the other triangle of the input
1383 // does not affect the output. Returns the output in the same lower/uppper
1384 // triangle. The data returned in the other output triangle is arbitrary and
1385 // implementation-defined.
1386 //
1387 // The value returned if `a` is not Hermitian positive definite is
1388 // implementation-defined.
1389 XlaOp Cholesky(XlaOp a, bool lower);
1390 
1391 // Enqueues an infeed instruction onto the computation, which writes data of
1392 // the given shape to the infeed buffer of the device.
1393 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1394              const string& config = "");
1395 
1396 // Variant of Infeed which takes a token-shaped operand and produces a
1397 // two-element tuple containing the data value and a token-shaped value.
1398 // Tokens are used for ordering side-effecting operations.
1399 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1400 XlaOp InfeedWithToken(const XlaOp& token, const Shape& shape,
1401                       const string& config = "");
1402 
1403 // Enqueues an outfeed instruction onto the computation. This instruction
1404 // generates outgoing data transfers for the given data.
1405 //
1406 // shape_with_layout communicates the laid out shape that we want to outfeed
1407 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
1408 // will occur.
1409 void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
1410              const string& outfeed_config);
1411 
1412 // Variant of Outfeed which takes a token-shaped operand and produces a
1413 // token-shaped value. Tokens are used for ordering side-effecting operations.
1414 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1415 XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
1416                        const Shape& shape_with_layout,
1417                        const string& outfeed_config);
1418 
1419 // Enqueues a call instruction onto the computation.
1420 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1421            absl::Span<const XlaOp> operands);
1422 
1423 // Enqueues a custom call instruction onto the computation. A custom call
1424 // invokes code external to XLA. The |operands| are passed to the external code,
1425 // and the external code is expected to produce a result of the given
1426 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
1427 // backend, a call instruction is emitted which targets a symbol with the name
1428 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
1429 // but |call_target_name| should be short as it may be used in labels. |opaque|
1430 // can encode arbitrarily large amounts of information.
1431 XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
1432                  absl::Span<const XlaOp> operands, const Shape& shape,
1433                  const string& opaque = "");
1434 
1435 // Overload which constructs a custom call with fixed layouts. The operands will
1436 // have the layouts specified by |operand_shapes_with_layout| when provided to
1437 // external code, and the external code is expected to produce a result with the
1438 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
1439 // and |operand_shapes_with_layout| must have layouts.
1440 XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
1441                            absl::Span<const XlaOp> operands,
1442                            const Shape& shape_with_layout,
1443                            absl::Span<const Shape> operand_shapes_with_layout,
1444                            const string& opaque = "");
1445 
1446 // The following methods enqueue element-wise binary arithmetic operations
1447 // onto the computation. The shapes of the operands have to match unless one
1448 // of the operands is a scalar, or an explicit broadcast dimension is given
1449 // (see g3doc for more details).
1450 
1451 // Enqueues a complex compose instruction onto the computation.
1452 XlaOp Complex(const XlaOp& real, const XlaOp& imag,
1453               absl::Span<const int64> broadcast_dimensions = {});
1454 
1455 // Enqueues a complex conjugate instruction onto the computation.
1456 XlaOp Conj(const XlaOp& operand);
1457 
1458 // Enqueues an add instruction onto the computation.
1459 XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
1460           absl::Span<const int64> broadcast_dimensions = {});
1461 
1462 // Enqueues a subtract instruction onto the computation.
1463 XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
1464           absl::Span<const int64> broadcast_dimensions = {});
1465 
1466 // Enqueues a multiply instruction onto the computation.
1467 XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
1468           absl::Span<const int64> broadcast_dimensions = {});
1469 
1470 // Enqueues a divide instruction onto the computation.
1471 XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
1472           absl::Span<const int64> broadcast_dimensions = {});
1473 
1474 // Enqueues a remainder instruction onto the computation.
1475 XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
1476           absl::Span<const int64> broadcast_dimensions = {});
1477 
1478 // Enqueues a max instruction onto the computation.
1479 XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
1480           absl::Span<const int64> broadcast_dimensions = {});
1481 
1482 // Enqueues a min instruction onto the computation.
1483 XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
1484           absl::Span<const int64> broadcast_dimensions = {});
1485 
1486 // Element-wise logical operators
1487 XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
1488           absl::Span<const int64> broadcast_dimensions = {});
1489 
1490 // Overload to call And with 3 or more operands.  We need the following somewhat
1491 // convoluted overload set to disambiguate with the overload that takes the
1492 // `broadcast_dimensions` optional param.
And(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3)1493 inline XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) {
1494   return And(op1, And(op2, op3));
1495 }
1496 template <typename... XlaOpTs>
And(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3,const XlaOpTs &...operands)1497 XlaOp And(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3,
1498           const XlaOpTs&... operands) {
1499   return And(op1, And(op2, And(op3, operands...)));
1500 }
1501 
1502 XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
1503          absl::Span<const int64> broadcast_dimensions = {});
1504 
1505 // Overload to call Or with 3 or more operands.  As with `And`, we need the
1506 // following complicated overload set to handle the default arg in the `Or`
1507 // overload above.
Or(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3)1508 inline XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3) {
1509   return Or(op1, Or(op2, op3));
1510 }
1511 template <typename... XlaOpTs>
Or(const XlaOp & op1,const XlaOp & op2,const XlaOp & op3,const XlaOpTs &...operands)1512 XlaOp Or(const XlaOp& op1, const XlaOp& op2, const XlaOp& op3,
1513          const XlaOpTs&... operands) {
1514   return Or(op1, Or(op2, Or(op3, operands...)));
1515 }
1516 
1517 XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
1518           absl::Span<const int64> broadcast_dimensions = {});
1519 
1520 XlaOp Not(const XlaOp& operand);
1521 
1522 XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
1523                 absl::Span<const int64> broadcast_dimensions = {});
1524 XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
1525                            absl::Span<const int64> broadcast_dimensions = {});
1526 XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
1527                         absl::Span<const int64> broadcast_dimensions = {});
1528 
1529 // Reduces an array among the provided dimensions, given "computation" as a
1530 // reduction operator.
1531 XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
1532              const XlaComputation& computation,
1533              absl::Span<const int64> dimensions_to_reduce);
1534 
1535 // Reduces several arrays simultaneously among the provided dimensions, given
1536 // "computation" as a reduction operator.
1537 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1538              absl::Span<const XlaOp> init_values,
1539              const XlaComputation& computation,
1540              absl::Span<const int64> dimensions_to_reduce);
1541 
1542 // Convenience wrapper around the above that reduces all the dimensions in the
1543 // operand shape.
1544 XlaOp ReduceAll(const XlaOp& operand, const XlaOp& init_value,
1545                 const XlaComputation& computation);
1546 
1547 // Enqueues a windowed reduce instruction onto the computation.
1548 XlaOp ReduceWindow(const XlaOp& operand, const XlaOp& init_value,
1549                    const XlaComputation& computation,
1550                    absl::Span<const int64> window_dimensions,
1551                    absl::Span<const int64> window_strides, Padding padding);
1552 
1553 // As ReduceWindow(), but the padding is given in the format
1554 // returned by MakePadding().
1555 XlaOp ReduceWindowWithGeneralPadding(
1556     const XlaOp& operand, const XlaOp& init_value,
1557     const XlaComputation& computation,
1558     absl::Span<const int64> window_dimensions,
1559     absl::Span<const int64> window_strides,
1560     absl::Span<const int64> base_dilations,
1561     absl::Span<const int64> window_dilations,
1562     absl::Span<const std::pair<int64, int64>> padding);
1563 
1564 // Returns the sum of the operand value within each subgroup of replicas. All
1565 // replicas supply one input to the sum and all replicas receive the resulting
1566 // sum for each subgroup.
1567 XlaOp CrossReplicaSum(const XlaOp& operand,
1568                       absl::Span<const ReplicaGroup> replica_groups = {});
1569 
1570 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
1571 // AllReduce means doing a reduction on the input operand cross cores and then
1572 // broadcasting the reduction result to those cores. The reduction function is
1573 // defined by `computation`, which should be a commutative computation on
1574 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
1575 // configured by:
1576 //
1577 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
1578 // empty, all replicas belong to one group. Allreduce will be applied within
1579 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
1580 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
1581 //
1582 // - `channel_id`: for Allreduce nodes from different modules, if they have the
1583 // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
1584 // applied cross modules.
1585 //
1586 // TODO(b/117564385): Rename this to AllReduce when it's ready to use.
1587 XlaOp CrossReplicaSum(
1588     const XlaOp& operand, const XlaComputation& computation,
1589     absl::Span<const ReplicaGroup> replica_groups = {},
1590     const absl::optional<ChannelHandle>& channel_id = absl::nullopt);
1591 
1592 // Enqueues an operation that do an Alltoall of the operand cross cores.
1593 XlaOp AllToAll(const XlaOp& operand, int64 split_dimension,
1594                int64 concat_dimension, int64 split_count,
1595                const std::vector<ReplicaGroup>& replica_groups = {});
1596 
1597 // Enqueues an collective operation that sends and receives data cross replicas.
1598 //
1599 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
1600 // pairs. For each pair, the operand is sent from source replica to target
1601 // replica. Note that, 1) any two pairs should not have the same target replica
1602 // id, and they should not have the same source replica id; 2) if a replica id
1603 // is not a target in any pair, then the output on that replica is a tensor
1604 // consists of 0(s) with the same shape as the input.
1605 XlaOp CollectivePermute(
1606     const XlaOp& operand,
1607     const std::vector<std::pair<int64, int64>>& source_target_pairs);
1608 
1609 // Enqueues an operation that returns the replica ID.
1610 XlaOp ReplicaId(XlaBuilder* builder);
1611 
1612 // Enqueues an operation that scatters the `source` array to the selected
1613 // indices of each window.
1614 XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
1615                        absl::Span<const int64> window_dimensions,
1616                        absl::Span<const int64> window_strides, Padding padding,
1617                        const XlaOp& source, const XlaOp& init_value,
1618                        const XlaComputation& scatter);
1619 
1620 // As SelectAndScatter(), but the padding is given in the format
1621 // returned by MakePadding().
1622 XlaOp SelectAndScatterWithGeneralPadding(
1623     const XlaOp& operand, const XlaComputation& select,
1624     absl::Span<const int64> window_dimensions,
1625     absl::Span<const int64> window_strides,
1626     absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
1627     const XlaOp& init_value, const XlaComputation& scatter);
1628 
1629 // Enqueues an abs instruction onto the computation.
1630 XlaOp Abs(const XlaOp& operand);
1631 
1632 // Enqueues a atan2 instruction onto the computation.
1633 XlaOp Atan2(const XlaOp& y, const XlaOp& x,
1634             absl::Span<const int64> broadcast_dimensions = {});
1635 
1636 // Enqueues an exp instruction onto the computation.
1637 XlaOp Exp(const XlaOp& operand);
1638 
1639 // Enqueues an expm1 instruction onto the computation.
1640 XlaOp Expm1(const XlaOp& operand);
1641 
1642 // Enqueues a floor instruction onto the computation.
1643 XlaOp Floor(const XlaOp& operand);
1644 
1645 // Enqueues a ceil instruction onto the computation.
1646 XlaOp Ceil(const XlaOp& operand);
1647 
1648 // Enqueues a round instruction onto the computation, rounding to nearest even
1649 // with half-way cases rounding away from zero.
1650 XlaOp Round(const XlaOp& operand);
1651 
1652 // Enqueues an log instruction (natural logarithm) onto the computation.
1653 XlaOp Log(const XlaOp& operand);
1654 
1655 // Enqueues an log1p instruction (log(x+1)) onto the computation.
1656 XlaOp Log1p(const XlaOp& operand);
1657 
1658 // Enqueues a sign instruction onto the computation.
1659 XlaOp Sign(const XlaOp& operand);
1660 
1661 // Enqueues a count leading zeros instruction onto the computation.
1662 XlaOp Clz(const XlaOp& operand);
1663 
1664 // Enqueues a cosine instruction onto the computation.
1665 XlaOp Cos(const XlaOp& operand);
1666 
1667 // Enqueues a sine instruction onto the computation.
1668 XlaOp Sin(const XlaOp& operand);
1669 
1670 // Enqueues a tanh instruction onto the computation.
1671 XlaOp Tanh(const XlaOp& operand);
1672 
1673 // Enqueues a real-part instruction onto the computation.
1674 XlaOp Real(const XlaOp& operand);
1675 
1676 // Enqueues an imaginary-part instruction onto the computation.
1677 XlaOp Imag(const XlaOp& operand);
1678 
1679 // Enqueues a sqrt computation onto the computation.
1680 XlaOp Sqrt(const XlaOp& operand);
1681 
1682 // Enqueues a rsqrt computation onto the computation.
1683 XlaOp Rsqrt(const XlaOp& operand);
1684 
1685 // Enqueues a lhs^rhs computation onto the computation.
1686 XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
1687           absl::Span<const int64> broadcast_dimensions = {});
1688 
1689 // Enqueues an operator that tests if the operand's values are finite, i.e., not
1690 // +/-Inf or NaN.  Returns an array of booleans with the same shape where
1691 // entries are true iff the corresponding entry was not infinite or NaN.
1692 //
1693 // Defined only for real-valued (i.e. not complex) floating-point types; raises
1694 // an error for other types.
1695 //
1696 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
1697 XlaOp IsFinite(const XlaOp& operand);
1698 
1699 // Enqueues an iota operation onto the computation.
1700 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
1701 
1702 // Enqueues a rank-1 iota operation onto the computation.
1703 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
1704 
1705 // Enqueues a convert instruction onto the computation that changes the
1706 // element type of the operand array to primitive_type.
1707 XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type);
1708 
1709 // Enqueues a no-op instruction onto the computation that changes
1710 // the element type of the operand array to primitive_type. The
1711 // bit-widths of the source and destination element types must be
1712 // identical.
1713 XlaOp BitcastConvertType(const XlaOp& operand, PrimitiveType new_element_type);
1714 
1715 // Enqueues a negate instruction onto the computation.
1716 XlaOp Neg(const XlaOp& operand);
1717 
1718 // Enqueues a transpose instruction onto the computation.
1719 XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
1720 
1721 // Enqueues a reverse instruction onto the computation. The order of the
1722 // elements in the given dimensions is reversed (i.e., the element at index i
1723 // is moved to index dimension_size - 1 - i).
1724 XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
1725 
1726 // Enqueues a sort (as increasing order) instruction onto the computation.
1727 // If only keys are provided:
1728 // * If the keys are an rank-1 tensor (an array), the result is a sorted array
1729 // of keys, in ascending order.
1730 // * If the keys have higher rank, the keys are sorted along the provided
1731 // dimension. For example, for a rank-2 tensor (a matrix) of keys, a dimension
1732 // value of 0 will independently sort every column, and a dimension value of 1
1733 // will independently sort each row. If no dimension number is provided, then
1734 // the last dimension is chosen by default.
1735 //
1736 // If both keys and values are provided:
1737 // * The keys and all values must be tensors with the same dimensions. The
1738 // element types of the tensors may be different.
1739 // * The result is a tuple that consists of a sorted tensor of keys (along the
1740 // provided dimension, as above) as the first element, and tensors with their
1741 // corresponding values as the other elements.
1742 ABSL_DEPRECATED("Use form with comparator computation instead")
1743 XlaOp Sort(const XlaOp& keys, absl::Span<const XlaOp> values = {},
1744            int64 dimension = -1);
1745 
1746 // Enqueues a sort instruction onto the computation, using 'comparator' for
1747 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
1748 // determines whether the stable sorting should be used.
1749 // If only one operand is provided:
1750 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
1751 //   The resulting sorting order has the property that for all index positions
1752 //   i, j with i < j, either
1753 //   comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
1754 //   comparator(value[i], value[j]) = true.
1755 // * If the operand has higher rank, the operand is sorted along the provided
1756 //   dimension. For example, for a rank-2 tensor (a matrix), a dimension value
1757 //   of 0 will independently sort every column, and a dimension value of 1 will
1758 //   independently sort each row. If no dimension number is provided, then the
1759 //   last dimension is chosen by default. For the dimension which is sorted, the
1760 //   same sorting order applies as in the rank-1 case.
1761 //
1762 // If more than one operand is provided:
1763 // * All operands must be tensors with the same dimensions. The element types of
1764 //   the tensors may be different.
1765 // * The result is a tuple that consists of the operands in sorted order (along
1766 //   the provided dimension, as above). The same permutation as implied by the
1767 //   comparison computation is applied to all operand tensors. When comparing
1768 //   two index positions, 'comparator' is called with 2 * n scalar parameters,
1769 //   where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
1770 //   two index positions.
1771 // Default comparator computations can be found in lib/comparators.h
1772 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
1773            int64 dimension = -1, bool is_stable = false);
1774 
1775 // Enqueues a clamp instruction onto the computation.
1776 XlaOp Clamp(const XlaOp& min, const XlaOp& operand, const XlaOp& max);
1777 
1778 // Enqueues a map instruction onto the computation.
1779 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1780           const XlaComputation& computation, absl::Span<const int64> dimensions,
1781           absl::Span<const XlaOp> static_operands = {});
1782 
1783 // Enqueues a N(mu, sigma) random number generation instruction onto the
1784 // computation.
1785 XlaOp RngNormal(const XlaOp& mu, const XlaOp& sigma, const Shape& shape);
1786 
1787 // Enqueues a U(a, b) random number generation instruction onto the
1788 // computation. Returns values in the semi-open interval [a, b).
1789 XlaOp RngUniform(const XlaOp& a, const XlaOp& b, const Shape& shape);
1790 
1791 // Enqueues a while node onto the computation.
1792 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
1793             const XlaOp& init);
1794 
1795 // Enqueues a conditional node onto the computation.
1796 XlaOp Conditional(const XlaOp& predicate, const XlaOp& true_operand,
1797                   const XlaComputation& true_computation,
1798                   const XlaOp& false_operand,
1799                   const XlaComputation& false_computation);
1800 
1801 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
1802 // conditional node onto the computation. N >= 1 branch_computations and
1803 // branch_operands are matched by index. branch_index selects the branch that
1804 // will be executed. Out of range branch_index uses the N-1'th
1805 // branch_computation as default.
1806 XlaOp Conditional(const XlaOp& branch_index,
1807                   absl::Span<const XlaComputation* const> branch_computations,
1808                   absl::Span<const XlaOp> branch_operands);
1809 
1810 // Enqueues a ReducePrecision node onto the computation.
1811 XlaOp ReducePrecision(const XlaOp& operand, const int exponent_bits,
1812                       const int mantissa_bits);
1813 
1814 // Enqueues a Gather node onto the computation.
1815 XlaOp Gather(const XlaOp& input, const XlaOp& start_indices,
1816              const GatherDimensionNumbers& dimension_numbers,
1817              absl::Span<const int64> slice_sizes);
1818 
1819 // Enqueues a Scatter node onto the computation.
1820 XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
1821               const XlaOp& updates, const XlaComputation& update_computation,
1822               const ScatterDimensionNumbers& dimension_numbers);
1823 
1824 // Enqueues a Send node onto the computation for device-to-device
1825 // communication. This operation sends the given operand to
1826 // a Recv instruction in a different computation that shares the same channel
1827 // handle.
1828 void Send(const XlaOp& operand, const ChannelHandle& handle);
1829 
1830 // Variant of Send which takes a token-shaped operand and produces a
1831 // token-shaped value.  Tokens are used for ordering side-effecting operations.
1832 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1833 XlaOp SendWithToken(const XlaOp& operand, const XlaOp& token,
1834                     const ChannelHandle& handle);
1835 
1836 // Enqueues a Recv node onto the computation for device-to-device
1837 // communication. The data comes from a Send instruction in a different
1838 // computation that shares the same channel handle and its shape must be the
1839 // same as the given shape.
1840 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1841            const ChannelHandle& handle);
1842 
1843 // Variant of Recv which takes a token-shaped operand and produces a two-element
1844 // tuple containing the data value and a token-shaped value. Tokens are used
1845 // for ordering side-effecting operations.
1846 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1847 XlaOp RecvWithToken(const XlaOp& token, const Shape& shape,
1848                     const ChannelHandle& handle);
1849 
1850 // Enqueues a Send node which transfers data from the device to the host. The
1851 // 'shape_with_layout' argument defines the layout of the data transferred; its
1852 // shape must be compatible with the shape of the operand. The operand must be
1853 // array-shaped.
1854 // TODO(b/111544877): Support tuple shapes.
1855 XlaOp SendToHost(const XlaOp& operand, const XlaOp& token,
1856                  const Shape& shape_with_layout, const ChannelHandle& handle);
1857 
1858 // Enqueues a Recv node which transfers data from the host to the device. The
1859 // given shape must contain a layout and must be an array.
1860 // TODO(b/111544877): Support tuple shapes.
1861 XlaOp RecvFromHost(const XlaOp& token, const Shape& shape,
1862                    const ChannelHandle& handle);
1863 
1864 // Enqueues an operation (AfterAll) with no operands that produces a
1865 // token-shaped value.  Tokens are used for ordering side-effecting operations.
1866 // This is a separate method from AfterAll to facility the removal of
1867 // operand-less AfterAll instructions.
1868 // TODO(b/110532604): Remove this function when all tokens are derived from a
1869 // single token generated or passed into the entry computation.
1870 XlaOp CreateToken(XlaBuilder* builder);
1871 
1872 // Enqueues an AfterAll instruction which produces a token-shaped value and
1873 // takes a variadic number of token-shaped operands. The number of operands must
1874 // be greater than zero. Used for joining tokens.
1875 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1876 
1877 // Normalizes operand across spatial and batch dimensions for each feature.
1878 //
1879 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
1880 // is the normalized result and batch_mean and batch_var are the mean and
1881 // variance, respectively, across batch for the operand.
1882 XlaOp BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
1883                         const XlaOp& offset, float epsilon,
1884                         int64 feature_index);
1885 
1886 // Normalizes operand across spatial and batch dimensions for each feature.
1887 //
1888 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
1889 // computing `mean` and `variance` for each batch inside the operation. It
1890 // uses the input `mean` and `variance` instead as estimated values. The
1891 // purpose of this op is to reduce latency in inference, hence the name
1892 // `BatchNormInference`.
1893 //
1894 // The output has the same shape as `operand`, and contains the normalized
1895 // values for each batch.
1896 XlaOp BatchNormInference(const XlaOp& operand, const XlaOp& scale,
1897                          const XlaOp& offset, const XlaOp& mean,
1898                          const XlaOp& variance, float epsilon,
1899                          int64 feature_index);
1900 
1901 // Calculates the gradients of a batch norm op.
1902 //
1903 // The inputs `batch_mean` and `batch_var` represent the mean and variance
1904 // across the batch.
1905 //
1906 // Returns a tuple of three elements:
1907 //   - grad_operand: Gradient with respect to input `operand`
1908 //   - grad_offset: Gradient with respect to input `offset`
1909 //   - grad_scale: Gradient with respect to input `scale`
1910 XlaOp BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
1911                     const XlaOp& batch_mean, const XlaOp& batch_var,
1912                     const XlaOp& grad_output, float epsilon,
1913                     int64 feature_index);
1914 
1915 // Returns the size of the given dimension of the operand. The operand must be
1916 // array shaped.
1917 XlaOp GetDimensionSize(const XlaOp& operand, int64 dimension);
1918 
1919 // Implementation details below this point.
1920 //
1921 
1922 // Free function template implementations.
1923 
1924 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)1925 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
1926   return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
1927 }
1928 
1929 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)1930 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
1931   return ConstantLiteral(builder, LiteralUtil::CreateR1<NativeT>(values));
1932 }
1933 
1934 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64 length,NativeT value)1935 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
1936   Literal literal(ShapeUtil::MakeShape(
1937       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
1938   literal.PopulateWithValue(value);
1939   return ConstantLiteral(builder, literal);
1940 }
1941 
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)1942 inline XlaOp ConstantR1(XlaBuilder* builder,
1943                         const tensorflow::core::Bitmap& values) {
1944   return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
1945 }
1946 
1947 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)1948 XlaOp ConstantR2(XlaBuilder* builder,
1949                  std::initializer_list<std::initializer_list<NativeT>> values) {
1950   return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
1951 }
1952 
1953 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)1954 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1955                                   const Array<NativeT>& values,
1956                                   const Layout& layout) {
1957   return ConstantLiteral(
1958       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
1959 }
1960 
1961 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)1962 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
1963   return ConstantLiteral(builder,
1964                          LiteralUtil::CreateFromArray<NativeT>(values));
1965 }
1966 
1967 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)1968 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1969                                       const Array2D<NativeT>& values,
1970                                       const Layout& layout) {
1971   return ConstantLiteral(
1972       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
1973 }
1974 
1975 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)1976 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1977                             const Array2D<NativeT>& values) {
1978   return ConstantLiteral(builder,
1979                          LiteralUtil::CreateR2FromArray2D<NativeT>(values));
1980 }
1981 
1982 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)1983 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1984                                       const Array3D<NativeT>& values,
1985                                       const Layout& layout) {
1986   return ConstantLiteral(
1987       builder,
1988       LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
1989 }
1990 
1991 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)1992 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1993                             const Array3D<NativeT>& values) {
1994   return ConstantFromArray(builder, values);
1995 }
1996 
1997 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)1998 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1999                                       const Array4D<NativeT>& values,
2000                                       const Layout& layout) {
2001   return ConstantFromArrayWithLayout(builder, values, layout);
2002 }
2003 
2004 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2005 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2006                             const Array4D<NativeT>& values) {
2007   return ConstantFromArray(builder, values);
2008 }
2009 
2010 }  // namespace xla
2011 
2012 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2013