1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
18 
19 #include <map>
20 #include <string>
21 #include <type_traits>
22 #include <utility>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/container/flat_hash_set.h"
26 #include "absl/strings/string_view.h"
27 #include "absl/types/span.h"
28 #include "tensorflow/compiler/xla/client/padding.h"
29 #include "tensorflow/compiler/xla/client/xla_computation.h"
30 #include "tensorflow/compiler/xla/comparison_util.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
34 #include "tensorflow/compiler/xla/service/hlo.pb.h"
35 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/statusor.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/xla_data.pb.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/stacktrace.h"
44 #include "tensorflow/core/platform/types.h"
45 
46 namespace xla {
47 
48 class XlaBuilder;
49 class XlaOp;
50 class HloInstruction;
51 
52 namespace internal {
53 
54 struct XlaBuilderFriend {
55   static XlaOp BuildFusion(XlaBuilder* builder,
56                            absl::Span<const XlaOp> operands,
57                            absl::string_view fusion_kind,
58                            const XlaComputation& fused_computation);
59 
60   static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
61                             const Shape& shape);
62 
63   static HloInstructionProto* GetInstruction(XlaOp op);
64 };
65 
66 }  // namespace internal
67 
68 // This represents an instruction that has been enqueued using the XlaBuilder.
69 // This is used to pass to subsequent computations that depends upon the
70 // instruction as an operand.
71 class XlaOp {
72  public:
XlaOp()73   XlaOp() : handle_(-1), builder_(nullptr) {
74     static_assert(std::is_trivially_destructible<XlaOp>::value,
75                   "XlaOp should be trivially destructible");
76   }
77   ~XlaOp() = default;
78 
79   XlaOp(const XlaOp& other) = default;
80   XlaOp& operator=(const XlaOp& other) = default;
81 
82   // Precondition: !IsUninitialized().
83   //
84   // It's very common to do foo.builder()->bar().  Without this precondition, if
85   // foo.builder() is null, the call to bar will segfault at some point possibly
86   // deep in the callstack when we finally dereference `this`.  The precondition
87   // lets us avoid this tricky-to-debug problem.
builder()88   XlaBuilder* builder() const {
89     CHECK(builder_ != nullptr);
90     return builder_;
91   }
92 
93   // Returns true if the XlaOp represents valid, non-erroneous value.
valid()94   bool valid() const { return handle_ >= 0; }
95 
96   // Returns true if the XlaOp was created by the XlaOp() constructor and
97   // not returned by a builder.
IsUninitialized()98   bool IsUninitialized() const { return builder_ == nullptr; }
99 
IsIdenticalTo(XlaOp rhs)100   bool IsIdenticalTo(XlaOp rhs) const {
101     return handle_ == rhs.handle_ && builder_ == rhs.builder_;
102   }
103 
104   friend std::ostream& operator<<(std::ostream& out, XlaOp op) {
105     out << op.handle();
106     return out;
107   }
108 
109  private:
XlaOp(XlaBuilder * builder)110   explicit XlaOp(XlaBuilder* builder) : handle_(-1), builder_(builder) {}
XlaOp(int64 handle,XlaBuilder * builder)111   XlaOp(int64 handle, XlaBuilder* builder)
112       : handle_(handle), builder_(builder) {}
113 
handle()114   int64 handle() const { return handle_; }
115 
116   friend class XlaBuilder;
117   friend class MlirHloBuilder;
118   friend struct internal::XlaBuilderFriend;
119 
120   // < 0 means "invalid handle".
121   int64 handle_;
122 
123   // Not owned. Non-null for any handle returned by XlaBuilder, even if the
124   // handle is invalid.
125   XlaBuilder* builder_;
126 };
127 
128 // Arithmetic operator overloads for the XlaOp type.
129 XlaOp operator-(XlaOp x);
130 XlaOp operator+(XlaOp x, XlaOp y);
131 XlaOp operator-(XlaOp x, XlaOp y);
132 XlaOp operator*(XlaOp x, XlaOp y);
133 XlaOp operator/(XlaOp x, XlaOp y);
134 XlaOp operator%(XlaOp x, XlaOp y);
135 
136 // Bitwise operator overloads for the XlaOp type.
137 XlaOp operator~(XlaOp x);
138 XlaOp operator&(XlaOp x, XlaOp y);
139 XlaOp operator|(XlaOp x, XlaOp y);
140 XlaOp operator^(XlaOp x, XlaOp y);
141 XlaOp operator<<(XlaOp x, XlaOp y);
142 // Performs a right arithmetic shift if 'x' is a signed type, otherwise performs
143 // a right logical shift.
144 XlaOp operator>>(XlaOp x, XlaOp y);
145 
146 // We don't overload the relational operators (==, !=, <, <=, >, >=) because the
147 // semantics might be surprising since their result types are usually 'bool'.
148 // Further programmers may expect == to be a structural equality.
149 // We also choose not to overload any of the mutating operators (e.g., +=, -=)
150 // because the semantics might be misleading — XLA computations are immutable.
151 
152 // A convenient interface for building up computations.
153 //
154 // Thread-compatible.
155 class XlaBuilder {
156  public:
157   // computation_name: name to use for the built computation.
158   XlaBuilder(const string& computation_name);
159 
160   XlaBuilder(const XlaBuilder&) = delete;
161   XlaBuilder& operator=(const XlaBuilder&) = delete;
162 
163   virtual ~XlaBuilder();
164 
165   // Returns the computation name.
name()166   const string& name() const { return name_; }
167 
168   // Sets OpMetadata that will be added to all instructions until cleared.
169   //
170   // OpMetadata is often applied to a series of XLA HLO instructions. As a
171   // result, OpMetadata is set on the computation builder. All subsequent
172   // instructions generated via this computation builder will have the same
173   // OpMetadata attached until a call to ClearOpMetadata.
SetOpMetadata(OpMetadata metadata)174   void SetOpMetadata(OpMetadata metadata) { metadata_ = std::move(metadata); }
175 
176   // Swaps the passed op metadata with the ones currently set.
177   //
178   // Returns the old op metadata.
SwapOpMetadata(OpMetadata metadata)179   OpMetadata SwapOpMetadata(OpMetadata metadata) {
180     OpMetadata old_metadata = std::move(metadata_);
181     metadata_ = std::move(metadata);
182     return old_metadata;
183   }
184 
185   // Similar to SetOpMetadata, but only set the metadata for the next op.
SetOneShotOpMetadata(OpMetadata metadata)186   void SetOneShotOpMetadata(OpMetadata metadata) {
187     one_shot_metadata_ = std::move(metadata);
188   }
189 
190   // Clears the HloMetadata state.
ClearOpMetadata()191   void ClearOpMetadata() { metadata_.Clear(); }
192 
193   // Sets an OpSharding that will be attached to all instructions until cleared.
SetSharding(const OpSharding & sharding)194   void SetSharding(const OpSharding& sharding) { sharding_ = sharding; }
195 
196   // Sets the FrontendAttributes that will be added to all instructions until
197   // cleared.
198   //
199   // FrontendAttributes are often applied to a series of XLA HLO instructions.
200   // As a result they are set on the computation builder and all the
201   // instructions generated via the computation builder will have the same
202   // frontend attributes attached to them.
SetFrontendAttributes(const FrontendAttributes & frontend_attributes)203   void SetFrontendAttributes(const FrontendAttributes& frontend_attributes) {
204     frontend_attributes_ = frontend_attributes;
205   }
206 
207   // Swap the passed FrontendAttributes with the ones currently set.
208   //
209   // Return the old attributes.
SwapFrontendAttributes(const FrontendAttributes & frontend_attributes)210   FrontendAttributes SwapFrontendAttributes(
211       const FrontendAttributes& frontend_attributes) {
212     FrontendAttributes old_attributes = std::move(frontend_attributes_);
213     frontend_attributes_ = frontend_attributes;
214     return old_attributes;
215   }
216 
217   // Returns the FrontendAttributes that will be attached to all instructions.
frontend_attributes()218   const FrontendAttributes& frontend_attributes() const {
219     return frontend_attributes_;
220   }
221 
222   // Clears all the frontend attributes.
ClearFrontendAttributes()223   void ClearFrontendAttributes() { frontend_attributes_.Clear(); }
224 
225   // Clears the sharding. Ops will be sharded according to the default placement
226   // policy.
ClearSharding()227   void ClearSharding() { sharding_ = absl::nullopt; }
228 
229   // Returns the OpSharding that will be attached to all instructions.
sharding()230   const absl::optional<OpSharding>& sharding() const { return sharding_; }
231 
232   // Sets the builder to a mode where it will die immediately when an error is
233   // encountered, rather than producing it in a deferred fashion when Build() is
234   // called (which is the default).
set_die_immediately_on_error(bool enabled)235   void set_die_immediately_on_error(bool enabled) {
236     die_immediately_on_error_ = enabled;
237   }
238 
239   // Default dimension numbers used for a 2D convolution.
240   static constexpr int64 kConvBatchDimension = 0;
241   static constexpr int64 kConvFeatureDimension = 1;
242   static constexpr int64 kConvFirstSpatialDimension = 2;
243   static constexpr int64 kConvSecondSpatialDimension = 3;
244   static constexpr int64 kConvKernelOutputDimension = 0;
245   static constexpr int64 kConvKernelInputDimension = 1;
246   static constexpr int64 kConvKernelFirstSpatialDimension = 2;
247   static constexpr int64 kConvKernelSecondSpatialDimension = 3;
248 
249   // Creates a default ConvolutionDimensionNumbers. For a 2D convolution, for
250   // the input operand {batch, feature, height, width} = {0, 1, 2, 3} and for
251   // the kernel operand
252   // {output_feature, input_feature, height, width} = {0, 1, 2, 3}.
253   static ConvolutionDimensionNumbers CreateDefaultConvDimensionNumbers(
254       int num_spatial_dims = 2);
255 
256   // Returns an error if the convolution dimension numbers have conflicts.
257   static Status Validate(const ConvolutionDimensionNumbers& dnum);
258 
259   // Returns a new XlaBuilder whose resultant Computation is used only by this
260   // XlaBuilder. The sub-XlaBuilder has the same die_immediately_on_error
261   // behavior as the parent.
262   std::unique_ptr<XlaBuilder> CreateSubBuilder(const string& computation_name);
263 
264   // Builds the computation with the requested operations, or returns a non-ok
265   // status. Note that all ops that have been enqueued will be moved to the
266   // computation being returned. The root of the computation will be the last
267   // added operation.
268   //
269   // `remove_dynamic_dimensions` tells the builder whether to remove the
270   // dynamic dimensions information in all ops.
271   //
272   // TODO(b/121223198): Delete `remove_dynamic_dimensions` and keeps the
273   // dynamic dimensions information when XLA backend can handle dynamic
274   // dimensions.
275   StatusOr<XlaComputation> Build(bool remove_dynamic_dimensions = false);
276 
277   // Overload of Build which specifies a particular root instruction for the
278   // computation.
279   StatusOr<XlaComputation> Build(XlaOp root,
280                                  bool remove_dynamic_dimensions = false);
281 
282   // Builds the computation with the requested operations, or notes an error in
283   // the parent XlaBuilder and returns an empty computation if building failed.
284   // This function is intended to be used where the returned XlaComputation is
285   // only used by the parent XlaBuilder and hence further operation on the
286   // returned XlaComputation will simply be error'ed out if an error occurred
287   // while building this computation. If the built computation is to be used by
288   // a XlaBuilder other than the parent XlaBuilder then Build() should be used
289   // instead.
290   XlaComputation BuildAndNoteError();
291 
292   // Returns a subgraph that roots on the given root. If the root is not a
293   // compile-time constant (see `IsConstant`), returns an error.
294   //
295   // This will copy the needed ops/computations to the subgraph.
296   StatusOr<XlaComputation> BuildConstantSubGraph(
297       XlaOp root_op, bool dynamic_dimension_is_uint_max = false);
298 
299   // Similar to BuildConstantSubGraph, but with root element type changed to
300   // boolean. A true value in the root indicates that the value is dynamic while
301   // false value indicates that the value is a constant. This will copy the
302   // needed ops/computations to the subgraph.
303   //
304   // E.g.,
305   // Compuptation {
306   //   a = 3
307   //   b = param(0)
308   //   ROOT Tuple(a + b, a + 1, b + 1)
309   // }
310   // Calling BuildDynamicInferenceGraph on root will produce the following
311   // graph:
312   //
313   // Compuptation {
314   //   a = False
315   //   b = True
316   //   ROOT Tuple(a | b, a, b)
317   // }
318   //
319   // The result, which is (True, False, True) after evaluation, can be
320   // interpreted as "First element is dynamic; Second element is static; Third
321   // element is dynamic".
322   StatusOr<XlaComputation> BuildDynamicInferenceGraph(XlaOp root_op);
323 
324   // Returns the first error that was encountered while building the
325   // computation. When an error is encountered, by default we return a vacuous
326   // XlaOp and inform the user of the error that occurred while
327   // building the computation when they make a final call to Build().
328   //
329   // See also set_die_immediately_on_error().
first_error()330   Status first_error() const { return first_error_; }
331 
332   // Returns the current status of the builder, complete with the stack trace
333   // information.
334   Status GetCurrentStatus() const;
335 
336   // Returns the shape of the given op.
337   StatusOr<Shape> GetShape(XlaOp op) const;
338 
339   // Returns the shape of the given op.
340   virtual StatusOr<const Shape*> GetShapePtr(XlaOp op) const;
341 
342   // Returns the (inferred) result for the current computation's shape. This
343   // assumes the root instruction is the last added instruction.
344   StatusOr<ProgramShape> GetProgramShape() const;
345 
346   // Returns the (inferred) result for the current computation's shape using the
347   // given operation as the root.
348   StatusOr<ProgramShape> GetProgramShape(XlaOp root) const;
349 
350   // Reports an error to the builder, by
351   // * storing it internally and capturing a backtrace if it's the first error
352   //   (this deferred value will be produced on the call to
353   //    Build()/GetShape()/...)
354   // * dying if die_immediately_on_error_ is true.
355   // Returns an XlaOp with an invalid handle but a valid builder. This value can
356   // be returned in place of a value in APIs that return an XlaOp.
357   XlaOp ReportError(const Status& error);
358 
359   // A helper function that converts a StatusOr<XlaOp> into an XlaOp.
360   // If the Status was an error, reports the error to builder and returns an
361   // invalid XlaOp handle.
362   XlaOp ReportErrorOrReturn(const StatusOr<XlaOp>& op);
363 
364   // A helper function that runs a function that returns a StatusOr<XlaOp> and
365   // returns an XlaOp.
366   XlaOp ReportErrorOrReturn(const std::function<StatusOr<XlaOp>()>& op_creator);
367 
368   // Returns true if 'operand' is a compile-time constant. A compile-time
369   // constant does not depend on any parameters, or on stateful operators such
370   // as `RngNormal` or `Infeed`.
371   //
372   // This tests whether a computation is a compile-time constant without
373   // evaluating the computation.
374   StatusOr<bool> IsConstant(XlaOp operand) const;
375 
376   // Sets up binding which indicates that the `target_dim_num` in the subshape
377   // `target_param_index` of parameter `target_param_num` is a dynamic dimension
378   // and its real dynamic size is represented by `dynamic_param_index` in
379   // parameter `dynamic_param_num`.
380   //
381   // Note that this should be called before the dynamic parameters are used to
382   // create other operations, otherwise created operations won't have the
383   // dynamic dimensions information.
384   //
385   // TODO(b/119520625): Remove this API once we have more dynamic shape infra
386   // ready.
387   ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.")
388   Status SetDynamicBinding(int64 dynamic_size_param_num,
389                            ShapeIndex dynamic_size_param_index,
390                            int64 target_param_num,
391                            ShapeIndex target_param_index, int64 target_dim_num);
392 
393   // Adds a new input/output alias. Since the input/output shape information are
394   // not available until the computation is built, and eventual error in the
395   // arguments of this API will be detected only at computation Build() time.
396   //
397   // Note: Except when 'must-alias' is true, alias is assumed to be 'may-alias'
398   // and only donated buffer at runtime will be aliased with output. If a buffer
399   // is not donated at runtime, a copy will be inserted by XLA to prevent buffer
400   // clobbering.
401   void SetUpAlias(const ShapeIndex& output_index, int64 param_number,
402                   const ShapeIndex& param_index,
403                   HloInputOutputAliasConfig::AliasKind kind =
404                       HloInputOutputAliasConfig::AliasKind::kMayAlias) {
405     input_output_aliases_.push_back(
406         {output_index, param_number, param_index, kind});
407   }
408 
409   // Describes an input/output alias as inserted by the SetUpAlias() API.
410   struct InputOutputAlias {
411     // Specifies the index of the aliased buffer in the result tuple.
412     ShapeIndex output_index;
413     // Specifies the parameter containing the buffer to be aliased.
414     int64 param_number;
415     // Specifies the index of the aliased buffer in the parameter
416     ShapeIndex param_index;
417     // Specifies if the alias is a must alias or may alias.
418     HloInputOutputAliasConfig::AliasKind kind;
419   };
420 
421   // Looks up the HloInstruction and sets the frontend attribute "attribute" to
422   // "value".
423   //
424   // If the attribute already existed then its value is updated.
425   //
426   // Note: the attribute is only added to the HloInstruction, not to the
427   // builder.
428   Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
429                                          string value);
430 
431   // Returns shapes for the operands.
432   StatusOr<std::vector<Shape>> GetOperandShapes(
433       absl::Span<const XlaOp> operands) const;
434 
435  private:
436   // Build helper which takes the id of the root operation..
437   StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
438 
439   // Description for the methods below can be found in the corresponding public
440   // functions section in this file.
441 
442   XlaOp Parameter(int64 parameter_number, const Shape& shape,
443                   const string& name,
444                   const std::vector<bool>& replicated_at_leaf_buffers);
Parameter(int64 parameter_number,const Shape & shape,const string & name)445   XlaOp Parameter(int64 parameter_number, const Shape& shape,
446                   const string& name) {
447     std::vector<bool> empty_bools;
448     return Parameter(parameter_number, shape, name, empty_bools);
449   }
450 
451   virtual XlaOp ConstantLiteral(const LiteralSlice& literal);
452 
453   XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
454 
455   XlaOp BroadcastInDim(XlaOp operand,
456                        const absl::Span<const int64> out_dim_size,
457                        const absl::Span<const int64> broadcast_dimensions);
458 
459   XlaOp Pad(XlaOp operand, XlaOp padding_value,
460             const PaddingConfig& padding_config);
461   XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno, int64 pad_lo,
462                  int64 pad_hi);
463 
464   virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
465                                       XlaOp padding_value,
466                                       const PaddingConfig& padding_config);
467 
468   XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
469                 absl::Span<const int64> new_sizes,
470                 int64 inferred_dimension = -1);
471 
472   XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
473                 int64 inferred_dimension = -1);
474 
475   XlaOp Reshape(const Shape& shape, XlaOp operand,
476                 int64 inferred_dimension = -1);
477 
478   XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
479                        absl::Span<const int64> new_size_bounds,
480                        const std::vector<bool>& dims_are_dynamic);
481 
482   XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
483 
484   XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
485               absl::Span<const int64> limit_indices,
486               absl::Span<const int64> strides);
487   virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
488                                         absl::Span<const int64> start_indices,
489                                         absl::Span<const int64> limit_indices,
490                                         absl::Span<const int64> strides);
491   virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
492                            int64 stride, int64 dimno);
493 
494   XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
495                      absl::Span<const int64> slice_sizes);
496   virtual StatusOr<XlaOp> DynamicSliceInternal(
497       const Shape& shape, XlaOp operand, absl::Span<const XlaOp> start_indices,
498       absl::Span<const int64> slice_sizes);
499 
500   XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
501                            absl::Span<const XlaOp> start_indices);
502   virtual StatusOr<XlaOp> DynamicUpdateSliceInternal(
503       const Shape& shape, XlaOp operand, XlaOp update,
504       absl::Span<const XlaOp> start_indices);
505 
506   XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
507   virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
508                                               absl::Span<const XlaOp> operands,
509                                               int64 dimension);
510 
511   void Trace(const string& tag, XlaOp operand);
512 
513   XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
514 
515   XlaOp Tuple(absl::Span<const XlaOp> elements);
516   virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
517                                         absl::Span<const XlaOp> elements);
518 
519   XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
520   virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
521                                                   XlaOp tuple_data,
522                                                   int64 index);
523 
524   XlaOp Dot(
525       XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr,
526       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
527 
528   XlaOp DotGeneral(
529       XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
530       const PrecisionConfig* precision_config = nullptr,
531       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
532 
533   XlaOp Conv(
534       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
535       Padding padding, int64 feature_group_count = 1,
536       int64 batch_group_count = 1,
537       const PrecisionConfig* precision_config = nullptr,
538       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
539 
540   XlaOp ConvWithGeneralPadding(
541       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
542       absl::Span<const std::pair<int64, int64>> padding,
543       int64 feature_group_count = 1, int64 batch_group_count = 1,
544       const PrecisionConfig* precision_config = nullptr,
545       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
546 
547   XlaOp ConvWithGeneralDimensions(
548       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
549       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
550       int64 feature_group_count = 1, int64 batch_group_count = 1,
551       const PrecisionConfig* precision_config = nullptr,
552       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
553 
554   XlaOp ConvGeneral(
555       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
556       absl::Span<const std::pair<int64, int64>> padding,
557       const ConvolutionDimensionNumbers& dimension_numbers,
558       int64 feature_group_count = 1, int64 batch_group_count = 1,
559       const PrecisionConfig* precision_config = nullptr,
560       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
561 
562   XlaOp ConvGeneralDilated(
563       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
564       absl::Span<const std::pair<int64, int64>> padding,
565       absl::Span<const int64> lhs_dilation,
566       absl::Span<const int64> rhs_dilation,
567       const ConvolutionDimensionNumbers& dimension_numbers,
568       int64 feature_group_count = 1, int64 batch_group_count = 1,
569       const PrecisionConfig* precision_config = nullptr,
570       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
571 
572   XlaOp DynamicConvForward(
573       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
574       absl::Span<const std::pair<int64, int64>> padding,
575       absl::Span<const int64> lhs_dilation,
576       absl::Span<const int64> rhs_dilation,
577       const ConvolutionDimensionNumbers& dimension_numbers,
578       int64 feature_group_count, int64 batch_group_count,
579       const PrecisionConfig* precision_config, PaddingType padding_type,
580       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
581 
582   XlaOp DynamicConvInputGrad(
583       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
584       absl::Span<const int64> window_strides,
585       absl::Span<const std::pair<int64, int64>> padding,
586       absl::Span<const int64> lhs_dilation,
587       absl::Span<const int64> rhs_dilation,
588       const ConvolutionDimensionNumbers& dimension_numbers,
589       int64 feature_group_count, int64 batch_group_count,
590       const PrecisionConfig* precision_config, PaddingType padding_type,
591       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
592 
593   XlaOp DynamicConvKernelGrad(
594       XlaOp activations, XlaOp gradients,
595       absl::Span<const int64> window_strides,
596       absl::Span<const std::pair<int64, int64>> padding,
597       absl::Span<const int64> lhs_dilation,
598       absl::Span<const int64> rhs_dilation,
599       const ConvolutionDimensionNumbers& dimension_numbers,
600       int64 feature_group_count, int64 batch_group_count,
601       const PrecisionConfig* precision_config, PaddingType padding_type,
602       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
603 
604   StatusOr<HloInstructionProto> DynamicConvInstruction(
605       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
606       absl::Span<const std::pair<int64, int64>> padding,
607       absl::Span<const int64> lhs_dilation,
608       absl::Span<const int64> rhs_dilation,
609       const ConvolutionDimensionNumbers& dimension_numbers,
610       int64 feature_group_count, int64 batch_group_count,
611       const PrecisionConfig* precision_config, PaddingType padding_type,
612       absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
613 
614   virtual StatusOr<XlaOp> ConvGeneralDilatedInternal(
615       const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
616       absl::Span<const int64> window_strides,
617       absl::Span<const std::pair<int64, int64>> padding,
618       absl::Span<const int64> lhs_dilation,
619       absl::Span<const int64> rhs_dilation,
620       const ConvolutionDimensionNumbers& dimension_numbers,
621       int64 feature_group_count, int64 batch_group_count,
622       const PrecisionConfig* precision_config);
623 
624   XlaOp Fft(XlaOp operand, FftType fft_type,
625             absl::Span<const int64> fft_length);
626   virtual StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
627                                       FftType fft_type,
628                                       absl::Span<const int64> fft_length);
629 
630   virtual StatusOr<XlaOp> TriangularSolveInternal(
631       const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options);
632 
633   virtual StatusOr<XlaOp> CholeskyInternal(const Shape& shape, XlaOp a,
634                                            bool lower);
635 
636   XlaOp Infeed(const Shape& shape, const string& config = "");
637   XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
638   virtual StatusOr<XlaOp> InfeedWithTokenInternal(
639       const Shape& infeed_instruction_shape, XlaOp token, const string& config);
640 
641   void Outfeed(XlaOp operand, const Shape& shape_with_layout,
642                const string& outfeed_config);
643   XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
644                          const Shape& shape_with_layout,
645                          const string& outfeed_config);
646   virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
647       XlaOp operand, XlaOp token, const Shape& shape_with_layout,
648       const string& outfeed_config);
649   XlaOp Call(const XlaComputation& computation,
650              absl::Span<const XlaOp> operands);
651 
652   XlaOp CustomCall(
653       const string& call_target_name, absl::Span<const XlaOp> operands,
654       const Shape& shape_with_layout, const string& opaque,
655       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
656       bool has_side_effect,
657       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
658           output_operand_aliasing,
659       const Literal* literal);
660 
661   // Internal version of CustomCall without computation that doesn't do op
662   // specific error handling and expects arguments to be legal. CustomCall
663   // method above calls this method after error handling.
664   virtual StatusOr<XlaOp> CustomCallInternal(
665       const string& call_target_name, absl::Span<const XlaOp> operands,
666       const Shape& shape_with_layout, const string& opaque,
667       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
668       bool has_side_effect,
669       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
670           output_operand_aliasing,
671       const Literal* literal);
672 
673   XlaOp CustomCall(
674       const string& call_target_name, absl::Span<const XlaOp> operands,
675       const XlaComputation& computation, const Shape& shape_with_layout,
676       const string& opaque,
677       absl::optional<absl::Span<const Shape>> operand_shapes_with_layout,
678       bool has_side_effect,
679       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
680           output_operand_aliasing,
681       const Literal* literal);
682 
683   XlaOp Reduce(XlaOp operand, XlaOp init_value,
684                const XlaComputation& computation,
685                absl::Span<const int64> dimensions_to_reduce);
686 
687   XlaOp Reduce(absl::Span<const XlaOp> operands,
688                absl::Span<const XlaOp> init_values,
689                const XlaComputation& computation,
690                absl::Span<const int64> dimensions_to_reduce);
691 
692   virtual StatusOr<XlaOp> ReduceInternal(
693       const Shape& shape, absl::Span<const XlaOp> all_operands,
694       const XlaComputation& computation,
695       absl::Span<const int64> dimensions_to_reduce);
696 
697   XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
698                   const XlaComputation& computation);
699 
700   XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
701                      const XlaComputation& computation,
702                      absl::Span<const int64> window_dimensions,
703                      absl::Span<const int64> window_strides, Padding padding);
704 
705   XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
706                      absl::Span<const XlaOp> init_values,
707                      const XlaComputation& computation,
708                      absl::Span<const int64> window_dimensions,
709                      absl::Span<const int64> window_strides, Padding padding);
710 
711   XlaOp ReduceWindowWithGeneralPadding(
712       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
713       const XlaComputation& computation,
714       absl::Span<const int64> window_dimensions,
715       absl::Span<const int64> window_strides,
716       absl::Span<const int64> base_dilations,
717       absl::Span<const int64> window_dilations,
718       absl::Span<const std::pair<int64, int64>> padding);
719   StatusOr<HloInstructionProto> ReduceWindowInternal(
720       absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
721       const XlaComputation& computation,
722       absl::Span<const int64> window_dimensions,
723       absl::Span<const int64> window_strides,
724       absl::Span<const int64> base_dilations,
725       absl::Span<const int64> window_dilations,
726       absl::Span<const std::pair<int64, int64>> padding);
727   virtual StatusOr<XlaOp> ReduceWindowInternal(
728       const Shape& shape, XlaOp operand, XlaOp init_value,
729       const XlaComputation& computation, Window window);
730   XlaOp CrossReplicaSum(XlaOp operand,
731                         absl::Span<const ReplicaGroup> replica_groups = {});
732 
733   XlaOp AllGather(
734       XlaOp operand, int64 all_gather_dimension, int64 shard_count,
735       absl::Span<const ReplicaGroup> replica_groups = {},
736       const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
737       const absl::optional<Layout>& layout = absl::nullopt,
738       const absl::optional<bool> use_global_device_ids = absl::nullopt);
739 
740   XlaOp AllReduce(
741       XlaOp operand, const XlaComputation& computation,
742       absl::Span<const ReplicaGroup> replica_groups = {},
743       const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
744       const absl::optional<Shape>& shape_with_layout = absl::nullopt);
745 
746   XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
747                  int64 split_count,
748                  const std::vector<ReplicaGroup>& replica_groups,
749                  const absl::optional<Layout>& layout = absl::nullopt);
750 
751   XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
752                       int64 concat_dimension, int64 split_count,
753                       const std::vector<ReplicaGroup>& replica_groups,
754                       const absl::optional<Layout>& layout);
755 
756   XlaOp CollectivePermute(
757       XlaOp operand,
758       const std::vector<std::pair<int64, int64>>& source_target_pairs);
759 
760   XlaOp ReplicaId();
761 
762   XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
763                          absl::Span<const int64> window_dimensions,
764                          absl::Span<const int64> window_strides,
765                          Padding padding, XlaOp source, XlaOp init_value,
766                          const XlaComputation& scatter);
767 
768   XlaOp SelectAndScatterWithGeneralPadding(
769       XlaOp operand, const XlaComputation& select,
770       absl::Span<const int64> window_dimensions,
771       absl::Span<const int64> window_strides,
772       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
773       XlaOp init_value, const XlaComputation& scatter);
774 
775   StatusOr<HloInstructionProto> SelectAndScatterInternal(
776       XlaOp operand, const XlaComputation& select,
777       absl::Span<const int64> window_dimensions,
778       absl::Span<const int64> window_strides,
779       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
780       XlaOp init_value, const XlaComputation& scatter);
781 
782   virtual XlaOp Iota(const Shape& shape, int64 iota_dimension);
783 
784   XlaOp Iota(PrimitiveType type, int64 size);
785 
786   XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
787 
788   XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
789   virtual StatusOr<XlaOp> BitcastConvertTypeInternal(const Shape& shape,
790                                                      XlaOp operand);
791 
792   XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
793   virtual StatusOr<XlaOp> TransposeInternal(
794       const Shape& shape, XlaOp operand, absl::Span<const int64> permutation);
795 
796   XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
797   virtual StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
798                                       absl::Span<const int64> dimensions);
799 
800   XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
801              int64 dimension = -1, bool is_stable = false);
802   virtual StatusOr<XlaOp> SortInternal(const Shape& shape,
803                                        absl::Span<const XlaOp> operands,
804                                        const XlaComputation& comparator,
805                                        int64 dimension, bool is_stable);
806 
807   XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
808 
809   XlaOp Map(absl::Span<const XlaOp> operands, const XlaComputation& computation,
810             absl::Span<const int64> dimensions,
811             absl::Span<const XlaOp> static_operands = {});
812 
813   XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
814 
815   XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
816 
817   XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
818                         const Shape& shape);
819   // Internal variant for the op with the full result shape containing both data
820   // and state shape as a tuple.
821   virtual StatusOr<XlaOp> RngBitGeneratorInternal(
822       const Shape& full_result_shape, RandomAlgorithm algorithm,
823       XlaOp initial_state);
824 
825   XlaOp While(const XlaComputation& condition, const XlaComputation& body,
826               XlaOp init);
827   virtual StatusOr<XlaOp> WhileInternal(const Shape& shape,
828                                         const XlaComputation& condition,
829                                         const XlaComputation& body, XlaOp init);
830 
831   XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
832                     const XlaComputation& true_computation, XlaOp false_operand,
833                     const XlaComputation& false_computation);
834 
835   XlaOp Conditional(XlaOp branch_index,
836                     absl::Span<const XlaComputation* const> branch_computations,
837                     absl::Span<const XlaOp> branch_operands);
838 
839   XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
840                         const int mantissa_bits);
841   virtual StatusOr<XlaOp> ReducePrecisionInternal(const Shape& shape,
842                                                   XlaOp operand,
843                                                   const int exponent_bits,
844                                                   const int mantissa_bits);
845 
846   XlaOp Gather(XlaOp input, XlaOp start_indices,
847                const GatherDimensionNumbers& dimension_numbers,
848                absl::Span<const int64> slice_sizes,
849                bool indices_are_sorted = false);
850 
851   virtual StatusOr<XlaOp> GatherInternal(
852       const Shape& shape, XlaOp input, XlaOp start_indices,
853       const GatherDimensionNumbers& dimension_numbers,
854       absl::Span<const int64> slice_sizes, bool indices_are_sorted);
855 
856   XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
857                 const XlaComputation& update_computation,
858                 const ScatterDimensionNumbers& dimension_numbers,
859                 bool indices_are_sorted = false, bool unique_indices = false);
860 
861   virtual StatusOr<XlaOp> ScatterInternal(
862       const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
863       const XlaComputation& update_computation,
864       const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
865       bool unique_indices);
866 
867   void Send(XlaOp operand, const ChannelHandle& handle);
868   XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
869 
870   XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
871                    const ChannelHandle& handle);
872 
873   XlaOp RecvFromHost(XlaOp token, const Shape& shape,
874                      const ChannelHandle& handle);
875 
876   virtual XlaOp CreateToken();
877 
878   XlaOp AfterAll(absl::Span<const XlaOp> tokens);
879 
880   XlaOp Recv(const Shape& shape, const ChannelHandle& handle);
881   XlaOp RecvWithToken(XlaOp token, const Shape& shape,
882                       const ChannelHandle& handle);
883 
884   XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
885                           float epsilon, int64 feature_index);
886 
887   XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
888                            XlaOp variance, float epsilon, int64 feature_index);
889 
890   XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
891                       XlaOp batch_var, XlaOp grad_output, float epsilon,
892                       int64 feature_index);
893 
894   XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
895 
896   XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
897 
898   virtual StatusOr<XlaOp> SetDimensionSizeInternal(const Shape& shape,
899                                                    XlaOp operand, XlaOp val,
900                                                    int64 dimension);
901 
902   XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
903 
904   virtual StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
905                                          HloOpcode opcode,
906                                          absl::Span<const XlaOp> operands);
AddInstruction(HloInstructionProto && instr,HloOpcode opcode)907   StatusOr<XlaOp> AddInstruction(HloInstructionProto&& instr,
908                                  HloOpcode opcode) {
909     return AddInstruction(std::move(instr), opcode, /*operands=*/{});
910   }
911 
912   void AddCalledComputation(const XlaComputation& computation,
913                             HloInstructionProto* instr);
914 
915   StatusOr<const HloInstructionProto*> LookUpInstruction(XlaOp op) const;
916   StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
917       int64 handle) const;
918   StatusOr<HloInstructionProto*> LookUpMutableInstruction(XlaOp op);
919   StatusOr<HloInstructionProto*> LookUpMutableInstructionByHandle(int64 handle);
920 
921   // Internal helper method that does the building for an arbitrary unary op.
922   virtual XlaOp UnaryOp(HloOpcode unop, XlaOp operand);
923 
924   // Internal helper method that does the building for an arbitrary binary op.
925   // broadcast_dimensions specifies which dimensions to use for broadcasting
926   // when the operation is between tensors of different ranks. The direction is
927   // only used if opcode is kCompare.
928   XlaOp BinaryOp(HloOpcode binop, XlaOp lhs, XlaOp rhs,
929                  absl::Span<const int64> broadcast_dimensions,
930                  absl::optional<ComparisonDirection> direction = absl::nullopt,
931                  absl::optional<Comparison::Type> type = absl::nullopt);
932 
933   StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
934                           ComparisonDirection direction);
935 
936   // Internal helper method for binary op compare without broadcast dimensions.
937   virtual StatusOr<XlaOp> Compare(const Shape& shape, XlaOp lhs, XlaOp rhs,
938                                   ComparisonDirection direction,
939                                   Comparison::Type type);
940 
941   // Internal helper method that does the building for an arbitrary binary op
942   // with same ranked operands that doesn't broadcast.
943   virtual XlaOp BinaryOpNoBroadcast(HloOpcode binop, const Shape& shape,
944                                     XlaOp lhs, XlaOp rhs);
945 
946   // Internal helper method that does the building for an arbitrary ternary op.
947   XlaOp TernaryOp(HloOpcode triop, XlaOp lhs, XlaOp rhs, XlaOp ehs);
948 
949   XlaOp RngOp(RandomDistribution distribution,
950               absl::Span<const XlaOp> parameters, const Shape& shape);
951 
952   virtual StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
953                                         absl::Span<const XlaOp> parameters,
954                                         const Shape& shape);
955 
956   virtual StatusOr<XlaOp> InDimBroadcast(
957       const Shape& shape, XlaOp operand,
958       absl::Span<const int64> broadcast_dimensions);
959 
960   // Internal helper method that creates a sequence of instructions that
961   // performs an explicit broadcast of the operand to the target shape.
962   StatusOr<XlaOp> AddBroadcastSequence(const Shape& output_shape,
963                                        XlaOp operand);
964 
965   // Internal helper method for creating a Reshape op with the already inferred
966   // shape.
967   virtual StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
968                                           int64 inferred_dimension);
969 
970   // Returns the (inferred) result for the program shape using the given root.
971   StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
972 
973   // A visitor which checks whether an operation is a compile-time constant,
974   // meaning that it doesn't depend on any parameters, or on any stateful
975   // operation such as `RngNormal` or `Infeed`. The visitor walks the
976   // computation starting at a given operation and sets is_constant to false iff
977   // a parameter or stateful operation is encountered.
978   void IsConstantVisitor(const int64 op_handle,
979                          absl::flat_hash_set<int64>* visited,
980                          bool* is_constant) const;
981 
982   // Checks bounds for convolution parameters.
983   Status VerifyConvolution(
984       const Shape& lhs_shape, const Shape& rhs_shape,
985       const ConvolutionDimensionNumbers& dimension_numbers) const;
986 
GetNextId()987   int64 GetNextId() { return ++next_id_; }
988 
989   // Populates the module with the input/output alias information stored within
990   // the input_output_aliases vector.
991   static Status PopulateInputOutputAlias(
992       HloModuleProto* module, const ProgramShape& program_shape,
993       const std::vector<InputOutputAlias>& input_output_aliases);
994 
995   string name_;  // Name to use for the built computation.
996 
997   // The next sequential ID for every instruction/computation contained within
998   // this computation.
999   int64 next_id_ = 0;
1000 
1001   // The first error encountered while building the computation.
1002   // This is OK until the first error is encountered.
1003   Status first_error_;
1004 
1005   // The saved stack trace from the point at which the first error occurred.
1006   tensorflow::SavedStackTrace first_error_backtrace_;
1007 
1008   // The instructions of this computation.
1009   std::vector<HloInstructionProto> instructions_;
1010 
1011   // An cache for the HloInstructionProto shapes, to avoid recreating Shape
1012   // objects from protos and to support the GetShapePtr() API.
1013   std::vector<std::unique_ptr<Shape>> instruction_shapes_;
1014 
1015   // Dynamic parameter configuration of this computation.
1016   DynamicParameterBinding dynamic_parameter_binding_;
1017 
1018   // Holds the input/output alias information populated by the SetUpAlias() API.
1019   std::vector<InputOutputAlias> input_output_aliases_;
1020 
1021   // A map from XlaOp::Handle to the index in the instructions_ vector where the
1022   // instruction is held.
1023   absl::flat_hash_map<int64, int64> handle_to_index_;
1024 
1025   // The embedded computations used by this computation. Each computation was
1026   // the entry computation of some XlaComputation, the key is the unique id of
1027   // that XlaComputation.
1028   std::map<int64, HloComputationProto> embedded_;
1029 
1030   // The unique parameter numbers.
1031   absl::flat_hash_set<int64> parameter_numbers_;
1032 
1033   // The metadata to attach to each op. This is structured as a "modal"-like
1034   // operation, in order to simplify client code (and not sprinkle this metadata
1035   // throughout the TensorFlow op kernel implementations).
1036   OpMetadata metadata_;
1037 
1038   // A temporary metadata that will only be applied to the next op created.
1039   absl::optional<OpMetadata> one_shot_metadata_;
1040 
1041   // Sharding for this operator. This is structured as a "model"-like operation,
1042   // in order to simplify client code, similar to metadata_.
1043   absl::optional<OpSharding> sharding_;
1044 
1045   // Mode bit that indicates whether to die when a first error is encountered.
1046   bool die_immediately_on_error_ = false;
1047 
1048   XlaBuilder* parent_builder_{nullptr};
1049 
1050   FrontendAttributes frontend_attributes_;
1051 
1052   friend XlaOp Parameter(XlaBuilder* builder, int64 parameter_number,
1053                          const Shape& shape, const string& name,
1054                          const std::vector<bool>& replicated_at_leaf_buffers);
1055   friend XlaOp ConstantLiteral(XlaBuilder* builder,
1056                                const LiteralSlice& literal);
1057 
1058   friend XlaOp Broadcast(XlaOp operand,
1059                          absl::Span<const int64> broadcast_sizes);
1060 
1061   friend XlaOp BroadcastInDim(
1062       XlaOp operand, const absl::Span<const int64> out_dim_size,
1063       const absl::Span<const int64> broadcast_dimensions);
1064 
1065   friend XlaOp Copy(XlaOp operand);
1066 
1067   friend XlaOp Pad(XlaOp operand, XlaOp padding_value,
1068                    const PaddingConfig& padding_config);
1069 
1070   friend XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno,
1071                         int64 pad_lo, int64 pad_hi);
1072 
1073   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1074                        absl::Span<const int64> new_sizes);
1075 
1076   friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1077 
1078   friend XlaOp Reshape(const Shape& shape, XlaOp operand);
1079 
1080   friend XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1081                               absl::Span<const int64> new_size_bounds,
1082                               const std::vector<bool>& dims_are_dynamic);
1083 
1084   friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
1085                                             absl::Span<const int64> new_sizes,
1086                                             int64 inferred_dimension);
1087 
1088   friend XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1089 
1090   friend XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1091                      absl::Span<const int64> limit_indices,
1092                      absl::Span<const int64> strides);
1093 
1094   friend XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
1095                           int64 stride, int64 dimno);
1096 
1097   friend XlaOp DynamicSlice(XlaOp operand,
1098                             absl::Span<const XlaOp> start_indices,
1099                             absl::Span<const int64> slice_sizes);
1100 
1101   friend XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1102                                   absl::Span<const XlaOp> start_indices);
1103 
1104   friend XlaOp ConcatInDim(XlaBuilder* builder,
1105                            absl::Span<const XlaOp> operands, int64 dimension);
1106 
1107   friend void Trace(const string& tag, XlaOp operand);
1108 
1109   friend XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1110   friend XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1111   friend XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
1112   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1113                        absl::Span<const int64> broadcast_dimensions,
1114                        ComparisonDirection direction);
1115   friend XlaOp Compare(XlaOp lhs, XlaOp rhs,
1116                        absl::Span<const int64> broadcast_dimensions,
1117                        ComparisonDirection direction,
1118                        Comparison::Type compare_type);
1119   friend XlaOp Dot(XlaOp lhs, XlaOp rhs,
1120                    const PrecisionConfig* precision_config,
1121                    absl::optional<PrimitiveType> preferred_element_type);
1122   friend XlaOp DotGeneral(XlaOp lhs, XlaOp rhs,
1123                           const DotDimensionNumbers& dimension_number,
1124                           const PrecisionConfig* precision_config,
1125                           absl::optional<PrimitiveType> preferred_element_type);
1126   virtual StatusOr<XlaOp> DotGeneralInternal(
1127       const Shape& shape, XlaOp lhs, XlaOp rhs,
1128       const DotDimensionNumbers& dimension_number,
1129       const PrecisionConfig* precision_config);
1130   friend XlaOp Conv(XlaOp lhs, XlaOp rhs,
1131                     absl::Span<const int64> window_strides, Padding padding,
1132                     int64 feature_group_count, int64 batch_group_count,
1133                     const PrecisionConfig* precision_config,
1134                     absl::optional<PrimitiveType> preferred_element_type);
1135   friend XlaOp ConvWithGeneralPadding(
1136       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1137       absl::Span<const std::pair<int64, int64>> padding,
1138       int64 feature_group_count, int64 batch_group_count,
1139       const PrecisionConfig* precision_config,
1140       absl::optional<PrimitiveType> preferred_element_type);
1141   friend XlaOp ConvWithGeneralDimensions(
1142       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1143       Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1144       int64 feature_group_count, int64 batch_group_count,
1145       const PrecisionConfig* precision_config,
1146       absl::optional<PrimitiveType> preferred_element_type);
1147   friend XlaOp ConvGeneral(
1148       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1149       absl::Span<const std::pair<int64, int64>> padding,
1150       const ConvolutionDimensionNumbers& dimension_numbers,
1151       int64 feature_group_count, int64 batch_group_count,
1152       const PrecisionConfig* precision_config,
1153       absl::optional<PrimitiveType> preferred_element_type);
1154   friend XlaOp DynamicConvForward(
1155       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1156       absl::Span<const std::pair<int64, int64>> padding,
1157       absl::Span<const int64> lhs_dilation,
1158       absl::Span<const int64> rhs_dilation,
1159       const ConvolutionDimensionNumbers& dimension_numbers,
1160       int64 feature_group_count, int64 batch_group_count,
1161       const PrecisionConfig* precision_config, PaddingType padding_type,
1162       absl::optional<PrimitiveType> preferred_element_type);
1163   friend XlaOp DynamicConvKernelGrad(
1164       XlaOp activations, XlaOp gradients,
1165       absl::Span<const int64> window_strides,
1166       absl::Span<const std::pair<int64, int64>> padding,
1167       absl::Span<const int64> lhs_dilation,
1168       absl::Span<const int64> rhs_dilation,
1169       const ConvolutionDimensionNumbers& dimension_numbers,
1170       int64 feature_group_count, int64 batch_group_count,
1171       const PrecisionConfig* precision_config, PaddingType padding_type,
1172       absl::optional<PrimitiveType> preferred_element_type);
1173   friend XlaOp DynamicConvInputGrad(
1174       XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1175       absl::Span<const int64> window_strides,
1176       absl::Span<const std::pair<int64, int64>> padding,
1177       absl::Span<const int64> lhs_dilation,
1178       absl::Span<const int64> rhs_dilation,
1179       const ConvolutionDimensionNumbers& dimension_numbers,
1180       int64 feature_group_count, int64 batch_group_count,
1181       const PrecisionConfig* precision_config, PaddingType padding_type,
1182       absl::optional<PrimitiveType> preferred_element_type);
1183 
1184   friend XlaOp ConvKernelGrad(
1185       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1186       absl::Span<const std::pair<int64, int64>> padding,
1187       absl::Span<const int64> lhs_dilation,
1188       absl::Span<const int64> rhs_dilation,
1189       const ConvolutionDimensionNumbers& dimension_numbers,
1190       int64 feature_group_count, int64 batch_group_count,
1191       const PrecisionConfig* precision_config,
1192       absl::optional<PrimitiveType> preferred_element_type);
1193 
1194   friend XlaOp ConvGeneralDilated(
1195       XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1196       absl::Span<const std::pair<int64, int64>> padding,
1197       absl::Span<const int64> lhs_dilation,
1198       absl::Span<const int64> rhs_dilation,
1199       const ConvolutionDimensionNumbers& dimension_numbers,
1200       int64 feature_group_count, int64 batch_group_count,
1201       const PrecisionConfig* precision_config,
1202       absl::optional<PrimitiveType> preferred_element_type);
1203   friend XlaOp Fft(XlaOp operand, FftType fft_type,
1204                    absl::Span<const int64> fft_length);
1205   friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1206                                bool unit_diagonal,
1207                                TriangularSolveOptions::Transpose transpose_a);
1208   friend XlaOp Cholesky(XlaOp a, bool lower);
1209   friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1210                       const string& config);
1211   friend void Outfeed(XlaOp operand, const Shape& shape_with_layout,
1212                       const string& outfeed_config);
1213   friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
1214                     absl::Span<const XlaOp> operands);
1215   friend XlaOp CustomCall(
1216       XlaBuilder* builder, const string& call_target_name,
1217       absl::Span<const XlaOp> operands, const Shape& shape,
1218       const string& opaque, bool has_side_effect,
1219       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1220           output_operand_aliasing,
1221       const Literal* literal);
1222   friend XlaOp CustomCallWithComputation(
1223       XlaBuilder* builder, const string& call_target_name,
1224       absl::Span<const XlaOp> operands, const XlaComputation& computation,
1225       const Shape& shape, const string& opaque, bool has_side_effect,
1226       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1227           output_operand_aliasing,
1228       const Literal* literal);
1229   friend XlaOp CustomCallWithLayout(
1230       XlaBuilder* builder, const string& call_target_name,
1231       absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
1232       absl::Span<const Shape> operand_shapes_with_layout, const string& opaque,
1233       bool has_side_effect,
1234       absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
1235           output_operand_aliasing,
1236       const Literal* literal);
1237   friend XlaOp Complex(XlaOp real, XlaOp imag,
1238                        absl::Span<const int64> broadcast_dimensions);
1239   friend XlaOp Conj(XlaOp operand);
1240   friend XlaOp Add(XlaOp lhs, XlaOp rhs,
1241                    absl::Span<const int64> broadcast_dimensions);
1242   friend XlaOp Sub(XlaOp lhs, XlaOp rhs,
1243                    absl::Span<const int64> broadcast_dimensions);
1244   friend XlaOp Mul(XlaOp lhs, XlaOp rhs,
1245                    absl::Span<const int64> broadcast_dimensions);
1246   friend XlaOp Div(XlaOp lhs, XlaOp rhs,
1247                    absl::Span<const int64> broadcast_dimensions);
1248   friend XlaOp Rem(XlaOp lhs, XlaOp rhs,
1249                    absl::Span<const int64> broadcast_dimensions);
1250   friend XlaOp Max(XlaOp lhs, XlaOp rhs,
1251                    absl::Span<const int64> broadcast_dimensions);
1252   friend XlaOp Min(XlaOp lhs, XlaOp rhs,
1253                    absl::Span<const int64> broadcast_dimensions);
1254   friend XlaOp And(XlaOp lhs, XlaOp rhs,
1255                    absl::Span<const int64> broadcast_dimensions);
1256   friend XlaOp Or(XlaOp lhs, XlaOp rhs,
1257                   absl::Span<const int64> broadcast_dimensions);
1258   friend XlaOp Xor(XlaOp lhs, XlaOp rhs,
1259                    absl::Span<const int64> broadcast_dimensions);
1260   friend XlaOp Not(XlaOp operand);
1261   friend XlaOp PopulationCount(XlaOp operand);
1262   friend XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
1263                          absl::Span<const int64> broadcast_dimensions);
1264   friend XlaOp ShiftRightArithmetic(
1265       XlaOp lhs, XlaOp rhs, absl::Span<const int64> broadcast_dimensions);
1266   friend XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
1267                                  absl::Span<const int64> broadcast_dimensions);
1268   friend XlaOp Reduce(XlaOp operand, XlaOp init_value,
1269                       const XlaComputation& computation,
1270                       absl::Span<const int64> dimensions_to_reduce);
1271   friend XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1272                       absl::Span<const XlaOp> init_values,
1273                       const XlaComputation& computation,
1274                       absl::Span<const int64> dimensions_to_reduce);
1275   friend XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
1276                          const XlaComputation& computation);
1277   friend XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
1278                             const XlaComputation& computation,
1279                             absl::Span<const int64> window_dimensions,
1280                             absl::Span<const int64> window_strides,
1281                             Padding padding);
1282   friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
1283                             absl::Span<const XlaOp> init_values,
1284                             const XlaComputation& computation,
1285                             absl::Span<const int64> window_dimensions,
1286                             absl::Span<const int64> window_strides,
1287                             Padding padding);
1288   friend XlaOp ReduceWindowWithGeneralPadding(
1289       XlaOp operand, XlaOp init_value, const XlaComputation& computation,
1290       absl::Span<const int64> window_dimensions,
1291       absl::Span<const int64> window_strides,
1292       absl::Span<const int64> base_dilations,
1293       absl::Span<const int64> window_dilations,
1294       absl::Span<const std::pair<int64, int64>> padding);
1295   friend XlaOp CrossReplicaSum(XlaOp operand,
1296                                absl::Span<const ReplicaGroup> replica_groups);
1297   friend XlaOp AllGather(XlaOp operand, int64 all_gather_dimension,
1298                          int64 shard_count,
1299                          absl::Span<const ReplicaGroup> replica_groups,
1300                          const absl::optional<ChannelHandle>& channel_id,
1301                          const absl::optional<Layout>& layout,
1302                          const absl::optional<bool> use_global_device_ids);
1303   friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
1304                          absl::Span<const ReplicaGroup> replica_groups,
1305                          const absl::optional<ChannelHandle>& channel_id,
1306                          const absl::optional<Shape>& shape_with_layout);
1307   friend XlaOp AllToAll(XlaOp operand, int64 split_dimension,
1308                         int64 concat_dimension, int64 split_count,
1309                         const std::vector<ReplicaGroup>& replica_groups,
1310                         const absl::optional<Layout>& layout);
1311   friend XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
1312                              int64 concat_dimension, int64 split_count,
1313                              const std::vector<ReplicaGroup>& replica_groups,
1314                              const absl::optional<Layout>& layout);
1315   friend XlaOp CollectivePermute(
1316       XlaOp operand,
1317       const std::vector<std::pair<int64, int64>>& source_target_pairs);
1318   friend XlaOp ReplicaId(XlaBuilder* builder);
1319   friend XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
1320                                 absl::Span<const int64> window_dimensions,
1321                                 absl::Span<const int64> window_strides,
1322                                 Padding padding, XlaOp source, XlaOp init_value,
1323                                 const XlaComputation& scatter);
1324   friend XlaOp SelectAndScatterWithGeneralPadding(
1325       XlaOp operand, const XlaComputation& select,
1326       absl::Span<const int64> window_dimensions,
1327       absl::Span<const int64> window_strides,
1328       absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
1329       XlaOp init_value, const XlaComputation& scatter);
1330   friend XlaOp Abs(XlaOp operand);
1331   friend XlaOp Atan2(XlaOp y, XlaOp x,
1332                      absl::Span<const int64> broadcast_dimensions);
1333   friend XlaOp Exp(XlaOp operand);
1334   friend XlaOp Expm1(XlaOp operand);
1335   friend XlaOp Floor(XlaOp operand);
1336   friend XlaOp Ceil(XlaOp operand);
1337   friend XlaOp Round(XlaOp operand);
1338   friend XlaOp Log(XlaOp operand);
1339   friend XlaOp Log1p(XlaOp operand);
1340   friend XlaOp Logistic(XlaOp operand);
1341   friend XlaOp Sign(XlaOp operand);
1342   friend XlaOp Clz(XlaOp operand);
1343   friend XlaOp Cos(XlaOp operand);
1344   friend XlaOp Sin(XlaOp operand);
1345   friend XlaOp Tanh(XlaOp operand);
1346   friend XlaOp Real(XlaOp operand);
1347   friend XlaOp Imag(XlaOp operand);
1348   friend XlaOp Sqrt(XlaOp operand);
1349   friend XlaOp Rsqrt(XlaOp operand);
1350   friend XlaOp Cbrt(XlaOp operand);
1351   friend XlaOp Pow(XlaOp lhs, XlaOp rhs,
1352                    absl::Span<const int64> broadcast_dimensions);
1353   friend XlaOp IsFinite(XlaOp operand);
1354   friend XlaOp Iota(XlaBuilder* builder, const Shape& shape,
1355                     int64 iota_dimension);
1356   friend XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
1357   friend XlaOp ConvertElementType(XlaOp operand,
1358                                   PrimitiveType new_element_type);
1359   friend XlaOp BitcastConvertType(XlaOp operand,
1360                                   PrimitiveType new_element_type);
1361   friend XlaOp Neg(XlaOp operand);
1362   friend XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
1363   friend XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
1364   friend XlaOp Sort(absl::Span<const XlaOp> operands,
1365                     const XlaComputation& comparator, int64 dimension,
1366                     bool is_stable);
1367   friend XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
1368   friend XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1369                    const XlaComputation& computation,
1370                    absl::Span<const int64> dimensions,
1371                    absl::Span<const XlaOp> static_operands);
1372   friend XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
1373   friend XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
1374   friend XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
1375                                const Shape& shape);
1376   friend XlaOp While(const XlaComputation& condition,
1377                      const XlaComputation& body, XlaOp init);
1378   friend XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
1379                            const XlaComputation& true_computation,
1380                            XlaOp false_operand,
1381                            const XlaComputation& false_computation);
1382   friend XlaOp Conditional(
1383       XlaOp branch_index,
1384       absl::Span<const XlaComputation* const> branch_computations,
1385       absl::Span<const XlaOp> branch_operands);
1386   friend XlaOp ConditionalImpl(
1387       XlaOp branch_index,
1388       absl::Span<const XlaComputation* const> branch_computations,
1389       absl::Span<const XlaOp> branch_operands);
1390   friend XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
1391                                const int mantissa_bits);
1392   friend XlaOp Gather(XlaOp input, XlaOp start_indices,
1393                       const GatherDimensionNumbers& dimension_numbers,
1394                       absl::Span<const int64> slice_sizes,
1395                       bool indices_are_sorted);
1396   friend XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
1397                        const XlaComputation& update_computation,
1398                        const ScatterDimensionNumbers& dimension_numbers,
1399                        bool indices_are_sorted, bool unique_indices);
1400   friend void Send(XlaOp operand, const ChannelHandle& handle);
1401   friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
1402                     const ChannelHandle& handle);
1403   friend XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset,
1404                                  float epsilon, int64 feature_index);
1405   friend XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset,
1406                                   XlaOp mean, XlaOp variance, float epsilon,
1407                                   int64 feature_index);
1408   friend XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
1409                              XlaOp batch_var, XlaOp grad_output, float epsilon,
1410                              int64 feature_index);
1411   friend XlaOp SendWithToken(XlaOp operand, XlaOp token,
1412                              const ChannelHandle& handle);
1413   friend XlaOp RecvWithToken(XlaOp token, const Shape& shape,
1414                              const ChannelHandle& handle);
1415   friend XlaOp SendToHost(XlaOp operand, XlaOp token,
1416                           const Shape& shape_with_layout,
1417                           const ChannelHandle& handle);
1418   friend XlaOp RecvFromHost(XlaOp token, const Shape& shape,
1419                             const ChannelHandle& handle);
1420   friend XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1421                                const string& config);
1422   friend XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
1423                                 const Shape& shape_with_layout,
1424                                 const string& outfeed_config);
1425   friend XlaOp CreateToken(XlaBuilder* builder);
1426   friend XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
1427 
1428   friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
1429   friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
1430   friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
1431 
1432  protected:
1433   // Returns OK status if the given op was built using this builder. Otherwise,
1434   // returns an error.
1435   Status CheckOpBuilder(XlaOp op) const;
1436 
1437  private:
1438   XlaOp ConditionalImpl(
1439       XlaOp branch_index,
1440       absl::Span<const XlaComputation* const> branch_computations,
1441       absl::Span<const XlaOp> branch_operands);
1442 
1443   XlaOp AllToAllArray(XlaOp operand, int64 split_dimension,
1444                       int64 concat_dimension, int64 split_count,
1445                       const std::vector<ReplicaGroup>& replica_groups);
1446 
1447   // Creates an op with the given opcode and the output shape.
1448   virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
1449                                          absl::Span<const XlaOp> operands);
1450 
1451   // Here, InstructionType is either const HloInstructionProto* or non-const
1452   // HloInstructionProto*.
1453   template <typename InstructionType>
LookUpInstructionByHandleInternal(int64 handle)1454   StatusOr<InstructionType> LookUpInstructionByHandleInternal(
1455       int64 handle) const {
1456     auto it = handle_to_index_.find(handle);
1457     if (it == handle_to_index_.end()) {
1458       return InvalidArgument("No XlaOp with handle %d", handle);
1459     }
1460     return const_cast<InstructionType>(&instructions_.at(it->second));
1461   }
1462 
1463   // Here, InstructionType is either const HloInstructionProto* or non-const
1464   // HloInstructionProto*.
1465   //
1466   // TODO(hinsu): Return const pointer within StatusOr and use
1467   // absl::implicit_cast at callsites. This requires implicit_cast support in
1468   // stream_executor::port::StatusOr similar to absl::StatusOr.
1469   template <typename InstructionType>
LookUpInstructionInternal(XlaOp op)1470   StatusOr<InstructionType> LookUpInstructionInternal(XlaOp op) const {
1471     TF_RETURN_IF_ERROR(CheckOpBuilder(op));
1472     return LookUpInstructionByHandleInternal<InstructionType>(op.handle());
1473   }
1474 
1475   friend struct internal::XlaBuilderFriend;
1476 };
1477 
1478 // RAII-style object: sets the current sharding assignment in builder on
1479 // construction, and sets back to the previous assignment on destruction.
1480 class XlaScopedShardingAssignment {
1481  public:
XlaScopedShardingAssignment(xla::XlaBuilder * builder,absl::optional<OpSharding> sharding)1482   XlaScopedShardingAssignment(xla::XlaBuilder* builder,
1483                               absl::optional<OpSharding> sharding)
1484       : builder_(builder), prev_sharding_(builder->sharding()) {
1485     SetSharding(sharding);
1486   }
1487 
1488   XlaScopedShardingAssignment(const XlaScopedShardingAssignment&) = delete;
1489   XlaScopedShardingAssignment& operator=(const XlaScopedShardingAssignment&) =
1490       delete;
1491 
~XlaScopedShardingAssignment()1492   ~XlaScopedShardingAssignment() { SetSharding(prev_sharding_); }
1493 
1494  private:
SetSharding(const absl::optional<OpSharding> & sharding)1495   void SetSharding(const absl::optional<OpSharding>& sharding) {
1496     if (sharding.has_value()) {
1497       builder_->SetSharding(sharding.value());
1498     } else {
1499       builder_->ClearSharding();
1500     }
1501   }
1502 
1503   xla::XlaBuilder* const builder_;
1504   absl::optional<OpSharding> prev_sharding_;
1505 };
1506 
1507 // RAII-style object: save the current builder's frontend attributes, and merge
1508 // them with the new ones on construction.
1509 // Restore the original attributes on destruction.
1510 class XlaScopedFrontendAttributesAssignment {
1511  public:
XlaScopedFrontendAttributesAssignment(xla::XlaBuilder * builder,FrontendAttributes attributes)1512   XlaScopedFrontendAttributesAssignment(xla::XlaBuilder* builder,
1513                                         FrontendAttributes attributes)
1514       : builder_(builder) {
1515     saved_ = builder_->SwapFrontendAttributes(attributes);
1516   }
1517 
~XlaScopedFrontendAttributesAssignment()1518   ~XlaScopedFrontendAttributesAssignment() {
1519     builder_->SetFrontendAttributes(saved_);
1520   }
1521 
1522  private:
1523   xla::XlaBuilder* const builder_;
1524   FrontendAttributes saved_;
1525 
1526   TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedFrontendAttributesAssignment);
1527 };
1528 
1529 // RAII-style object: sets the current op metadata in builder on construction,
1530 // and sets back to the previous assignment on destruction.
1531 class XlaScopedOpMetadataAssignment {
1532  public:
XlaScopedOpMetadataAssignment(xla::XlaBuilder * builder,OpMetadata metadata)1533   XlaScopedOpMetadataAssignment(xla::XlaBuilder* builder, OpMetadata metadata)
1534       : builder_(builder) {
1535     saved_ = builder_->SwapOpMetadata(metadata);
1536   }
1537 
~XlaScopedOpMetadataAssignment()1538   ~XlaScopedOpMetadataAssignment() { builder_->SwapOpMetadata(saved_); }
1539 
1540  private:
1541   xla::XlaBuilder* const builder_;
1542   OpMetadata saved_;
1543 
1544   TF_DISALLOW_COPY_AND_ASSIGN(XlaScopedOpMetadataAssignment);
1545 };
1546 
1547 // Free functions for building XlaOps. The intention is that these will
1548 // become the public API for building XlaOps rather than calling methods on
1549 // XlaBuilder directly.
1550 //
1551 
1552 // Enqueues a "retrieve parameter value" instruction for a parameter that was
1553 // passed to the computation.
1554 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1555                 const string& name);
1556 
1557 // Same as above, but with leaf buffer replication annotation.
1558 XlaOp Parameter(XlaBuilder* builder, int64 parameter_number, const Shape& shape,
1559                 const string& name,
1560                 const std::vector<bool>& replicated_at_leaf_buffers);
1561 
1562 // Enqueues a constant with the value of the given literal onto the
1563 // computation.
1564 XlaOp ConstantLiteral(XlaBuilder* builder, const LiteralSlice& literal);
1565 
1566 // Enqueues a constant onto the computation. Methods are templated on the
1567 // native host type (NativeT) which corresponds to a specific XLA
1568 // PrimitiveType as given in the following table:
1569 //
1570 //  Native Type   PrimitiveType
1571 // -----------------------------
1572 //   bool           PRED
1573 //   int32          S32
1574 //   int64          S64
1575 //   uint32         U32
1576 //   uint64         U64
1577 //   float          F32
1578 //   double         F64
1579 //
1580 // Note: not all primitive types defined in xla_data.proto have a
1581 // corresponding native type yet.
1582 template <typename NativeT>
1583 XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
1584 template <typename NativeT>
1585 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values);
1586 XlaOp ConstantR1(XlaBuilder* builder, const tensorflow::core::Bitmap& values);
1587 template <typename NativeT>
1588 XlaOp ConstantR2(XlaBuilder* builder,
1589                  std::initializer_list<std::initializer_list<NativeT>> values);
1590 template <typename NativeT>
1591 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
1592                                   const Array<NativeT>& values,
1593                                   const Layout& layout);
1594 template <typename NativeT>
1595 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values);
1596 template <typename NativeT>
1597 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
1598                                       const Array2D<NativeT>& values,
1599                                       const Layout& layout);
1600 template <typename NativeT>
1601 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
1602                             const Array2D<NativeT>& values);
1603 template <typename NativeT>
1604 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
1605                                       const Array3D<NativeT>& values,
1606                                       const Layout& layout);
1607 template <typename NativeT>
1608 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
1609                             const Array3D<NativeT>& values);
1610 template <typename NativeT>
1611 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
1612                                       const Array4D<NativeT>& values,
1613                                       const Layout& layout);
1614 template <typename NativeT>
1615 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
1616                             const Array4D<NativeT>& values);
1617 
1618 // Enqueues a rank one constant (XlaBuilder* builder, vector) onto the
1619 // computation. The vector has size 'length' and every element has the value
1620 // 'value'.
1621 template <typename NativeT>
1622 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
1623 
1624 // Adds dimensions to an array by duplicating the data in the array.
1625 //
1626 // The new dimensions are inserted on the left, i.e. if
1627 // broadcast_sizes has values {a0, ..., aN} and the operand shape
1628 // has dimensions {b0, ..., bM} then the shape of the output has
1629 // dimensions {a0, ..., aN, b0, ..., bM}.
1630 //
1631 // The new dimensions index into copies of the operand, i.e.
1632 //
1633 //   output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
1634 XlaOp Broadcast(XlaOp operand, absl::Span<const int64> broadcast_sizes);
1635 
1636 // This op broadcasts the `operand` to an output with the given `shape`.
1637 // `broadcast_dimensions` are the dimensions to be broadcasting into, i.e., the
1638 // i'th dimension of the operand is mapped to the broadcast_dimensions[i]'th
1639 // dimension of the output. This also requires that the i'th input dimension is
1640 // either 1 or is the same as the output dimension it's broadcasting into.
1641 //
1642 // For example, say operand = {1, 2}, i.e., a 1D tensor in shape s32[2]; the
1643 // output shape is s32[2,2]:
1644 // - Specifying {1} as broadcast_dimension will generate output
1645 //   {{1, 2},
1646 //    {1, 2}}
1647 // - On the other hand, specifying {0} as broadcast_dimension
1648 //   will generate output
1649 //   {{1 , 1},
1650 //    {2 , 2}}
1651 XlaOp BroadcastInDim(XlaOp operand, const absl::Span<const int64> out_dim_size,
1652                      const absl::Span<const int64> broadcast_dimensions);
1653 
1654 // Copies the input operand to the output. This operation is for internal
1655 // purpose and is only used by the compiler for optimization purposes or to
1656 // ensure correctness. The XLA client should never have to generate this
1657 // instruction.
1658 //
1659 // Copy has two potential use cases:
1660 //
1661 // * Create a copy of the operand with a new layout.
1662 //
1663 // * Create a copy of the operand in a separately allocated buffer. This is
1664 //   necessary for some backends if the operand is a parameter or constant and
1665 //   the operand is returned within a tuple. In this case, the lifetime of the
1666 //   operand buffer must be the same as the lifetime of the output result.
1667 //   However, the lifetimes of parameters and constants are managed separately
1668 //   from the lifetime of the output result. Creating a separate copy of the
1669 //   parameter or constant buffer resolves this issue.
1670 XlaOp Copy(XlaOp operand);
1671 
1672 // Enqueues a pad operation onto the computation that pads the given value on
1673 // the edges as well as between the elements of the input. padding_config
1674 // specifies the padding amount for each dimension.
1675 XlaOp Pad(XlaOp operand, XlaOp padding_value,
1676           const PaddingConfig& padding_config);
1677 
1678 // Enqueues a pad operation in a given dimension, taking all other
1679 // dimensions as they are.
1680 XlaOp PadInDim(XlaOp operand, XlaOp padding_value, int64 dimno, int64 pad_lo,
1681                int64 pad_hi);
1682 
1683 // Enqueues an operation onto the computation that flattens the operand based
1684 // on the dimension order (major/slowest-varying to minor/fastest-varying)
1685 // given, followed by reshaping it into the shape with the given dimension
1686 // sizes (also major to minor). Conceptually, this is a limited form of
1687 // "shape casting".
1688 XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
1689               absl::Span<const int64> new_sizes);
1690 
1691 // Enqueues a dynamic reshape operation. The dynamic reshape takes additional
1692 // XlaOps as sizes for the result dimension. The result dim i is a dynamic
1693 // dimension dimension if dims_are_dynamic[i] is true.
1694 XlaOp DynamicReshape(XlaOp operand, absl::Span<const XlaOp> dim_sizes,
1695                      absl::Span<const int64> new_size_bounds,
1696                      const std::vector<bool>& dims_are_dynamic);
1697 
1698 // Enqueues an operation onto the computation that collapses the operand,
1699 // from first to last dimension (C order), then reshapes it to the given
1700 // dimension sizes. Conceptually, this is a limited form of "shape casting".
1701 XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
1702 
1703 // Enqueues a Reshape op that uses an explicit target shape.
1704 XlaOp Reshape(const Shape& shape, XlaOp operand);
1705 
1706 // `inferred_dimension` represents the output dimension that's inferred by
1707 // upper-level framework by dividing the input element count by the known
1708 // output element count. While an inferred_dimension can be static, if there
1709 // is a dynamic dimension in the output, it must be the inferred dimension.
1710 XlaOp ReshapeWithInferredDimension(XlaOp operand,
1711                                    absl::Span<const int64> new_sizes,
1712                                    int64 inferred_dimension);
1713 
1714 // Wrapper for Reshape.
1715 // Enqueues an operation to collapse the provided dimensions; e.g. an
1716 // operand with dimensions {x=256, y=2, z=2, p=32} can be collapsed to
1717 // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must
1718 // be a consecutive, in-order subsequence of the operand dimensions.
1719 //
1720 // Note that collapsing a single dimension does nothing:
1721 //
1722 //    {256} collapsing {0} => {256}
1723 //    {1} collapsing {0} => {1}
1724 //
1725 // Collapsing multiple dimensions produces a single result dimension:
1726 //
1727 //    {256, 2} collapsing {0,1} => {512}
1728 //    {256, 2, 3} collapsing {0,1} => {512, 3}
1729 //
1730 // This could potentially cause data to be moved -- it provides a more
1731 // structured form of reshaping than an arbitrary Reshape operation.
1732 XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
1733 
1734 // Enqueues a slice operation onto the computation that slices the operand
1735 // from the start indices to the limit indices; e.g.
1736 //
1737 //        x
1738 //   [ 0 1 2 3 ]
1739 // y [ 4 5 6 7 ] => slice(start={1, 1}, limit={2, 3}) => [ 5 6 ]
1740 //   [ 8 9 a b ]
1741 //
1742 // Note that "limit" means up-to-but-not-including; i.e. [start, limit) in 1D
1743 // range notation.
1744 // The strides parameter determines the stride over the slice
1745 XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
1746             absl::Span<const int64> limit_indices,
1747             absl::Span<const int64> strides);
1748 
1749 // Enqueues a slice operation in a given dimension, taking all other
1750 // dimensions as they are; e.g. if dimno is 1 from start_index 2 to
1751 // limit_index 4 by 1, and the shape is f32[7,8,9], this call is short-hand
1752 // for:
1753 //
1754 //  array[:, 2:4:1, :]
1755 XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
1756                  int64 stride, int64 dimno);
1757 
1758 // Enqueues a slice operation onto the computation that slices the 'operand'
1759 // from dynamic start indices which are passed in 'start_indices'.
1760 // The size of the slice in each dimension is passed in 'slice_sizes',
1761 // which specify the end point of exclusive slice intervals in each
1762 // dimension [start, start + size).
1763 // The shape of each element of 'start_indices' must be scalar, with the span
1764 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1765 // have the same shape.
1766 // Slice index calculations are computed modulo input dimension sizes to
1767 // prevent dynamic start indices from generating out-of-bound array accesses.
1768 XlaOp DynamicSlice(XlaOp operand, absl::Span<const XlaOp> start_indices,
1769                    absl::Span<const int64> slice_sizes);
1770 
1771 // Enqueues a dynamic update slice operation onto the computation, which
1772 // updates a slice of 'operand' with 'update' at dynamic 'start_indices'.
1773 // The shape of 'update' determines the shape of the slice of 'operand'
1774 // which is updated.
1775 // The indices specified in 'start_indices' specify the offset of the slice
1776 // of 'operand' which is updated.
1777 //
1778 //               update = {10, 11} // calculated at runtime.
1779 //   [1 2 3]     start  = {1, 1}   // calculated at runtime.  [1 2  3 ]
1780 //   [4 5 6]  => DynamicUpdateslice(data, update, start)   => [4 10 11]
1781 //   [7 8 9]                                                  [7 8  9 ]
1782 //
1783 // The shape of each element of 'start_indices' must be scalar, with the span
1784 // size equal to the rank of the 'operand'. All elements of 'start_indices' must
1785 // have the same shape.
1786 // Slice index calculations are computed modulo update dimension sizes to
1787 // prevent dynamic start indices from generating out-of-bound array accesses.
1788 XlaOp DynamicUpdateSlice(XlaOp operand, XlaOp update,
1789                          absl::Span<const XlaOp> start_indices);
1790 
1791 // Enqueues a concatenate instruction onto the computation. 'operands' must
1792 // have >= 1 entry.
1793 XlaOp ConcatInDim(XlaBuilder* builder, absl::Span<const XlaOp> operands,
1794                   int64 dimension);
1795 
1796 // Enqueue a tracing operation onto the computation; the computation will emit
1797 // a logging message with the operand.
1798 void Trace(const string& tag, XlaOp operand);
1799 
1800 // Enqueues a conditional-move-like select operation onto the computation;
1801 // predicated on pred, selects between on_true and on_false.
1802 XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
1803 
1804 // Enqueues a tuple-creation instruction onto the computation.
1805 XlaOp Tuple(XlaBuilder* builder, absl::Span<const XlaOp> elements);
1806 
1807 // Enqueues a tuple-element-get instruction onto the computation.
1808 XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
1809 
1810 // Enqueues an equal-to comparison instruction onto the computation.
1811 XlaOp Eq(XlaOp lhs, XlaOp rhs,
1812          absl::Span<const int64> broadcast_dimensions = {});
1813 XlaOp EqTotalOrder(XlaOp lhs, XlaOp rhs,
1814                    absl::Span<const int64> broadcast_dimensions = {});
1815 
1816 // Enqueues a not-equal comparison instruction onto the computation.
1817 XlaOp Ne(XlaOp lhs, XlaOp rhs,
1818          absl::Span<const int64> broadcast_dimensions = {});
1819 XlaOp NeTotalOrder(XlaOp lhs, XlaOp rhs,
1820                    absl::Span<const int64> broadcast_dimensions = {});
1821 
1822 // Enqueues a greater-or-equal comparison instruction onto the computation.
1823 XlaOp Ge(XlaOp lhs, XlaOp rhs,
1824          absl::Span<const int64> broadcast_dimensions = {});
1825 XlaOp GeTotalOrder(XlaOp lhs, XlaOp rhs,
1826                    absl::Span<const int64> broadcast_dimensions = {});
1827 
1828 // Enqueues a greater-than comparison instruction onto the computation.
1829 XlaOp Gt(XlaOp lhs, XlaOp rhs,
1830          absl::Span<const int64> broadcast_dimensions = {});
1831 XlaOp GtTotalOrder(XlaOp lhs, XlaOp rhs,
1832                    absl::Span<const int64> broadcast_dimensions = {});
1833 
1834 // Enqueues a less-than comparison instruction onto the computation.
1835 XlaOp Lt(XlaOp lhs, XlaOp rhs,
1836          absl::Span<const int64> broadcast_dimensions = {});
1837 XlaOp LtTotalOrder(XlaOp lhs, XlaOp rhs,
1838                    absl::Span<const int64> broadcast_dimensions = {});
1839 
1840 // Enqueues a less-or-equal comparison instruction onto the computation.
1841 XlaOp Le(XlaOp lhs, XlaOp rhs,
1842          absl::Span<const int64> broadcast_dimensions = {});
1843 XlaOp LeTotalOrder(XlaOp lhs, XlaOp rhs,
1844                    absl::Span<const int64> broadcast_dimensions = {});
1845 
1846 // Enqueues a comparison instruction onto the computation (optionally without
1847 // broadcast_dimensions for consistency with others).
1848 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1849               absl::Span<const int64> broadcast_dimensions,
1850               ComparisonDirection direction, Comparison::Type compare_type);
1851 XlaOp Compare(XlaOp lhs, XlaOp rhs,
1852               absl::Span<const int64> broadcast_dimensions,
1853               ComparisonDirection direction);
1854 XlaOp Compare(XlaOp lhs, XlaOp rhs, ComparisonDirection direction);
1855 
1856 // Enqueues a dot instruction onto the computation.
1857 XlaOp Dot(XlaOp lhs, XlaOp rhs,
1858           const PrecisionConfig* precision_config = nullptr,
1859           absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1860 
1861 // Enqueues a general dot instruction onto the computation.
1862 XlaOp DotGeneral(
1863     XlaOp lhs, XlaOp rhs, const DotDimensionNumbers& dimension_numbers,
1864     const PrecisionConfig* precision_config = nullptr,
1865     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1866 
1867 // Enqueues a convolution instruction onto the computation, which uses the
1868 // default convolution dimension numbers.
1869 XlaOp Conv(
1870     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1871     Padding padding, int64 feature_group_count = 1, int64 batch_group_count = 1,
1872     const PrecisionConfig* precision_config = nullptr,
1873     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1874 
1875 // Enqueues a convolution instruction onto the computation, with the caller
1876 // provided padding configuration in the format returned by MakePadding().
1877 XlaOp ConvWithGeneralPadding(
1878     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1879     absl::Span<const std::pair<int64, int64>> padding,
1880     int64 feature_group_count = 1, int64 batch_group_count = 1,
1881     const PrecisionConfig* precision_config = nullptr,
1882     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1883 
1884 // Enqueues a convolution instruction onto the computation, with the caller
1885 // provided dimension numbers configuration.
1886 XlaOp ConvWithGeneralDimensions(
1887     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1888     Padding padding, const ConvolutionDimensionNumbers& dimension_numbers,
1889     int64 feature_group_count = 1, int64 batch_group_count = 1,
1890     const PrecisionConfig* precision_config = nullptr,
1891     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1892 
1893 // Enqueues a convolution instruction onto the computation, with the caller
1894 // provided padding configuration as well as the dimension numbers.
1895 XlaOp ConvGeneral(
1896     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1897     absl::Span<const std::pair<int64, int64>> padding,
1898     const ConvolutionDimensionNumbers& dimension_numbers,
1899     int64 feature_group_count = 1, int64 batch_group_count = 1,
1900     const PrecisionConfig* precision_config = nullptr,
1901     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1902 
1903 // Enqueues a convolution instruction onto the computation, with the caller
1904 // provided padding configuration, dilation factors and dimension numbers.
1905 XlaOp ConvGeneralDilated(
1906     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1907     absl::Span<const std::pair<int64, int64>> padding,
1908     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1909     const ConvolutionDimensionNumbers& dimension_numbers,
1910     int64 feature_group_count = 1, int64 batch_group_count = 1,
1911     const PrecisionConfig* precision_config = nullptr,
1912     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1913 
1914 XlaOp DynamicConvForward(
1915     XlaOp lhs, XlaOp rhs, absl::Span<const int64> window_strides,
1916     absl::Span<const std::pair<int64, int64>> padding,
1917     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1918     const ConvolutionDimensionNumbers& dimension_numbers,
1919     int64 feature_group_count, int64 batch_group_count,
1920     const PrecisionConfig* precision_config, PaddingType padding_type,
1921     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1922 
1923 XlaOp DynamicConvInputGrad(
1924     XlaOp input_sizes, XlaOp lhs, XlaOp rhs,
1925     absl::Span<const int64> window_strides,
1926     absl::Span<const std::pair<int64, int64>> padding,
1927     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1928     const ConvolutionDimensionNumbers& dimension_numbers,
1929     int64 feature_group_count, int64 batch_group_count,
1930     const PrecisionConfig* precision_config, PaddingType padding_type,
1931     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1932 
1933 XlaOp DynamicConvKernelGrad(
1934     XlaOp activations, XlaOp gradients, absl::Span<const int64> window_strides,
1935     absl::Span<const std::pair<int64, int64>> padding,
1936     absl::Span<const int64> lhs_dilation, absl::Span<const int64> rhs_dilation,
1937     const ConvolutionDimensionNumbers& dimension_numbers,
1938     int64 feature_group_count, int64 batch_group_count,
1939     const PrecisionConfig* precision_config, PaddingType padding_type,
1940     absl::optional<PrimitiveType> preferred_element_type = absl::nullopt);
1941 
1942 // Enqueues an FFT instruction onto the computation, of the given type and
1943 // with the given FFT length.
1944 XlaOp Fft(XlaOp operand, FftType fft_type, absl::Span<const int64> fft_length);
1945 
1946 // Solves systems of linear equations with lower or upper triangular coefficient
1947 // matrices by forward- or back-substitution. Broadcasting along leading
1948 // dimensions, this routine solves for x in one of the matrix systems
1949 //   `op(a) * x = b`,  or `x * op(a) = b`,
1950 // for the variable `x` given `a` and `b`, where `op(a)` is either
1951 //   `op(a) = a`,  or `op(a) = transpose(a)`,  or `op(a) = conj(transpose(a))`.
1952 //
1953 // * `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form
1954 //   square matrices. If `lower` is true (false), then the strictly upper
1955 //   (lower) triangular part of each innermost matrix in `a` is assumed to be
1956 //   zero and is not accessed.
1957 // * `b` is a tensor of shape `[..., M, K]` if `left_side` is true, otherwise a
1958 //   tensor of shape `[..., K, M]`.
1959 // * `left_side` is a boolean, indicating whether to solve a system of the form
1960 //   op(a) * x = b (true) or x * op(a) = b (false).
1961 // * `lower` is a boolean, indicating whether the argument `a` is
1962 //   lower-triangular (true) or upper-triangular (false).
1963 // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
1964 //   1 and not accessed.
1965 // * `transpose_a` indicates which function `op` we use to transform the tensor
1966 //   `a`: the identity function, transpose(a), or conjugate(transpose(a))
1967 XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
1968                       bool unit_diagonal,
1969                       TriangularSolveOptions::Transpose transpose_a);
1970 
1971 // Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
1972 // positive definite matrices.
1973 // `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
1974 // two minor dimensions equal.
1975 // If `lower` is true, the data from the lower triangle is used; if false, the
1976 // upper triangle is used. The input data in the other triangle of the input
1977 // does not affect the output. Returns the output in the same lower/upper
1978 // triangle. The data returned in the other output triangle is arbitrary and
1979 // implementation-defined.
1980 //
1981 // If `a` is not Hermitian positive definite, returns an array full of NaNs.
1982 XlaOp Cholesky(XlaOp a, bool lower);
1983 
1984 // Enqueues an infeed instruction onto the computation, which writes data of
1985 // the given shape to the infeed buffer of the device.
1986 XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
1987              const string& config = "");
1988 
1989 // Variant of Infeed which takes a token-shaped operand and produces a
1990 // two-element tuple containing the data value and a token-shaped value.
1991 // Tokens are used for ordering side-effecting operations.
1992 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
1993 XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
1994                       const string& config = "");
1995 
1996 // Enqueues an outfeed instruction onto the computation. This instruction
1997 // generates outgoing data transfers for the given data.
1998 //
1999 // shape_with_layout communicates the laid out shape that we want to outfeed
2000 // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error
2001 // will occur.
2002 void Outfeed(XlaOp operand, const Shape& shape_with_layout,
2003              const string& outfeed_config);
2004 
2005 // Variant of Outfeed which takes a token-shaped operand and produces a
2006 // token-shaped value. Tokens are used for ordering side-effecting operations.
2007 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2008 XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
2009                        const Shape& shape_with_layout,
2010                        const string& outfeed_config);
2011 
2012 // Enqueues a call instruction onto the computation.
2013 XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
2014            absl::Span<const XlaOp> operands);
2015 
2016 // Enqueues a custom call instruction onto the computation. A custom call
2017 // invokes code external to XLA. The |operands| are passed to the external code,
2018 // and the external code is expected to produce a result of the given
2019 // |shape|. The exact mechanism is backend-specific. For example, in the CPU
2020 // backend, a call instruction is emitted which targets a symbol with the name
2021 // |call_target_name|.  |call_target_name| and |opaque| can arbitrary strings,
2022 // but |call_target_name| should be short as it may be used in labels. |opaque|
2023 // can encode arbitrarily large amounts of information. |has_side_effect|
2024 // specifies whether the instruction can have side effects.
2025 // |output_operand_aliasing| specifies a list of output/operand buffer pairs
2026 // that alias each other, where the output buffer is represented as a
2027 // ShapeIndex, and the operand buffer is represented as the operand index and
2028 // the ShapeIndex.
2029 XlaOp CustomCall(
2030     XlaBuilder* builder, const string& call_target_name,
2031     absl::Span<const XlaOp> operands, const Shape& shape,
2032     const string& opaque = "", bool has_side_effect = false,
2033     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2034         output_operand_aliasing = {},
2035     const Literal* literal = nullptr);
2036 
2037 // Overload which constructs a custom call that applies an Xla computation.
2038 XlaOp CustomCallWithComputation(
2039     XlaBuilder* builder, const string& call_target_name,
2040     absl::Span<const XlaOp> operands, const XlaComputation& computation,
2041     const Shape& shape, const string& opaque = "", bool has_side_effect = false,
2042     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2043         output_operand_aliasing = {},
2044     const Literal* literal = nullptr);
2045 
2046 // Overload which constructs a custom call with fixed layouts. The operands will
2047 // have the layouts specified by |operand_shapes_with_layout| when provided to
2048 // external code, and the external code is expected to produce a result with the
2049 // layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
2050 // and |operand_shapes_with_layout| must have layouts.
2051 XlaOp CustomCallWithLayout(
2052     XlaBuilder* builder, const string& call_target_name,
2053     absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
2054     absl::Span<const Shape> operand_shapes_with_layout,
2055     const string& opaque = "", bool has_side_effect = false,
2056     absl::Span<const std::pair<ShapeIndex, std::pair<int64, ShapeIndex>>>
2057         output_operand_aliasing = {},
2058     const Literal* literal = nullptr);
2059 
2060 // The following methods enqueue element-wise binary arithmetic operations
2061 // onto the computation. The shapes of the operands have to match unless one
2062 // of the operands is a scalar, or an explicit broadcast dimension is given
2063 // (see g3doc for more details).
2064 
2065 // Enqueues a complex compose instruction onto the computation.
2066 XlaOp Complex(XlaOp real, XlaOp imag,
2067               absl::Span<const int64> broadcast_dimensions = {});
2068 
2069 // Enqueues a complex conjugate instruction onto the computation.
2070 XlaOp Conj(XlaOp operand);
2071 
2072 // Enqueues an add instruction onto the computation.
2073 XlaOp Add(XlaOp lhs, XlaOp rhs,
2074           absl::Span<const int64> broadcast_dimensions = {});
2075 
2076 // Enqueues a subtract instruction onto the computation.
2077 XlaOp Sub(XlaOp lhs, XlaOp rhs,
2078           absl::Span<const int64> broadcast_dimensions = {});
2079 
2080 // Enqueues a multiply instruction onto the computation.
2081 XlaOp Mul(XlaOp lhs, XlaOp rhs,
2082           absl::Span<const int64> broadcast_dimensions = {});
2083 
2084 // Enqueues a divide instruction onto the computation.
2085 XlaOp Div(XlaOp lhs, XlaOp rhs,
2086           absl::Span<const int64> broadcast_dimensions = {});
2087 
2088 // Enqueues a remainder instruction onto the computation.
2089 XlaOp Rem(XlaOp lhs, XlaOp rhs,
2090           absl::Span<const int64> broadcast_dimensions = {});
2091 
2092 // Enqueues a max instruction onto the computation.
2093 XlaOp Max(XlaOp lhs, XlaOp rhs,
2094           absl::Span<const int64> broadcast_dimensions = {});
2095 
2096 // Enqueues a min instruction onto the computation.
2097 XlaOp Min(XlaOp lhs, XlaOp rhs,
2098           absl::Span<const int64> broadcast_dimensions = {});
2099 
2100 // Element-wise logical operators
2101 XlaOp And(XlaOp lhs, XlaOp rhs,
2102           absl::Span<const int64> broadcast_dimensions = {});
2103 
2104 // Overload to call And with 3 or more operands.  We need the following somewhat
2105 // convoluted overload set to disambiguate with the overload that takes the
2106 // `broadcast_dimensions` optional param.
And(XlaOp op1,XlaOp op2,XlaOp op3)2107 inline XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3) {
2108   return And(op1, And(op2, op3));
2109 }
2110 template <typename... XlaOpTs>
And(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2111 XlaOp And(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2112   return And(op1, And(op2, And(op3, operands...)));
2113 }
2114 
2115 XlaOp Or(XlaOp lhs, XlaOp rhs,
2116          absl::Span<const int64> broadcast_dimensions = {});
2117 
2118 // Overload to call Or with 3 or more operands.  As with `And`, we need the
2119 // following complicated overload set to handle the default arg in the `Or`
2120 // overload above.
Or(XlaOp op1,XlaOp op2,XlaOp op3)2121 inline XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3) {
2122   return Or(op1, Or(op2, op3));
2123 }
2124 template <typename... XlaOpTs>
Or(XlaOp op1,XlaOp op2,XlaOp op3,const XlaOpTs &...operands)2125 XlaOp Or(XlaOp op1, XlaOp op2, XlaOp op3, const XlaOpTs&... operands) {
2126   return Or(op1, Or(op2, Or(op3, operands...)));
2127 }
2128 
2129 XlaOp Xor(XlaOp lhs, XlaOp rhs,
2130           absl::Span<const int64> broadcast_dimensions = {});
2131 
2132 XlaOp Not(XlaOp operand);
2133 
2134 XlaOp PopulationCount(XlaOp operand);
2135 
2136 XlaOp ShiftLeft(XlaOp lhs, XlaOp rhs,
2137                 absl::Span<const int64> broadcast_dimensions = {});
2138 XlaOp ShiftRightArithmetic(XlaOp lhs, XlaOp rhs,
2139                            absl::Span<const int64> broadcast_dimensions = {});
2140 XlaOp ShiftRightLogical(XlaOp lhs, XlaOp rhs,
2141                         absl::Span<const int64> broadcast_dimensions = {});
2142 
2143 // Reduces an array among the provided dimensions, given "computation" as a
2144 // reduction operator.
2145 XlaOp Reduce(XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2146              absl::Span<const int64> dimensions_to_reduce);
2147 
2148 // Reduces several arrays simultaneously among the provided dimensions, given
2149 // "computation" as a reduction operator.
2150 XlaOp Reduce(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2151              absl::Span<const XlaOp> init_values,
2152              const XlaComputation& computation,
2153              absl::Span<const int64> dimensions_to_reduce);
2154 
2155 // Convenience wrapper around the above that reduces all the dimensions in the
2156 // operand shape.
2157 XlaOp ReduceAll(XlaOp operand, XlaOp init_value,
2158                 const XlaComputation& computation);
2159 
2160 // Enqueues a windowed reduce instruction onto the computation.
2161 XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
2162                    const XlaComputation& computation,
2163                    absl::Span<const int64> window_dimensions,
2164                    absl::Span<const int64> window_strides, Padding padding);
2165 
2166 XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
2167                    absl::Span<const XlaOp> init_values,
2168                    const XlaComputation& computation,
2169                    absl::Span<const int64> window_dimensions,
2170                    absl::Span<const int64> window_strides, Padding padding);
2171 
2172 // As ReduceWindow(), but the padding is given in the format
2173 // returned by MakePadding().
2174 XlaOp ReduceWindowWithGeneralPadding(
2175     XlaOp operand, XlaOp init_value, const XlaComputation& computation,
2176     absl::Span<const int64> window_dimensions,
2177     absl::Span<const int64> window_strides,
2178     absl::Span<const int64> base_dilations,
2179     absl::Span<const int64> window_dilations,
2180     absl::Span<const std::pair<int64, int64>> padding);
2181 
2182 // Returns the sum of the operand value within each subgroup of replicas. All
2183 // replicas supply one input to the sum and all replicas receive the resulting
2184 // sum for each subgroup.
2185 XlaOp CrossReplicaSum(XlaOp operand,
2186                       absl::Span<const ReplicaGroup> replica_groups = {});
2187 
2188 XlaOp AllGather(
2189     XlaOp operand, int64 all_gather_dimension, int64 shard_count,
2190     absl::Span<const ReplicaGroup> replica_groups = {},
2191     const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2192     const absl::optional<Layout>& layout = absl::nullopt,
2193     const absl::optional<bool> use_global_device_ids = absl::nullopt);
2194 
2195 // Enqueues an operation that do an AllReduce of the operand cross cores. Here
2196 // AllReduce means doing a reduction on the input operand cross cores and then
2197 // broadcasting the reduction result to those cores. The reduction function is
2198 // defined by `computation`, which should be a commutative computation on
2199 // scalars, e.g., add, min, or max. The way that AllReduce is applied is
2200 // configured by:
2201 //
2202 // - `replica_groups`: each ReplicaGroup contains a list of replica id. If
2203 // empty, all replicas belong to one group. Allreduce will be applied within
2204 // subgroups. For example, we have 4 replicas, then replica_groups={{0,2},{1,3}}
2205 // means, replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
2206 //
2207 // - `channel_id`: for Allreduce nodes from different modules, if they have the
2208 // same channel_id, they will be 'AllReduce'd. If empty, AllReduce will not be
2209 // applied cross modules.
2210 //
2211 // - `shape_with_layout`: forces the layout of the AllReduce to the given
2212 // layout. This is used to guarantee the same layout for a group of AllReduce
2213 // ops compiled separately.
2214 XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
2215                 absl::Span<const ReplicaGroup> replica_groups = {},
2216                 const absl::optional<ChannelHandle>& channel_id = absl::nullopt,
2217                 const absl::optional<Shape>& shape_with_layout = absl::nullopt);
2218 
2219 // Enqueues an operation that do an Alltoall of the operand cross cores.
2220 // An optional `layout` can be specified to force the layout of the instruction.
2221 // This is used to guarantee the same layout for a group of AllToAll ops
2222 // compiled separately.
2223 XlaOp AllToAll(XlaOp operand, int64 split_dimension, int64 concat_dimension,
2224                int64 split_count,
2225                const std::vector<ReplicaGroup>& replica_groups = {},
2226                const absl::optional<Layout>& layout = absl::nullopt);
2227 
2228 XlaOp AllToAllTuple(XlaOp operand, int64 split_dimension,
2229                     int64 concat_dimension, int64 split_count,
2230                     const std::vector<ReplicaGroup>& replica_groups = {},
2231                     const absl::optional<Layout>& layout = absl::nullopt);
2232 
2233 // Enqueues an collective operation that sends and receives data cross replicas.
2234 //
2235 // - `source_target_pair`: a list of (source_replica_id, target_replica_id)
2236 // pairs. For each pair, the operand is sent from source replica to target
2237 // replica. Note that, 1) any two pairs should not have the same target replica
2238 // id, and they should not have the same source replica id; 2) if a replica id
2239 // is not a target in any pair, then the output on that replica is a tensor
2240 // consists of 0(s) with the same shape as the input.
2241 XlaOp CollectivePermute(
2242     XlaOp operand,
2243     const std::vector<std::pair<int64, int64>>& source_target_pairs);
2244 
2245 // Enqueues an operation that returns the replica ID.
2246 XlaOp ReplicaId(XlaBuilder* builder);
2247 
2248 // Enqueues an operation that scatters the `source` array to the selected
2249 // indices of each window.
2250 XlaOp SelectAndScatter(XlaOp operand, const XlaComputation& select,
2251                        absl::Span<const int64> window_dimensions,
2252                        absl::Span<const int64> window_strides, Padding padding,
2253                        XlaOp source, XlaOp init_value,
2254                        const XlaComputation& scatter);
2255 
2256 // As SelectAndScatter(), but the padding is given in the format
2257 // returned by MakePadding().
2258 XlaOp SelectAndScatterWithGeneralPadding(
2259     XlaOp operand, const XlaComputation& select,
2260     absl::Span<const int64> window_dimensions,
2261     absl::Span<const int64> window_strides,
2262     absl::Span<const std::pair<int64, int64>> padding, XlaOp source,
2263     XlaOp init_value, const XlaComputation& scatter);
2264 
2265 // Enqueues an abs instruction onto the computation.
2266 XlaOp Abs(XlaOp operand);
2267 
2268 // Enqueues a atan2 instruction onto the computation.
2269 XlaOp Atan2(XlaOp y, XlaOp x,
2270             absl::Span<const int64> broadcast_dimensions = {});
2271 
2272 // Enqueues an exp instruction onto the computation.
2273 XlaOp Exp(XlaOp operand);
2274 
2275 // Enqueues an expm1 instruction onto the computation.
2276 XlaOp Expm1(XlaOp operand);
2277 
2278 // Enqueues a floor instruction onto the computation.
2279 XlaOp Floor(XlaOp operand);
2280 
2281 // Enqueues a ceil instruction onto the computation.
2282 XlaOp Ceil(XlaOp operand);
2283 
2284 // Enqueues a round instruction onto the computation, rounding to nearest even
2285 // with half-way cases rounding away from zero.
2286 XlaOp Round(XlaOp operand);
2287 
2288 // Enqueues an log instruction (natural logarithm) onto the computation.
2289 XlaOp Log(XlaOp operand);
2290 
2291 // Enqueues an log1p instruction (log(x+1)) onto the computation.
2292 XlaOp Log1p(XlaOp operand);
2293 
2294 // Enqueues a logistic instruction onto the computation.
2295 XlaOp Logistic(XlaOp operand);
2296 
2297 // Enqueues a sign instruction onto the computation.
2298 XlaOp Sign(XlaOp operand);
2299 
2300 // Enqueues a count leading zeros instruction onto the computation.
2301 XlaOp Clz(XlaOp operand);
2302 
2303 // Enqueues a cosine instruction onto the computation.
2304 XlaOp Cos(XlaOp operand);
2305 
2306 // Enqueues a sine instruction onto the computation.
2307 XlaOp Sin(XlaOp operand);
2308 
2309 // Enqueues a tanh instruction onto the computation.
2310 XlaOp Tanh(XlaOp operand);
2311 
2312 // Enqueues a real-part instruction onto the computation.
2313 XlaOp Real(XlaOp operand);
2314 
2315 // Enqueues an imaginary-part instruction onto the computation.
2316 XlaOp Imag(XlaOp operand);
2317 
2318 // Enqueues a sqrt computation onto the computation.
2319 XlaOp Sqrt(XlaOp operand);
2320 
2321 // Enqueues a cbrt computation onto the computation.
2322 XlaOp Cbrt(XlaOp operand);
2323 
2324 // Enqueues a rsqrt computation onto the computation.
2325 XlaOp Rsqrt(XlaOp operand);
2326 
2327 // Enqueues a lhs^rhs computation onto the computation.
2328 XlaOp Pow(XlaOp lhs, XlaOp rhs,
2329           absl::Span<const int64> broadcast_dimensions = {});
2330 
2331 // Enqueues an operator that tests if the operand's values are finite, i.e., not
2332 // +/-Inf or NaN.  Returns an array of booleans with the same shape where
2333 // entries are true iff the corresponding entry was not infinite or NaN.
2334 //
2335 // Defined only for real-valued (i.e. not complex) floating-point types; raises
2336 // an error for other types.
2337 //
2338 // See also IsInf, IsPosInf, IsNegInf, and IsNan in lib/math.h.
2339 XlaOp IsFinite(XlaOp operand);
2340 
2341 // Enqueues an iota operation onto the computation.
2342 XlaOp Iota(XlaBuilder* builder, const Shape& shape, int64 iota_dimension);
2343 
2344 // Enqueues a rank-1 iota operation onto the computation.
2345 XlaOp Iota(XlaBuilder* builder, PrimitiveType type, int64 size);
2346 
2347 // Enqueues a convert instruction onto the computation that changes the
2348 // element type of the operand array to primitive_type.
2349 XlaOp ConvertElementType(XlaOp operand, PrimitiveType new_element_type);
2350 
2351 // Enqueues a no-op instruction onto the computation that changes
2352 // the element type of the operand array to primitive_type. The
2353 // bit-widths of the source and destination element types must be
2354 // identical.
2355 XlaOp BitcastConvertType(XlaOp operand, PrimitiveType new_element_type);
2356 
2357 // Enqueues a negate instruction onto the computation.
2358 XlaOp Neg(XlaOp operand);
2359 
2360 // Enqueues a transpose instruction onto the computation.
2361 XlaOp Transpose(XlaOp operand, absl::Span<const int64> permutation);
2362 
2363 // Enqueues a reverse instruction onto the computation. The order of the
2364 // elements in the given dimensions is reversed (i.e., the element at index i
2365 // is moved to index dimension_size - 1 - i).
2366 XlaOp Rev(XlaOp operand, absl::Span<const int64> dimensions);
2367 
2368 // Enqueues a sort instruction onto the computation, using 'comparator' for
2369 // comparisons. 'comparator' needs to define a strict weak order. 'is_stable'
2370 // determines whether the stable sorting should be used.
2371 // If only one operand is provided:
2372 // * If the operand is a rank-1 tensor (an array), the result is a sorted array.
2373 //   The resulting sorting order has the property that for all index positions
2374 //   i, j with i < j, either
2375 //   comparator(value[i], value[j]) = comparator(value[j], value[i]) = false or
2376 //   comparator(value[i], value[j]) = true.
2377 // * If the operand has higher rank, the operand is sorted along the provided
2378 //   dimension. For example, for a rank-2 tensor (a matrix), a dimension value
2379 //   of 0 will independently sort every column, and a dimension value of 1 will
2380 //   independently sort each row. If no dimension number is provided, then the
2381 //   last dimension is chosen by default. For the dimension which is sorted, the
2382 //   same sorting order applies as in the rank-1 case.
2383 //
2384 // If more than one operand is provided:
2385 // * All operands must be tensors with the same dimensions. The element types of
2386 //   the tensors may be different.
2387 // * The result is a tuple that consists of the operands in sorted order (along
2388 //   the provided dimension, as above). The same permutation as implied by the
2389 //   comparison computation is applied to all operand tensors. When comparing
2390 //   two index positions, 'comparator' is called with 2 * n scalar parameters,
2391 //   where parameter 2 * i and 2 * i + 1 correspond to the value of operand i at
2392 //   two index positions.
2393 // Default comparator computations can be found in lib/comparators.h
2394 XlaOp Sort(absl::Span<const XlaOp> operands, const XlaComputation& comparator,
2395            int64 dimension = -1, bool is_stable = false);
2396 
2397 // Enqueues a clamp instruction onto the computation.
2398 XlaOp Clamp(XlaOp min, XlaOp operand, XlaOp max);
2399 
2400 // Enqueues a map instruction onto the computation.
2401 XlaOp Map(XlaBuilder* builder, absl::Span<const XlaOp> operands,
2402           const XlaComputation& computation, absl::Span<const int64> dimensions,
2403           absl::Span<const XlaOp> static_operands = {});
2404 
2405 // Enqueues a N(mu, sigma) random number generation instruction onto the
2406 // computation.
2407 XlaOp RngNormal(XlaOp mu, XlaOp sigma, const Shape& shape);
2408 
2409 // Enqueues a U(a, b) random number generation instruction onto the
2410 // computation. Returns values in the semi-open interval [a, b).
2411 XlaOp RngUniform(XlaOp a, XlaOp b, const Shape& shape);
2412 
2413 // Enqueues a B(initial_state) random bit generation instruction onto the
2414 // computation. Resturns the new key and random bits with the specified shape.
2415 XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state,
2416                       const Shape& shape);
2417 
2418 // Enqueues a while node onto the computation.
2419 XlaOp While(const XlaComputation& condition, const XlaComputation& body,
2420             XlaOp init);
2421 
2422 // Enqueues a conditional node onto the computation.
2423 XlaOp Conditional(XlaOp predicate, XlaOp true_operand,
2424                   const XlaComputation& true_computation, XlaOp false_operand,
2425                   const XlaComputation& false_computation);
2426 
2427 // Enqueues either a predicated (if/else) or indexed (switch/case/default)
2428 // conditional node onto the computation. N >= 1 branch_computations and
2429 // branch_operands are matched by index. branch_index selects the branch that
2430 // will be executed. Out of range branch_index uses the N-1'th
2431 // branch_computation as default.
2432 XlaOp Conditional(XlaOp branch_index,
2433                   absl::Span<const XlaComputation* const> branch_computations,
2434                   absl::Span<const XlaOp> branch_operands);
2435 
2436 // Enqueues a ReducePrecision node onto the computation.
2437 XlaOp ReducePrecision(XlaOp operand, const int exponent_bits,
2438                       const int mantissa_bits);
2439 
2440 // Enqueues a Gather node onto the computation.
2441 XlaOp Gather(XlaOp input, XlaOp start_indices,
2442              const GatherDimensionNumbers& dimension_numbers,
2443              absl::Span<const int64> slice_sizes,
2444              bool indices_are_sorted = false);
2445 
2446 // Enqueues a Scatter node onto the computation.
2447 XlaOp Scatter(XlaOp input, XlaOp scatter_indices, XlaOp updates,
2448               const XlaComputation& update_computation,
2449               const ScatterDimensionNumbers& dimension_numbers,
2450               bool indices_are_sorted = false, bool unique_indices = false);
2451 
2452 // Enqueues a Send node onto the computation for device-to-device
2453 // communication. This operation sends the given operand to
2454 // a Recv instruction in a different computation that shares the same channel
2455 // handle.
2456 void Send(XlaOp operand, const ChannelHandle& handle);
2457 
2458 // Variant of Send which takes a token-shaped operand and produces a
2459 // token-shaped value.  Tokens are used for ordering side-effecting operations.
2460 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2461 XlaOp SendWithToken(XlaOp operand, XlaOp token, const ChannelHandle& handle);
2462 
2463 // Enqueues a Recv node onto the computation for device-to-device
2464 // communication. The data comes from a Send instruction in a different
2465 // computation that shares the same channel handle and its shape must be the
2466 // same as the given shape.
2467 XlaOp Recv(XlaBuilder* builder, const Shape& shape,
2468            const ChannelHandle& handle);
2469 
2470 // Variant of Recv which takes a token-shaped operand and produces a two-element
2471 // tuple containing the data value and a token-shaped value. Tokens are used
2472 // for ordering side-effecting operations.
2473 // TODO(b/110532604): Replace all uses of the non-token form with this variant.
2474 XlaOp RecvWithToken(XlaOp token, const Shape& shape,
2475                     const ChannelHandle& handle);
2476 
2477 // Enqueues a Send node which transfers data from the device to the host. The
2478 // 'shape_with_layout' argument defines the layout of the data transferred; its
2479 // shape must be compatible with the shape of the operand. The operand must be
2480 // array-shaped.
2481 // TODO(b/111544877): Support tuple shapes.
2482 XlaOp SendToHost(XlaOp operand, XlaOp token, const Shape& shape_with_layout,
2483                  const ChannelHandle& handle);
2484 
2485 // Enqueues a Recv node which transfers data from the host to the device. The
2486 // given shape must contain a layout and must be an array.
2487 // TODO(b/111544877): Support tuple shapes.
2488 XlaOp RecvFromHost(XlaOp token, const Shape& shape,
2489                    const ChannelHandle& handle);
2490 
2491 // Enqueues an operation (AfterAll) with no operands that produces a
2492 // token-shaped value.  Tokens are used for ordering side-effecting operations.
2493 // This is a separate method from AfterAll to facility the removal of
2494 // operand-less AfterAll instructions.
2495 // TODO(b/110532604): Remove this function when all tokens are derived from a
2496 // single token generated or passed into the entry computation.
2497 XlaOp CreateToken(XlaBuilder* builder);
2498 
2499 // Enqueues an AfterAll instruction which produces a token-shaped value and
2500 // takes a variadic number of token-shaped operands. The number of operands must
2501 // be greater than zero. Used for joining tokens.
2502 XlaOp AfterAll(XlaBuilder* builder, absl::Span<const XlaOp> tokens);
2503 
2504 // Normalizes operand across spatial and batch dimensions for each feature.
2505 //
2506 // Returns a tuple (normalized, batch_mean, batch_var) where `normalized`
2507 // is the normalized result and batch_mean and batch_var are the mean and
2508 // variance, respectively, across batch for the operand.
2509 XlaOp BatchNormTraining(XlaOp operand, XlaOp scale, XlaOp offset, float epsilon,
2510                         int64 feature_index);
2511 
2512 // Normalizes operand across spatial and batch dimensions for each feature.
2513 //
2514 // `BatchNormInference` is equivalent to calling `BatchNormTraining` without
2515 // computing `mean` and `variance` for each batch inside the operation. It
2516 // uses the input `mean` and `variance` instead as estimated values. The
2517 // purpose of this op is to reduce latency in inference, hence the name
2518 // `BatchNormInference`.
2519 //
2520 // The output has the same shape as `operand`, and contains the normalized
2521 // values for each batch.
2522 XlaOp BatchNormInference(XlaOp operand, XlaOp scale, XlaOp offset, XlaOp mean,
2523                          XlaOp variance, float epsilon, int64 feature_index);
2524 
2525 // Calculates the gradients of a batch norm op.
2526 //
2527 // The inputs `batch_mean` and `batch_var` represent the mean and variance
2528 // across the batch.
2529 //
2530 // Returns a tuple of three elements:
2531 //   - grad_operand: Gradient with respect to input `operand`
2532 //   - grad_offset: Gradient with respect to input `offset`
2533 //   - grad_scale: Gradient with respect to input `scale`
2534 XlaOp BatchNormGrad(XlaOp operand, XlaOp scale, XlaOp batch_mean,
2535                     XlaOp batch_var, XlaOp grad_output, float epsilon,
2536                     int64 feature_index);
2537 
2538 // Returns the size of the given dimension of the operand. The operand must be
2539 // array shaped.
2540 XlaOp GetDimensionSize(XlaOp operand, int64 dimension);
2541 
2542 XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension);
2543 
2544 // Returns the same op but with dynamic dimension removed.
2545 XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension);
2546 
2547 // Implementation details below this point.
2548 //
2549 
2550 // Free function template implementations.
2551 
2552 template <typename NativeT>
ConstantR0(XlaBuilder * builder,NativeT value)2553 XlaOp ConstantR0(XlaBuilder* builder, NativeT value) {
2554   return ConstantLiteral(builder, LiteralUtil::CreateR0<NativeT>(value));
2555 }
2556 
2557 template <typename NativeT>
ConstantR1(XlaBuilder * builder,absl::Span<const NativeT> values)2558 XlaOp ConstantR1(XlaBuilder* builder, absl::Span<const NativeT> values) {
2559   BorrowingLiteral literal(
2560       reinterpret_cast<const char*>(values.begin()),
2561       ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<NativeT>(),
2562                            {static_cast<int64>(values.size())}));
2563   return ConstantLiteral(builder, literal);
2564 }
2565 
2566 template <typename NativeT>
ConstantR1(XlaBuilder * builder,int64 length,NativeT value)2567 XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value) {
2568   Literal literal(ShapeUtil::MakeShape(
2569       primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
2570   literal.PopulateWithValue(value);
2571   return ConstantLiteral(builder, literal);
2572 }
2573 
ConstantR1(XlaBuilder * builder,const tensorflow::core::Bitmap & values)2574 inline XlaOp ConstantR1(XlaBuilder* builder,
2575                         const tensorflow::core::Bitmap& values) {
2576   return ConstantLiteral(builder, LiteralUtil::CreateR1(values));
2577 }
2578 
2579 template <typename NativeT>
ConstantR2(XlaBuilder * builder,std::initializer_list<std::initializer_list<NativeT>> values)2580 XlaOp ConstantR2(XlaBuilder* builder,
2581                  std::initializer_list<std::initializer_list<NativeT>> values) {
2582   return ConstantLiteral(builder, LiteralUtil::CreateR2<NativeT>(values));
2583 }
2584 
2585 template <typename NativeT>
ConstantFromArrayWithLayout(XlaBuilder * builder,const Array<NativeT> & values,const Layout & layout)2586 XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
2587                                   const Array<NativeT>& values,
2588                                   const Layout& layout) {
2589   return ConstantLiteral(
2590       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2591 }
2592 
2593 template <typename NativeT>
ConstantFromArray(XlaBuilder * builder,const Array<NativeT> & values)2594 XlaOp ConstantFromArray(XlaBuilder* builder, const Array<NativeT>& values) {
2595   return ConstantLiteral(builder,
2596                          LiteralUtil::CreateFromArray<NativeT>(values));
2597 }
2598 
2599 template <typename NativeT>
ConstantR2FromArray2DWithLayout(XlaBuilder * builder,const Array2D<NativeT> & values,const Layout & layout)2600 XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
2601                                       const Array2D<NativeT>& values,
2602                                       const Layout& layout) {
2603   return ConstantLiteral(
2604       builder, LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
2605 }
2606 
2607 template <typename NativeT>
ConstantR2FromArray2D(XlaBuilder * builder,const Array2D<NativeT> & values)2608 XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
2609                             const Array2D<NativeT>& values) {
2610   return ConstantLiteral(builder,
2611                          LiteralUtil::CreateR2FromArray2D<NativeT>(values));
2612 }
2613 
2614 template <typename NativeT>
ConstantR3FromArray3DWithLayout(XlaBuilder * builder,const Array3D<NativeT> & values,const Layout & layout)2615 XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
2616                                       const Array3D<NativeT>& values,
2617                                       const Layout& layout) {
2618   return ConstantLiteral(
2619       builder,
2620       LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
2621 }
2622 
2623 template <typename NativeT>
ConstantR3FromArray3D(XlaBuilder * builder,const Array3D<NativeT> & values)2624 XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
2625                             const Array3D<NativeT>& values) {
2626   return ConstantFromArray(builder, values);
2627 }
2628 
2629 template <typename NativeT>
ConstantR4FromArray4DWithLayout(XlaBuilder * builder,const Array4D<NativeT> & values,const Layout & layout)2630 XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
2631                                       const Array4D<NativeT>& values,
2632                                       const Layout& layout) {
2633   return ConstantFromArrayWithLayout(builder, values, layout);
2634 }
2635 
2636 template <typename NativeT>
ConstantR4FromArray4D(XlaBuilder * builder,const Array4D<NativeT> & values)2637 XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
2638                             const Array4D<NativeT>& values) {
2639   return ConstantFromArray(builder, values);
2640 }
2641 
2642 }  // namespace xla
2643 
2644 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_XLA_BUILDER_H_
2645